diff options
| -rw-r--r-- | paste/__init__.py | 47 | ||||
| -rw-r--r-- | tests/middleware/test_authenticate.py | 39 | 
2 files changed, 42 insertions, 44 deletions
diff --git a/paste/__init__.py b/paste/__init__.py index ddd1054..6e93923 100644 --- a/paste/__init__.py +++ b/paste/__init__.py @@ -142,28 +142,31 @@ def open_database(app: App, environ: Env, start_response: StartResponse) -> Resp          return app(environ, start_response) -@middleware -def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: -    def check_auth(): -        value = environ.get("HTTP_AUTHORIZATION") -        if not isinstance(value, str): -            return False -        if not value.startswith("APIKey "): -            return False -        value = value.removeprefix("APIKey ") -        try: -            value = b64decode(value.encode(), validate=True) -        except (binascii.Error, UnicodeEncodeError): -            return False -        return Auth(environ["paste.db_conn"]).check_token(value) +def authenticate(get_auth: Callable[[Env], Auth]) -> Middleware: +    @middleware +    def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: +        def check_auth(): +            value = environ.get("HTTP_AUTHORIZATION") +            if not isinstance(value, str): +                return False +            if not value.startswith("APIKey "): +                return False +            value = value.removeprefix("APIKey ") +            try: +                value = b64decode(value.encode(), validate=True) +            except (binascii.Error, UnicodeEncodeError): +                return False +            return get_auth(environ).check_token(value) -    if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): -        return app(environ, start_response) -    return simple_response( -        start_response, -        "401 Unauthorized", -        extra_headers=[("WWW-Authenticate", "APIKey")], -    ) +        if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): +            return app(environ, start_response) +        return simple_response( +            start_response, +            "401 Unauthorized", +            extra_headers=[("WWW-Authenticate", "APIKey")], +        ) + +    return authenticate  @catch_exceptions @@ -171,7 +174,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo  @options  @if_none_match  @open_database -@authenticate +@authenticate(lambda environ: Auth(environ["paste.db_conn"]))  def application(environ: Env, start_response: StartResponse) -> Response:      store = Store(environ["paste.db_conn"])      name = environ["PATH_INFO"] diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py index d3d7607..ec8734a 100644 --- a/tests/middleware/test_authenticate.py +++ b/tests/middleware/test_authenticate.py @@ -8,10 +8,22 @@ from paste import authenticate  from ..common_wsgi import call_app +def get_auth(environ): +    assert environ +    assert "test.check_token" in environ + +    class MockAuth: +        @staticmethod +        def check_token(tok): +            return environ["test.check_token"](tok) + +    return MockAuth() + +  @pytest.fixture  def app():      @validator -    @authenticate +    @authenticate(get_auth)  # type: ignore      @validator      def app(_, start_response):          start_response("200 OK", [("Content-Type", "text/plain")]) @@ -33,7 +45,6 @@ def test_unauthenticated_request(app, method):  def test_unauthenticated_request_with_key(app, method):      environ = {          "REQUEST_METHOD": method, -        "paste.db_conn": None,          "HTTP_AUTHORIZATION": "ApiKey AAAA",      }      response = call_app(app, environ) @@ -63,22 +74,8 @@ def test_authenticate_malformed_key(app, method, key):      assert ("WWW-Authenticate", "APIKey") in response.headers -class MockConnection: -    def __init__(self, check_token): -        self.check_token = check_token - - -class MockAuth: -    def __init__(self, c): -        assert isinstance(c, MockConnection) -        self.conn = c - -    def check_token(self, tok): -        return self.conn.check_token(tok) - -  @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_fail(app, method, monkeypatch): +def test_authenticate_check_token_fail(app, method):      check_token_called = False      token = b"test" @@ -90,11 +87,10 @@ def test_authenticate_check_token_fail(app, method, monkeypatch):      environ = {          "REQUEST_METHOD": method, -        "paste.db_conn": MockConnection(check_token), +        "test.check_token": check_token,          "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",      } -    monkeypatch.setattr("paste.Auth", MockAuth)      response = call_app(app, environ)      assert check_token_called      assert response.data == b"401 Unauthorized\n" @@ -104,7 +100,7 @@ def test_authenticate_check_token_fail(app, method, monkeypatch):  @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_success(app, method, monkeypatch): +def test_authenticate_check_token_success(app, method):      check_token_called = False      token = b"test" @@ -116,11 +112,10 @@ def test_authenticate_check_token_success(app, method, monkeypatch):      environ = {          "REQUEST_METHOD": method, -        "paste.db_conn": MockConnection(check_token), +        "test.check_token": check_token,          "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",      } -    monkeypatch.setattr("paste.Auth", MockAuth)      response = call_app(app, environ)      assert check_token_called      assert response.data == b"Hello, world!"  | 
