diff options
-rw-r--r-- | paste/__init__.py | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/paste/__init__.py b/paste/__init__.py index f8985eb..482c534 100644 --- a/paste/__init__.py +++ b/paste/__init__.py @@ -173,13 +173,7 @@ def open_database(app: App, environ: Env, start_response: StartResponse) -> Resp return app(environ, start_response) -def check_auth(conn: Connection, auth: Optional[str]) -> bool: - if not auth or not auth.startswith("Bearer "): - return False - try: - token = b64decode(auth.removeprefix("Bearer ").encode()) - except binascii.Error: - return False +def check_token(conn: Connection, token: bytes) -> bool: (count,) = conn.execute( "SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,) ).fetchone() @@ -188,9 +182,20 @@ def check_auth(conn: Connection, auth: Optional[str]) -> bool: @middleware def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: - conn = environ["paste.db_conn"] - token = environ.get("HTTP_AUTHORIZATION") - if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(conn, token): + def check_auth(): + value = environ.get("HTTP_AUTHORIZATION") + if not isinstance(value, str): + return False + if not value.startswith("APIKey "): + return False + value = value.removeprefix("APIKey ") + try: + value = b64decode(value.encode(), validate=True) + except (binascii.Error, UnicodeEncodeError): + return False + return check_token(environ["paste.db_conn"], value) + + if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): return app(environ, start_response) return simple_response(start_response, "401 Unauthorized") |