aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2023-03-28 19:10:41 +0100
committerTomasz Kramkowski <tomasz@kramkow.ski>2023-03-28 20:10:48 +0100
commite344527db7faae25ff1cb13ff70edc98cd811b4e (patch)
tree351c25486eda232270c16a4061fb7933a1dfe3e4
parentcba20d70c21cefaad3345a5779c88423edd0655b (diff)
downloadpaste-e344527db7faae25ff1cb13ff70edc98cd811b4e.tar.gz
paste-e344527db7faae25ff1cb13ff70edc98cd811b4e.tar.xz
paste-e344527db7faae25ff1cb13ff70edc98cd811b4e.zip
Make authenticate easier to test without monkeypatching
By having authenticate be a function taking a parameter to a getter which can get an Auth from an Env, it's now possible to test it without needing monkeypatching.
-rw-r--r--paste/__init__.py49
-rw-r--r--tests/middleware/test_authenticate.py39
2 files changed, 43 insertions, 45 deletions
diff --git a/paste/__init__.py b/paste/__init__.py
index ddd1054..6e93923 100644
--- a/paste/__init__.py
+++ b/paste/__init__.py
@@ -142,28 +142,31 @@ def open_database(app: App, environ: Env, start_response: StartResponse) -> Resp
return app(environ, start_response)
-@middleware
-def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response:
- def check_auth():
- value = environ.get("HTTP_AUTHORIZATION")
- if not isinstance(value, str):
- return False
- if not value.startswith("APIKey "):
- return False
- value = value.removeprefix("APIKey ")
- try:
- value = b64decode(value.encode(), validate=True)
- except (binascii.Error, UnicodeEncodeError):
- return False
- return Auth(environ["paste.db_conn"]).check_token(value)
-
- if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth():
- return app(environ, start_response)
- return simple_response(
- start_response,
- "401 Unauthorized",
- extra_headers=[("WWW-Authenticate", "APIKey")],
- )
+def authenticate(get_auth: Callable[[Env], Auth]) -> Middleware:
+ @middleware
+ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response:
+ def check_auth():
+ value = environ.get("HTTP_AUTHORIZATION")
+ if not isinstance(value, str):
+ return False
+ if not value.startswith("APIKey "):
+ return False
+ value = value.removeprefix("APIKey ")
+ try:
+ value = b64decode(value.encode(), validate=True)
+ except (binascii.Error, UnicodeEncodeError):
+ return False
+ return get_auth(environ).check_token(value)
+
+ if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth():
+ return app(environ, start_response)
+ return simple_response(
+ start_response,
+ "401 Unauthorized",
+ extra_headers=[("WWW-Authenticate", "APIKey")],
+ )
+
+ return authenticate
@catch_exceptions
@@ -171,7 +174,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo
@options
@if_none_match
@open_database
-@authenticate
+@authenticate(lambda environ: Auth(environ["paste.db_conn"]))
def application(environ: Env, start_response: StartResponse) -> Response:
store = Store(environ["paste.db_conn"])
name = environ["PATH_INFO"]
diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py
index d3d7607..ec8734a 100644
--- a/tests/middleware/test_authenticate.py
+++ b/tests/middleware/test_authenticate.py
@@ -8,10 +8,22 @@ from paste import authenticate
from ..common_wsgi import call_app
+def get_auth(environ):
+ assert environ
+ assert "test.check_token" in environ
+
+ class MockAuth:
+ @staticmethod
+ def check_token(tok):
+ return environ["test.check_token"](tok)
+
+ return MockAuth()
+
+
@pytest.fixture
def app():
@validator
- @authenticate
+ @authenticate(get_auth) # type: ignore
@validator
def app(_, start_response):
start_response("200 OK", [("Content-Type", "text/plain")])
@@ -33,7 +45,6 @@ def test_unauthenticated_request(app, method):
def test_unauthenticated_request_with_key(app, method):
environ = {
"REQUEST_METHOD": method,
- "paste.db_conn": None,
"HTTP_AUTHORIZATION": "ApiKey AAAA",
}
response = call_app(app, environ)
@@ -63,22 +74,8 @@ def test_authenticate_malformed_key(app, method, key):
assert ("WWW-Authenticate", "APIKey") in response.headers
-class MockConnection:
- def __init__(self, check_token):
- self.check_token = check_token
-
-
-class MockAuth:
- def __init__(self, c):
- assert isinstance(c, MockConnection)
- self.conn = c
-
- def check_token(self, tok):
- return self.conn.check_token(tok)
-
-
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_check_token_fail(app, method, monkeypatch):
+def test_authenticate_check_token_fail(app, method):
check_token_called = False
token = b"test"
@@ -90,11 +87,10 @@ def test_authenticate_check_token_fail(app, method, monkeypatch):
environ = {
"REQUEST_METHOD": method,
- "paste.db_conn": MockConnection(check_token),
+ "test.check_token": check_token,
"HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
}
- monkeypatch.setattr("paste.Auth", MockAuth)
response = call_app(app, environ)
assert check_token_called
assert response.data == b"401 Unauthorized\n"
@@ -104,7 +100,7 @@ def test_authenticate_check_token_fail(app, method, monkeypatch):
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_check_token_success(app, method, monkeypatch):
+def test_authenticate_check_token_success(app, method):
check_token_called = False
token = b"test"
@@ -116,11 +112,10 @@ def test_authenticate_check_token_success(app, method, monkeypatch):
environ = {
"REQUEST_METHOD": method,
- "paste.db_conn": MockConnection(check_token),
+ "test.check_token": check_token,
"HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
}
- monkeypatch.setattr("paste.Auth", MockAuth)
response = call_app(app, environ)
assert check_token_called
assert response.data == b"Hello, world!"