From d198fca95919cc78275d3d9fa8f1b0a8acfdbab3 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Tue, 28 Mar 2023 18:53:15 +0100 Subject: Switch back to centralised opening of the database Create Store instances when needed This will make more sense with following commits --- paste/__init__.py | 25 +++++++++++++------------ paste/__main__.py | 29 ++++++++++++++++------------- paste/store.py | 6 ------ tests/middleware/test_authenticate.py | 23 +++++++++++++++++------ tests/middleware/test_open_store.py | 20 ++++++++++---------- tests/test_application.py | 6 +++--- 6 files changed, 59 insertions(+), 50 deletions(-) diff --git a/paste/__init__.py b/paste/__init__.py index 567aa8e..155e927 100644 --- a/paste/__init__.py +++ b/paste/__init__.py @@ -8,7 +8,8 @@ from functools import wraps from typing import Optional from wsgiref.util import application_uri, request_uri -from . import store +from . import db +from .store import Store from .types import ( App, Closable, @@ -134,10 +135,10 @@ def options(app: App, environ: Env, start_response: StartResponse) -> Response: @middleware -def open_store(app: App, environ: Env, start_response: StartResponse) -> Response: +def open_database(app: App, environ: Env, start_response: StartResponse) -> Response: db_path = environ.get("PASTE_DB", DB_PATH) - with store.open(db_path) as stor: - environ["paste.store"] = stor + with db.connect(db_path) as conn: + environ["paste.db_conn"] = conn return app(environ, start_response) @@ -154,7 +155,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo value = b64decode(value.encode(), validate=True) except (binascii.Error, UnicodeEncodeError): return False - return environ["paste.store"].check_token(value) + return Store(environ["paste.db_conn"]).check_token(value) if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): return app(environ, start_response) @@ -169,13 +170,13 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo @validate_method @options @if_none_match -@open_store +@open_database @authenticate def application(environ: Env, start_response: StartResponse) -> Response: - stor = environ["paste.store"] + store = Store(environ["paste.db_conn"]) name = environ["PATH_INFO"] if environ["REQUEST_METHOD"] == "GET": - row = stor.get(name) + row = store.get(name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content = row @@ -189,7 +190,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: ) return [content] elif environ["REQUEST_METHOD"] == "HEAD": - row = stor.head(name) + row = store.head(name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content_length = row @@ -206,7 +207,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) - created, content_hash = stor.put(name, content, content_type) + created, content_hash = store.put(name, content, content_type) start_response( "201 Created" if created else "204 No Content", [ @@ -219,7 +220,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) - path, content_hash = stor.post(name, content, content_type) + path, content_hash = store.post(name, content, content_type) uri = application_uri(environ) path = urllib.parse.quote(path) if uri[-1] == "/" and path[:1] == "/": @@ -234,7 +235,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: ) return [] elif environ["REQUEST_METHOD"] == "DELETE": - if stor.delete(name): + if store.delete(name): start_response("204 No Content", []) return [] return simple_response(start_response, "404 Not Found") diff --git a/paste/__main__.py b/paste/__main__.py index 27ff72b..6e684c4 100644 --- a/paste/__main__.py +++ b/paste/__main__.py @@ -1,11 +1,12 @@ from base64 import b64decode, b64encode -from contextlib import AbstractContextManager +from collections.abc import Iterator +from contextlib import contextmanager from datetime import datetime, timezone from os import getenv from sys import argv, stderr from wsgiref.simple_server import make_server -from . import DB_PATH, application, store +from . import DB_PATH, application, db, store PROGRAM_NAME = "paste" @@ -42,8 +43,10 @@ Environment: db_path = getenv("PASTE_DB", DB_PATH) -def open_db() -> AbstractContextManager[store.Store]: - return store.open(db_path) +@contextmanager +def open_auth() -> Iterator[store.Auth]: + with db.connect(db_path) as conn: + yield store.Auth(conn) def main(): @@ -54,11 +57,11 @@ def main(): httpd.serve_forever() elif len(argv) == 2: if argv[1] == "new-token": - with open_db() as db: - print(b64encode(db.generate_token()).decode()) + with open_auth() as auth: + print(b64encode(auth.generate_token()).decode()) elif argv[1] == "list-tokens": - with open_db() as db: - for token_hash, created_at in db.get_tokens(): + with open_auth() as auth: + for token_hash, created_at in auth.get_tokens(): created_at = datetime.fromtimestamp(created_at, timezone.utc) print(f"{token_hash.hex()}\t{created_at.ctime()}") elif argv[1] == "-h" or argv[1] == "--help": @@ -70,12 +73,12 @@ def main(): elif len(argv) == 3: if argv[1] == "delete-token": token = argv[2] - with open_db() as db: + with open_auth() as auth: try: try: - db.delete_token(b64decode(token)) + auth.delete_token(b64decode(token)) except ValueError: - db.delete_token_hash(bytes.fromhex(token)) + auth.delete_token_hash(bytes.fromhex(token)) except ValueError: print("Malformed token", file=stderr) exit(1) @@ -83,9 +86,9 @@ def main(): print("Token not found", file=stderr) exit(1) elif argv[1] == "verify-token": - with open_db() as db: + with open_auth() as auth: try: - if not db.check_token(b64decode(argv[2])): + if not auth.check_token(b64decode(argv[2])): print("Token not found", file=stderr) exit(1) print("Found") diff --git a/paste/store.py b/paste/store.py index e7022f6..f69c03d 100644 --- a/paste/store.py +++ b/paste/store.py @@ -120,9 +120,3 @@ class Store: "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,) ).fetchone() return count > 0 - - -@contextmanager -def open(uri: str) -> Iterator[Store]: - with db.connect(uri) as conn: - yield Store(conn) diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py index 2acfe92..9fccb32 100644 --- a/tests/middleware/test_authenticate.py +++ b/tests/middleware/test_authenticate.py @@ -34,7 +34,7 @@ def test_unauthenticated_request(app, method): def test_unauthenticated_request_with_key(app, method): environ = { "REQUEST_METHOD": method, - "paste.store": None, + "paste.db_conn": None, "HTTP_AUTHORIZATION": "ApiKey AAAA", } response = call_app(app, environ) @@ -64,13 +64,22 @@ def test_authenticate_malformed_key(app, method, key): assert ("WWW-Authenticate", "APIKey") in response.headers -class MockDB: +class MockConnection: def __init__(self, check_token): self.check_token = check_token +class MockStore: + def __init__(self, c): + assert isinstance(c, MockConnection) + self.conn = c + + def check_token(self, tok): + return self.conn.check_token(tok) + + @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_fail(app, method): +def test_authenticate_check_token_fail(app, method, monkeypatch): check_token_called = False token = b"test" @@ -82,10 +91,11 @@ def test_authenticate_check_token_fail(app, method): environ = { "REQUEST_METHOD": method, - "paste.store": MockDB(check_token), + "paste.db_conn": MockConnection(check_token), "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", } + monkeypatch.setattr("paste.Store", MockStore) response = call_app(app, environ) assert check_token_called assert response.data == b"401 Unauthorized\n" @@ -95,7 +105,7 @@ def test_authenticate_check_token_fail(app, method): @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_success(app, method): +def test_authenticate_check_token_success(app, method, monkeypatch): check_token_called = False token = b"test" @@ -107,10 +117,11 @@ def test_authenticate_check_token_success(app, method): environ = { "REQUEST_METHOD": method, - "paste.store": MockDB(check_token), + "paste.db_conn": MockConnection(check_token), "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", } + monkeypatch.setattr("paste.Store", MockStore) response = call_app(app, environ) assert check_token_called assert response.data == b"Hello, world!" diff --git a/tests/middleware/test_open_store.py b/tests/middleware/test_open_store.py index dd430e9..3e844a9 100644 --- a/tests/middleware/test_open_store.py +++ b/tests/middleware/test_open_store.py @@ -1,30 +1,30 @@ from contextlib import contextmanager from wsgiref.validate import validator -from paste import open_store +from paste import open_database from ..common_wsgi import call_app -def test_open_store(monkeypatch): +def test_open_database(monkeypatch): db_path = "test_db_path" - store = object() + conn = object() @contextmanager - def store_open(path): - assert path == db_path - yield store + def connect(uri): + assert uri == db_path + yield conn @validator - @open_store + @open_database @validator def app(environ, start_response): - assert "paste.store" in environ - assert environ["paste.store"] == store + assert "paste.db_conn" in environ + assert environ["paste.db_conn"] == conn start_response("200 OK", [("Content-Type", "text/plain")]) return [b"Hello, World!"] - monkeypatch.setattr("paste.store.open", store_open) + monkeypatch.setattr("paste.db.connect", connect) response = call_app(app, environ={"PASTE_DB": db_path}) assert response.status == "200 OK" assert response.data == b"Hello, World!" diff --git a/tests/test_application.py b/tests/test_application.py index e86d937..20cbeac 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -4,15 +4,15 @@ import pytest from webtest import TestApp import paste.db -import paste.store from paste import __main__, application +from paste.store import Store DB = "file::memory:?cache=shared" @pytest.fixture def db(): - with paste.store.open(DB) as d: + with paste.db.connect(DB) as d: yield d @@ -25,7 +25,7 @@ def app(db): @pytest.fixture def token(db): - return b64encode(db.generate_token()).decode() + return b64encode(Store(db).generate_token()).decode() @pytest.mark.parametrize("method", ["put", "post", "delete"]) -- cgit v1.2.3-54-g00ecf