From 77889ec30410280ef1ec1671b6226a8d0c0444f7 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Mon, 27 Mar 2023 22:46:44 +0100 Subject: Throw all data manipulation code in one place This means that everything now goes through a Store object, which should make testing a little bit easier. --- paste/__init__.py | 28 ++--- paste/__main__.py | 67 ++--------- paste/store.py | 201 +++++++++++++++++++++------------ tests/middleware/test_authenticate.py | 53 +++++---- tests/middleware/test_open_database.py | 7 +- tests/test_application.py | 6 +- 6 files changed, 182 insertions(+), 180 deletions(-) diff --git a/paste/__init__.py b/paste/__init__.py index 36ce0cb..0f75587 100644 --- a/paste/__init__.py +++ b/paste/__init__.py @@ -5,11 +5,10 @@ import urllib.parse from base64 import b64decode, b64encode from collections.abc import Callable, Iterable from functools import wraps -from sqlite3 import Connection from typing import Any, Optional, Protocol, runtime_checkable from wsgiref.util import application_uri, request_uri -from . import db, store +from . import store @runtime_checkable @@ -154,18 +153,11 @@ def options(app: App, environ: Env, start_response: StartResponse) -> Response: @middleware def open_database(app: App, environ: Env, start_response: StartResponse) -> Response: db_path = environ.get("PASTE_DB", DB_PATH) - with db.connect(db_path) as conn: - environ["paste.db_conn"] = conn + with store.open(db_path) as stor: + environ["paste.store"] = stor return app(environ, start_response) -def check_token(conn: Connection, token: bytes) -> bool: - (count,) = conn.execute( - "SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,) - ).fetchone() - return count == 1 - - @middleware def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: def check_auth(): @@ -179,7 +171,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo value = b64decode(value.encode(), validate=True) except (binascii.Error, UnicodeEncodeError): return False - return check_token(environ["paste.db_conn"], value) + return environ["paste.store"].check_token(value) if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(): return app(environ, start_response) @@ -197,10 +189,10 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo @open_database @authenticate def application(environ: Env, start_response: StartResponse) -> Response: - conn = environ["paste.db_conn"] + stor = environ["paste.store"] name = environ["PATH_INFO"] if environ["REQUEST_METHOD"] == "GET": - row = store.get(conn, name) + row = stor.get(name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content = row @@ -214,7 +206,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: ) return [content] elif environ["REQUEST_METHOD"] == "HEAD": - row = store.head(conn, name) + row = stor.head(name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content_length = row @@ -231,7 +223,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) - created, content_hash = store.put(conn, name, content, content_type) + created, content_hash = stor.put(name, content, content_type) start_response( "201 Created" if created else "204 No Content", [ @@ -244,7 +236,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) - path, content_hash = store.post(conn, name, content, content_type) + path, content_hash = stor.post(name, content, content_type) uri = application_uri(environ) path = urllib.parse.quote(path) if uri[-1] == "/" and path[:1] == "/": @@ -259,7 +251,7 @@ def application(environ: Env, start_response: StartResponse) -> Response: ) return [] elif environ["REQUEST_METHOD"] == "DELETE": - if store.delete(conn, name): + if stor.delete(name): start_response("204 No Content", []) return [] return simple_response(start_response, "404 Not Found") diff --git a/paste/__main__.py b/paste/__main__.py index 2b19baa..27ff72b 100644 --- a/paste/__main__.py +++ b/paste/__main__.py @@ -1,16 +1,12 @@ from base64 import b64decode, b64encode from contextlib import AbstractContextManager from datetime import datetime, timezone -from hashlib import sha256 from os import getenv -from secrets import token_bytes -from sqlite3 import Connection from sys import argv, stderr from wsgiref.simple_server import make_server -from . import DB_PATH, application, db +from . import DB_PATH, application, store -TOKEN_BYTES = 96 // 8 PROGRAM_NAME = "paste" @@ -43,50 +39,11 @@ Environment: ) -def generate_token(conn: Connection): - token = token_bytes(TOKEN_BYTES) - token_hash = sha256(token).digest() - with conn: - conn.execute("INSERT INTO token (hash) VALUES (?)", (token_hash,)) - return token - - -def get_tokens(conn: Connection): - return conn.execute( - "SELECT hash, created_at FROM token ORDER BY created_at" - ).fetchall() - - -def delete_token_hash(conn: Connection, token_hash: bytes): - if len(token_hash) != 256 // 8: - raise ValueError("Invalid token hash") - with conn: - cur = conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,)) - if cur.rowcount <= 0: - raise KeyError("Token hash does not exist") - - -def delete_token(conn: Connection, token: bytes): - if len(token) != TOKEN_BYTES: - raise ValueError("Invalid token") - return delete_token_hash(conn, sha256(token).digest()) - - -def check_token(conn: Connection, token: bytes): - if len(token) != TOKEN_BYTES: - raise ValueError("Invalid token") - token_hash = sha256(token).digest() - (count,) = conn.execute( - "SELECT COUNT(*) FROM token WHERE hash = ?", (token_hash,) - ).fetchone() - return count > 0 - - db_path = getenv("PASTE_DB", DB_PATH) -def connect() -> AbstractContextManager[Connection]: - return db.connect(db_path) +def open_db() -> AbstractContextManager[store.Store]: + return store.open(db_path) def main(): @@ -97,11 +54,11 @@ def main(): httpd.serve_forever() elif len(argv) == 2: if argv[1] == "new-token": - with connect() as conn: - print(b64encode(generate_token(conn)).decode()) + with open_db() as db: + print(b64encode(db.generate_token()).decode()) elif argv[1] == "list-tokens": - with connect() as conn: - for token_hash, created_at in get_tokens(conn): + with open_db() as db: + for token_hash, created_at in db.get_tokens(): created_at = datetime.fromtimestamp(created_at, timezone.utc) print(f"{token_hash.hex()}\t{created_at.ctime()}") elif argv[1] == "-h" or argv[1] == "--help": @@ -113,12 +70,12 @@ def main(): elif len(argv) == 3: if argv[1] == "delete-token": token = argv[2] - with connect() as conn: + with open_db() as db: try: try: - delete_token(conn, b64decode(token)) + db.delete_token(b64decode(token)) except ValueError: - delete_token_hash(conn, bytes.fromhex(token)) + db.delete_token_hash(bytes.fromhex(token)) except ValueError: print("Malformed token", file=stderr) exit(1) @@ -126,9 +83,9 @@ def main(): print("Token not found", file=stderr) exit(1) elif argv[1] == "verify-token": - with connect() as conn: + with open_db() as db: try: - if not check_token(conn, b64decode(argv[2])): + if not db.check_token(b64decode(argv[2])): print("Token not found", file=stderr) exit(1) print("Found") diff --git a/paste/store.py b/paste/store.py index dd00edd..b7e471d 100644 --- a/paste/store.py +++ b/paste/store.py @@ -1,77 +1,132 @@ -from secrets import token_urlsafe +from collections.abc import Iterator +from contextlib import contextmanager +from hashlib import sha256 +from secrets import token_bytes, token_urlsafe from sqlite3 import Connection, IntegrityError +from . import db -def put(conn: Connection, name: str, content: bytes, content_type: str): - with conn: - conn.execute( - "INSERT OR IGNORE INTO file (content) VALUES (?)", - (content,), - ) - (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone() - cur = conn.execute( - """UPDATE link - SET content_type = ?, file_hash = ? +TOKEN_BYTES = 96 // 8 + + +class Store: + def __init__(self, conn: Connection): + self.conn = conn + + def put(self, name: str, content: bytes, content_type: str): + with self.conn: + self.conn.execute( + "INSERT OR IGNORE INTO file (content) VALUES (?)", + (content,), + ) + (content_hash,) = self.conn.execute( + "SELECT DATA_HASH(?)", (content,) + ).fetchone() + cur = self.conn.execute( + """UPDATE link + SET content_type = ?, file_hash = ? + WHERE name_hash = DATA_HASH(?)""", + (content_type, content_hash, name), + ) + if cur.rowcount == 1: + return False, content_hash + self.conn.execute( + """INSERT INTO link ( + name, content_type, file_hash + ) VALUES (?, ?, ?)""", + (name, content_type, content_hash), + ) + return True, content_hash + + def post(self, prefix: str, content: bytes, content_type: str): + with self.conn: + self.conn.execute( + "INSERT OR IGNORE INTO file (content) VALUES (?)", + (content,), + ) + (content_hash,) = self.conn.execute( + "SELECT DATA_HASH(?)", (content,) + ).fetchone() + for _ in range(16): + name = prefix + token_urlsafe(5) + try: + self.conn.execute( + """INSERT INTO link (name, content_type, file_hash) + VALUES (?, ?, ?)""", + (name, content_type, content_hash), + ) + except IntegrityError: + continue + break + else: + raise RuntimeError("Could not insert a link in 16 attempts") + return name, content_hash + + def get(self, name: str): + row = self.conn.execute( + """SELECT link.content_type, file.hash, file.content + FROM link + JOIN file ON file.hash = link.file_hash + WHERE name_hash = DATA_HASH(?)""", + (name,), + ).fetchone() + return row + + def head(self, name: str): + row = self.conn.execute( + """SELECT link.content_type, file.hash, length(file.content) + FROM link + JOIN file ON file.hash = link.file_hash WHERE name_hash = DATA_HASH(?)""", - (content_type, content_hash, name), - ) - if cur.rowcount == 1: - return False, content_hash - conn.execute( - """INSERT INTO link ( - name, content_type, file_hash - ) VALUES (?, ?, ?)""", - (name, content_type, content_hash), - ) - return True, content_hash - - -def post(conn: Connection, prefix: str, content: bytes, content_type: str): - with conn: - conn.execute( - "INSERT OR IGNORE INTO file (content) VALUES (?)", - (content,), - ) - (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone() - for _ in range(16): - name = prefix + token_urlsafe(5) - try: - conn.execute( - """INSERT INTO link (name, content_type, file_hash) - VALUES (?, ?, ?)""", - (name, content_type, content_hash), - ) - except IntegrityError: - continue - break - else: - raise RuntimeError("Could not insert a link in 16 attempts") - return name, content_hash - - -def get(conn: Connection, name: str): - row = conn.execute( - """SELECT link.content_type, file.hash, file.content - FROM link - JOIN file ON file.hash = link.file_hash - WHERE name_hash = DATA_HASH(?)""", - (name,), - ).fetchone() - return row - - -def head(conn: Connection, name: str): - row = conn.execute( - """SELECT link.content_type, file.hash, length(file.content) - FROM link - JOIN file ON file.hash = link.file_hash - WHERE name_hash = DATA_HASH(?)""", - (name,), - ).fetchone() - return row - - -def delete(conn: Connection, name: str): - with conn: - cur = conn.execute("DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,)) - return cur.rowcount == 1 + (name,), + ).fetchone() + return row + + def delete(self, name: str): + with self.conn: + cur = self.conn.execute( + "DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,) + ) + return cur.rowcount == 1 + + def generate_token(self): + token = token_bytes(TOKEN_BYTES) + with self.conn: + self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,)) + return token + + def get_tokens(self): + return self.conn.execute( + "SELECT hash, created_at FROM token ORDER BY created_at" + ).fetchall() + + def delete_token_hash(self, token_hash: bytes): + if len(token_hash) != 256 // 8: + raise ValueError("Invalid token hash") + with self.conn: + cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,)) + if cur.rowcount <= 0: + raise KeyError("Token hash does not exist") + + def delete_token(self, token: bytes): + if len(token) != TOKEN_BYTES: + raise ValueError("Invalid token") + return self.delete_token_hash(sha256(token).digest()) + + def check_token(self, token: bytes): + if len(token) != TOKEN_BYTES: + raise ValueError("Invalid token") + (count,) = self.conn.execute( + "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,) + ).fetchone() + return count > 0 + + @property + def version(self): + return db.get_version(self.conn) + + +@contextmanager +def open(uri: str) -> Iterator[Store]: + with db.connect(uri) as conn: + yield Store(conn) diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py index 2395316..2acfe92 100644 --- a/tests/middleware/test_authenticate.py +++ b/tests/middleware/test_authenticate.py @@ -22,8 +22,7 @@ def app(): @pytest.mark.parametrize("method", ["GET", "HEAD"]) -def test_unauthenticated_request(app, method, monkeypatch): - monkeypatch.delattr(paste, "check_token") +def test_unauthenticated_request(app, method): environ = {"REQUEST_METHOD": method} response = call_app(app, environ) assert response.data == b"Hello, world!" @@ -32,11 +31,10 @@ def test_unauthenticated_request(app, method, monkeypatch): @pytest.mark.parametrize("method", ["GET", "HEAD"]) -def test_unauthenticated_request_with_key(app, method, monkeypatch): - monkeypatch.delattr(paste, "check_token") +def test_unauthenticated_request_with_key(app, method): environ = { "REQUEST_METHOD": method, - "paste.db_conn": None, + "paste.store": None, "HTTP_AUTHORIZATION": "ApiKey AAAA", } response = call_app(app, environ) @@ -46,8 +44,7 @@ def test_unauthenticated_request_with_key(app, method, monkeypatch): @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_no_header(app, method, monkeypatch): - monkeypatch.delattr(paste, "check_token") +def test_authenticate_no_header(app, method): environ = {"REQUEST_METHOD": method} response = call_app(app, environ) assert response.data == b"401 Unauthorized\n" @@ -58,8 +55,7 @@ def test_authenticate_no_header(app, method, monkeypatch): @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) @pytest.mark.parametrize("key", ["ApiKey AAAA", "APIKey AAA", "APIKey AAAA", "AAAA"]) -def test_authenticate_malformed_key(app, method, key, monkeypatch): - monkeypatch.delattr(paste, "check_token") +def test_authenticate_malformed_key(app, method, key): environ = {"REQUEST_METHOD": method, "HTTP_AUTHORIZATION": key} response = call_app(app, environ) assert response.data == b"401 Unauthorized\n" @@ -68,24 +64,28 @@ def test_authenticate_malformed_key(app, method, key, monkeypatch): assert ("WWW-Authenticate", "APIKey") in response.headers +class MockDB: + def __init__(self, check_token): + self.check_token = check_token + + @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_fail(app, method, monkeypatch): +def test_authenticate_check_token_fail(app, method): check_token_called = False token = b"test" - environ = { - "REQUEST_METHOD": method, - "paste.db_conn": object(), - "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", - } - def check_token(conn, tok): + def check_token(tok): nonlocal check_token_called - assert conn == environ["paste.db_conn"] assert tok == token check_token_called = True return False - monkeypatch.setattr(paste, "check_token", check_token) + environ = { + "REQUEST_METHOD": method, + "paste.store": MockDB(check_token), + "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", + } + response = call_app(app, environ) assert check_token_called assert response.data == b"401 Unauthorized\n" @@ -95,23 +95,22 @@ def test_authenticate_check_token_fail(app, method, monkeypatch): @pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"]) -def test_authenticate_check_token_success(app, method, monkeypatch): +def test_authenticate_check_token_success(app, method): check_token_called = False token = b"test" - environ = { - "REQUEST_METHOD": method, - "paste.db_conn": object(), - "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", - } - def check_token(conn, tok): + def check_token(tok): nonlocal check_token_called - assert conn == environ["paste.db_conn"] assert tok == token check_token_called = True return True - monkeypatch.setattr(paste, "check_token", check_token) + environ = { + "REQUEST_METHOD": method, + "paste.store": MockDB(check_token), + "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}", + } + response = call_app(app, environ) assert check_token_called assert response.data == b"Hello, world!" diff --git a/tests/middleware/test_open_database.py b/tests/middleware/test_open_database.py index 43ebb07..1584ac0 100644 --- a/tests/middleware/test_open_database.py +++ b/tests/middleware/test_open_database.py @@ -15,10 +15,9 @@ def app(): @open_database @validator def app(environ, start_response): - assert "paste.db_conn" in environ - conn = environ["paste.db_conn"] - (ver,) = conn.execute("PRAGMA user_version").fetchone() - assert ver > 0 + assert "paste.store" in environ + db = environ["paste.store"] + assert db.version > 0 start_response("200 OK", [("Content-Type", "text/plain")]) return [b"Hello, World!"] diff --git a/tests/test_application.py b/tests/test_application.py index 5a7847e..e86d937 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -4,6 +4,7 @@ import pytest from webtest import TestApp import paste.db +import paste.store from paste import __main__, application DB = "file::memory:?cache=shared" @@ -11,7 +12,7 @@ DB = "file::memory:?cache=shared" @pytest.fixture def db(): - with paste.db.connect(DB) as d: + with paste.store.open(DB) as d: yield d @@ -24,7 +25,7 @@ def app(db): @pytest.fixture def token(db): - return b64encode(__main__.generate_token(db)).decode() + return b64encode(db.generate_token()).decode() @pytest.mark.parametrize("method", ["put", "post", "delete"]) @@ -229,7 +230,6 @@ def test_delete(app, token): ) assert res.status == "201 Created" assert res.headers["Location"] == res.request.url - etag = res.headers["ETag"] res = app.delete("/test_key", expect_errors=True) assert res.status == "401 Unauthorized" res = app.delete("/test_key", headers={"Authorization": f"APIKey {token}"}) -- cgit v1.2.3-54-g00ecf