diff options
author | Tomasz Kramkowski <tomasz@kramkow.ski> | 2023-03-28 19:10:41 +0100 |
---|---|---|
committer | Tomasz Kramkowski <tomasz@kramkow.ski> | 2023-03-28 20:10:48 +0100 |
commit | e344527db7faae25ff1cb13ff70edc98cd811b4e (patch) | |
tree | 351c25486eda232270c16a4061fb7933a1dfe3e4 | |
parent | cba20d70c21cefaad3345a5779c88423edd0655b (diff) | |
download | paste-e344527db7faae25ff1cb13ff70edc98cd811b4e.tar.gz paste-e344527db7faae25ff1cb13ff70edc98cd811b4e.tar.xz paste-e344527db7faae25ff1cb13ff70edc98cd811b4e.zip |
Make authenticate easier to test without monkeypatching
By having authenticate be a function taking a parameter to a getter
which can get an Auth from an Env, it's now possible to test it without
needing monkeypatching.
-rw-r--r-- | paste/__init__.py | 49 | ||||
-rw-r--r-- | tests/middleware/test_authenticate.py | 39 |
2 files changed, 43 insertions, 45 deletions
diff --git a/paste/__init__.py b/paste/__init__.py index ddd1054..6e93923 100644 --- a/paste/__init__.py +++ b/paste/__init__.py @@ -142,28 +142,31 @@ def open_database(app: App, environ: Env, start_response: StartResponse) -> Resp return app(environ, start_response) -@middleware -def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: - def check_auth(): - value = environ.get("HTTP_AUTHORIZATION") - if not isinstance(value, str): - return False - if not value.startswith("APIKey "): - return False - value = value.removeprefix("APIKey ") - try: - value = b64decode(value.encode(), validate=True) - except (binascii.Error, UnicodeEncodeError): - return False - return Auth(environ["paste.db_conn"]).check_token(value) - - if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): - return app(environ, start_response) - return simple_response( - start_response, - "401 Unauthorized", - extra_headers=[("WWW-Authenticate", "APIKey")], - ) +def authenticate(get_auth: Callable[[Env], Auth]) -> Middleware: + @middleware + def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: + def check_auth(): + value = environ.get("HTTP_AUTHORIZATION") + if not isinstance(value, str): + return False + if not value.startswith("APIKey "): + return False + value = value.removeprefix("APIKey ") + try: + value = b64decode(value.encode(), validate=True) + except (binascii.Error, UnicodeEncodeError): + return False + return get_auth(environ).check_token(value) + + if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): + return app(environ, start_response) + return simple_response( + start_response, + "401 Unauthorized", + extra_headers=[("WWW-Authenticate", "APIKey")], + ) + + return authenticate @catch_exceptions @@ -171,7 +174,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo @options @if_none_match @open_database -@authenticate +@authenticate(lambda environ: Auth(environ["paste.db_conn"])) def application(environ: Env, start_response: StartResponse) -> Response: store = Store(environ["paste.db_conn"]) name = environ["PATH_INFO"] diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py index d3d7607..ec8734a 100644 --- a/tests/middleware/test_authenticate.py +++ b/tests/middleware/test_authenticate.py @@ -8,10 +8,22 @@ from paste import authenticate from ..common_wsgi import call_app +def get_auth(environ): + assert environ + assert "test.check_token" in environ + + class MockAuth: + @staticmethod + def check_token(tok): + return environ["test.check_token"](tok) + + return MockAuth() + + @pytest.fixture def app(): @validator - @authenticate + @authenticate(get_auth) # type: ignore @validator def app(_, start_response): start_response("200 OK", [("Content-Type", "text/plain")]) @@ -33,7 +45,6 @@ def test_unauthenticated_request(app, method): def test_unauthenticated_request_with_key(app, method): environ = { "REQUEST_METHOD": method, - "paste.db_conn": None, "HTTP_AUTHORIZATION": "ApiKey AAAA", } response = call_app(app, environ) @@ -63,22 +74,8 @@ def test_authenticate_malformed_key(app, method, key): assert ("WWW-Authenticate", "APIKey") in response.headers -class MockConnection: - def __init__(self, check_token): - self.check_token = check_token - - -class MockAuth: - def __init__(self, c): - assert isinstance(c, MockConnection) - self.conn = c - - def check_token(self, tok): - return self.conn.check_token(tok) - - @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_fail(app, method, monkeypatch): +def test_authenticate_check_token_fail(app, method): check_token_called = False token = b"test" @@ -90,11 +87,10 @@ def test_authenticate_check_token_fail(app, method, monkeypatch): environ = { "REQUEST_METHOD": method, - "paste.db_conn": MockConnection(check_token), + "test.check_token": check_token, "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", } - monkeypatch.setattr("paste.Auth", MockAuth) response = call_app(app, environ) assert check_token_called assert response.data == b"401 Unauthorized\n" @@ -104,7 +100,7 @@ def test_authenticate_check_token_fail(app, method, monkeypatch): @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_success(app, method, monkeypatch): +def test_authenticate_check_token_success(app, method): check_token_called = False token = b"test" @@ -116,11 +112,10 @@ def test_authenticate_check_token_success(app, method, monkeypatch): environ = { "REQUEST_METHOD": method, - "paste.db_conn": MockConnection(check_token), + "test.check_token": check_token, "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", } - monkeypatch.setattr("paste.Auth", MockAuth) response = call_app(app, environ) assert check_token_called assert response.data == b"Hello, world!" |