diff options
Diffstat (limited to 'paste/__init__.py')
-rw-r--r-- | paste/__init__.py | 28 |
1 files changed, 10 insertions, 18 deletions
diff --git a/paste/__init__.py b/paste/__init__.py index 36ce0cb..0f75587 100644 --- a/paste/__init__.py +++ b/paste/__init__.py @@ -5,11 +5,10 @@ import urllib.parse from base64 import b64decode, b64encode from collections.abc import Callable, Iterable from functools import wraps -from sqlite3 import Connection from typing import Any, Optional, Protocol, runtime_checkable from wsgiref.util import application_uri, request_uri -from . import db, store +from . import store @runtime_checkable @@ -154,18 +153,11 @@ def options(app: App, environ: Env, start_response: StartResponse) -> Response: @middleware def open_database(app: App, environ: Env, start_response: StartResponse) -> Response: db_path = environ.get("PASTE_DB", DB_PATH) - with db.connect(db_path) as conn: - environ["paste.db_conn"] = conn + with store.open(db_path) as stor: + environ["paste.store"] = stor return app(environ, start_response) -def check_token(conn: Connection, token: bytes) -> bool: - (count,) = conn.execute( - "SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,) - ).fetchone() - return count == 1 - - @middleware def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: def check_auth(): @@ -179,7 +171,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo value = b64decode(value.encode(), validate=True) except (binascii.Error, UnicodeEncodeError): return False - return check_token(environ["paste.db_conn"], value) + return environ["paste.store"].check_token(value) if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): return app(environ, start_response) @@ -197,10 +189,10 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo @open_database @authenticate def application(environ: Env, start_response: StartResponse) -> Response: - conn = environ["paste.db_conn"] + stor = environ["paste.store"] name = environ["PATH_INFO"] if environ["REQUEST_METHOD"] == "GET": - row = store.get(conn, name) + row = stor.get(name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content = row @@ -214,7 +206,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: ) return [content] elif environ["REQUEST_METHOD"] == "HEAD": - row = store.head(conn, name) + row = stor.head(name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content_length = row @@ -231,7 +223,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) - created, content_hash = store.put(conn, name, content, content_type) + created, content_hash = stor.put(name, content, content_type) start_response( "201 Created" if created else "204 No Content", [ @@ -244,7 +236,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) - path, content_hash = store.post(conn, name, content, content_type) + path, content_hash = stor.post(name, content, content_type) uri = application_uri(environ) path = urllib.parse.quote(path) if uri[-1] == "/" and path[:1] == "/": @@ -259,7 +251,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: ) return [] elif environ["REQUEST_METHOD"] == "DELETE": - if store.delete(conn, name): + if stor.delete(name): start_response("204 No Content", []) return [] return simple_response(start_response, "404 Not Found") |