aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--paste/__init__.py25
1 files changed, 15 insertions, 10 deletions
diff --git a/paste/__init__.py b/paste/__init__.py
index f8985eb..482c534 100644
--- a/paste/__init__.py
+++ b/paste/__init__.py
@@ -173,13 +173,7 @@ def open_database(app: App, environ: Env, start_response: StartResponse) -> Resp
return app(environ, start_response)
-def check_auth(conn: Connection, auth: Optional[str]) -> bool:
- if not auth or not auth.startswith("Bearer "):
- return False
- try:
- token = b64decode(auth.removeprefix("Bearer ").encode())
- except binascii.Error:
- return False
+def check_token(conn: Connection, token: bytes) -> bool:
(count,) = conn.execute(
"SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,)
).fetchone()
@@ -188,9 +182,20 @@ def check_auth(conn: Connection, auth: Optional[str]) -> bool:
@middleware
def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response:
- conn = environ["paste.db_conn"]
- token = environ.get("HTTP_AUTHORIZATION")
- if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(conn, token):
+ def check_auth():
+ value = environ.get("HTTP_AUTHORIZATION")
+ if not isinstance(value, str):
+ return False
+ if not value.startswith("APIKey "):
+ return False
+ value = value.removeprefix("APIKey ")
+ try:
+ value = b64decode(value.encode(), validate=True)
+ except (binascii.Error, UnicodeEncodeError):
+ return False
+ return check_token(environ["paste.db_conn"], value)
+
+ if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth():
return app(environ, start_response)
return simple_response(start_response, "401 Unauthorized")