aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2023-03-27 22:46:44 +0100
committerTomasz Kramkowski <tomasz@kramkow.ski>2023-03-27 22:46:44 +0100
commit77889ec30410280ef1ec1671b6226a8d0c0444f7 (patch)
treeb6f371ee6838b179681dc101c8c32b0ed088fe41
parent3a9629e49b4c7e1d10c89bbffa04d18e96948116 (diff)
downloadpaste-77889ec30410280ef1ec1671b6226a8d0c0444f7.tar.gz
paste-77889ec30410280ef1ec1671b6226a8d0c0444f7.tar.xz
paste-77889ec30410280ef1ec1671b6226a8d0c0444f7.zip
Throw all data manipulation code in one place
This means that everything now goes through a Store object, which should make testing a little bit easier.
-rw-r--r--paste/__init__.py28
-rw-r--r--paste/__main__.py67
-rw-r--r--paste/store.py201
-rw-r--r--tests/middleware/test_authenticate.py53
-rw-r--r--tests/middleware/test_open_database.py7
-rw-r--r--tests/test_application.py6
6 files changed, 182 insertions, 180 deletions
diff --git a/paste/__init__.py b/paste/__init__.py
index 36ce0cb..0f75587 100644
--- a/paste/__init__.py
+++ b/paste/__init__.py
@@ -5,11 +5,10 @@ import urllib.parse
from base64 import b64decode, b64encode
from collections.abc import Callable, Iterable
from functools import wraps
-from sqlite3 import Connection
from typing import Any, Optional, Protocol, runtime_checkable
from wsgiref.util import application_uri, request_uri
-from . import db, store
+from . import store
@runtime_checkable
@@ -154,18 +153,11 @@ def options(app: App, environ: Env, start_response: StartResponse) -> Response:
@middleware
def open_database(app: App, environ: Env, start_response: StartResponse) -> Response:
db_path = environ.get("PASTE_DB", DB_PATH)
- with db.connect(db_path) as conn:
- environ["paste.db_conn"] = conn
+ with store.open(db_path) as stor:
+ environ["paste.store"] = stor
return app(environ, start_response)
-def check_token(conn: Connection, token: bytes) -> bool:
- (count,) = conn.execute(
- "SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,)
- ).fetchone()
- return count == 1
-
-
@middleware
def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response:
def check_auth():
@@ -179,7 +171,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo
value = b64decode(value.encode(), validate=True)
except (binascii.Error, UnicodeEncodeError):
return False
- return check_token(environ["paste.db_conn"], value)
+ return environ["paste.store"].check_token(value)
if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth():
return app(environ, start_response)
@@ -197,10 +189,10 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo
@open_database
@authenticate
def application(environ: Env, start_response: StartResponse) -> Response:
- conn = environ["paste.db_conn"]
+ stor = environ["paste.store"]
name = environ["PATH_INFO"]
if environ["REQUEST_METHOD"] == "GET":
- row = store.get(conn, name)
+ row = stor.get(name)
if not row:
return simple_response(start_response, "404 Not Found")
content_type, content_hash, content = row
@@ -214,7 +206,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
)
return [content]
elif environ["REQUEST_METHOD"] == "HEAD":
- row = store.head(conn, name)
+ row = stor.head(name)
if not row:
return simple_response(start_response, "404 Not Found")
content_type, content_hash, content_length = row
@@ -231,7 +223,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
content_type = environ.get("CONTENT_TYPE", "text/plain")
content_length = int(environ["CONTENT_LENGTH"])
content = environ["wsgi.input"].read(content_length)
- created, content_hash = store.put(conn, name, content, content_type)
+ created, content_hash = stor.put(name, content, content_type)
start_response(
"201 Created" if created else "204 No Content",
[
@@ -244,7 +236,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
content_type = environ.get("CONTENT_TYPE", "text/plain")
content_length = int(environ["CONTENT_LENGTH"])
content = environ["wsgi.input"].read(content_length)
- path, content_hash = store.post(conn, name, content, content_type)
+ path, content_hash = stor.post(name, content, content_type)
uri = application_uri(environ)
path = urllib.parse.quote(path)
if uri[-1] == "/" and path[:1] == "/":
@@ -259,7 +251,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
)
return []
elif environ["REQUEST_METHOD"] == "DELETE":
- if store.delete(conn, name):
+ if stor.delete(name):
start_response("204 No Content", [])
return []
return simple_response(start_response, "404 Not Found")
diff --git a/paste/__main__.py b/paste/__main__.py
index 2b19baa..27ff72b 100644
--- a/paste/__main__.py
+++ b/paste/__main__.py
@@ -1,16 +1,12 @@
from base64 import b64decode, b64encode
from contextlib import AbstractContextManager
from datetime import datetime, timezone
-from hashlib import sha256
from os import getenv
-from secrets import token_bytes
-from sqlite3 import Connection
from sys import argv, stderr
from wsgiref.simple_server import make_server
-from . import DB_PATH, application, db
+from . import DB_PATH, application, store
-TOKEN_BYTES = 96 // 8
PROGRAM_NAME = "paste"
@@ -43,50 +39,11 @@ Environment:
)
-def generate_token(conn: Connection):
- token = token_bytes(TOKEN_BYTES)
- token_hash = sha256(token).digest()
- with conn:
- conn.execute("INSERT INTO token (hash) VALUES (?)", (token_hash,))
- return token
-
-
-def get_tokens(conn: Connection):
- return conn.execute(
- "SELECT hash, created_at FROM token ORDER BY created_at"
- ).fetchall()
-
-
-def delete_token_hash(conn: Connection, token_hash: bytes):
- if len(token_hash) != 256 // 8:
- raise ValueError("Invalid token hash")
- with conn:
- cur = conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
- if cur.rowcount <= 0:
- raise KeyError("Token hash does not exist")
-
-
-def delete_token(conn: Connection, token: bytes):
- if len(token) != TOKEN_BYTES:
- raise ValueError("Invalid token")
- return delete_token_hash(conn, sha256(token).digest())
-
-
-def check_token(conn: Connection, token: bytes):
- if len(token) != TOKEN_BYTES:
- raise ValueError("Invalid token")
- token_hash = sha256(token).digest()
- (count,) = conn.execute(
- "SELECT COUNT(*) FROM token WHERE hash = ?", (token_hash,)
- ).fetchone()
- return count > 0
-
-
db_path = getenv("PASTE_DB", DB_PATH)
-def connect() -> AbstractContextManager[Connection]:
- return db.connect(db_path)
+def open_db() -> AbstractContextManager[store.Store]:
+ return store.open(db_path)
def main():
@@ -97,11 +54,11 @@ def main():
httpd.serve_forever()
elif len(argv) == 2:
if argv[1] == "new-token":
- with connect() as conn:
- print(b64encode(generate_token(conn)).decode())
+ with open_db() as db:
+ print(b64encode(db.generate_token()).decode())
elif argv[1] == "list-tokens":
- with connect() as conn:
- for token_hash, created_at in get_tokens(conn):
+ with open_db() as db:
+ for token_hash, created_at in db.get_tokens():
created_at = datetime.fromtimestamp(created_at, timezone.utc)
print(f"{token_hash.hex()}\t{created_at.ctime()}")
elif argv[1] == "-h" or argv[1] == "--help":
@@ -113,12 +70,12 @@ def main():
elif len(argv) == 3:
if argv[1] == "delete-token":
token = argv[2]
- with connect() as conn:
+ with open_db() as db:
try:
try:
- delete_token(conn, b64decode(token))
+ db.delete_token(b64decode(token))
except ValueError:
- delete_token_hash(conn, bytes.fromhex(token))
+ db.delete_token_hash(bytes.fromhex(token))
except ValueError:
print("Malformed token", file=stderr)
exit(1)
@@ -126,9 +83,9 @@ def main():
print("Token not found", file=stderr)
exit(1)
elif argv[1] == "verify-token":
- with connect() as conn:
+ with open_db() as db:
try:
- if not check_token(conn, b64decode(argv[2])):
+ if not db.check_token(b64decode(argv[2])):
print("Token not found", file=stderr)
exit(1)
print("Found")
diff --git a/paste/store.py b/paste/store.py
index dd00edd..b7e471d 100644
--- a/paste/store.py
+++ b/paste/store.py
@@ -1,77 +1,132 @@
-from secrets import token_urlsafe
+from collections.abc import Iterator
+from contextlib import contextmanager
+from hashlib import sha256
+from secrets import token_bytes, token_urlsafe
from sqlite3 import Connection, IntegrityError
+from . import db
-def put(conn: Connection, name: str, content: bytes, content_type: str):
- with conn:
- conn.execute(
- "INSERT OR IGNORE INTO file (content) VALUES (?)",
- (content,),
- )
- (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone()
- cur = conn.execute(
- """UPDATE link
- SET content_type = ?, file_hash = ?
+TOKEN_BYTES = 96 // 8
+
+
+class Store:
+ def __init__(self, conn: Connection):
+ self.conn = conn
+
+ def put(self, name: str, content: bytes, content_type: str):
+ with self.conn:
+ self.conn.execute(
+ "INSERT OR IGNORE INTO file (content) VALUES (?)",
+ (content,),
+ )
+ (content_hash,) = self.conn.execute(
+ "SELECT DATA_HASH(?)", (content,)
+ ).fetchone()
+ cur = self.conn.execute(
+ """UPDATE link
+ SET content_type = ?, file_hash = ?
+ WHERE name_hash = DATA_HASH(?)""",
+ (content_type, content_hash, name),
+ )
+ if cur.rowcount == 1:
+ return False, content_hash
+ self.conn.execute(
+ """INSERT INTO link (
+ name, content_type, file_hash
+ ) VALUES (?, ?, ?)""",
+ (name, content_type, content_hash),
+ )
+ return True, content_hash
+
+ def post(self, prefix: str, content: bytes, content_type: str):
+ with self.conn:
+ self.conn.execute(
+ "INSERT OR IGNORE INTO file (content) VALUES (?)",
+ (content,),
+ )
+ (content_hash,) = self.conn.execute(
+ "SELECT DATA_HASH(?)", (content,)
+ ).fetchone()
+ for _ in range(16):
+ name = prefix + token_urlsafe(5)
+ try:
+ self.conn.execute(
+ """INSERT INTO link (name, content_type, file_hash)
+ VALUES (?, ?, ?)""",
+ (name, content_type, content_hash),
+ )
+ except IntegrityError:
+ continue
+ break
+ else:
+ raise RuntimeError("Could not insert a link in 16 attempts")
+ return name, content_hash
+
+ def get(self, name: str):
+ row = self.conn.execute(
+ """SELECT link.content_type, file.hash, file.content
+ FROM link
+ JOIN file ON file.hash = link.file_hash
+ WHERE name_hash = DATA_HASH(?)""",
+ (name,),
+ ).fetchone()
+ return row
+
+ def head(self, name: str):
+ row = self.conn.execute(
+ """SELECT link.content_type, file.hash, length(file.content)
+ FROM link
+ JOIN file ON file.hash = link.file_hash
WHERE name_hash = DATA_HASH(?)""",
- (content_type, content_hash, name),
- )
- if cur.rowcount == 1:
- return False, content_hash
- conn.execute(
- """INSERT INTO link (
- name, content_type, file_hash
- ) VALUES (?, ?, ?)""",
- (name, content_type, content_hash),
- )
- return True, content_hash
-
-
-def post(conn: Connection, prefix: str, content: bytes, content_type: str):
- with conn:
- conn.execute(
- "INSERT OR IGNORE INTO file (content) VALUES (?)",
- (content,),
- )
- (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone()
- for _ in range(16):
- name = prefix + token_urlsafe(5)
- try:
- conn.execute(
- """INSERT INTO link (name, content_type, file_hash)
- VALUES (?, ?, ?)""",
- (name, content_type, content_hash),
- )
- except IntegrityError:
- continue
- break
- else:
- raise RuntimeError("Could not insert a link in 16 attempts")
- return name, content_hash
-
-
-def get(conn: Connection, name: str):
- row = conn.execute(
- """SELECT link.content_type, file.hash, file.content
- FROM link
- JOIN file ON file.hash = link.file_hash
- WHERE name_hash = DATA_HASH(?)""",
- (name,),
- ).fetchone()
- return row
-
-
-def head(conn: Connection, name: str):
- row = conn.execute(
- """SELECT link.content_type, file.hash, length(file.content)
- FROM link
- JOIN file ON file.hash = link.file_hash
- WHERE name_hash = DATA_HASH(?)""",
- (name,),
- ).fetchone()
- return row
-
-
-def delete(conn: Connection, name: str):
- with conn:
- cur = conn.execute("DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,))
- return cur.rowcount == 1
+ (name,),
+ ).fetchone()
+ return row
+
+ def delete(self, name: str):
+ with self.conn:
+ cur = self.conn.execute(
+ "DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,)
+ )
+ return cur.rowcount == 1
+
+ def generate_token(self):
+ token = token_bytes(TOKEN_BYTES)
+ with self.conn:
+ self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,))
+ return token
+
+ def get_tokens(self):
+ return self.conn.execute(
+ "SELECT hash, created_at FROM token ORDER BY created_at"
+ ).fetchall()
+
+ def delete_token_hash(self, token_hash: bytes):
+ if len(token_hash) != 256 // 8:
+ raise ValueError("Invalid token hash")
+ with self.conn:
+ cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
+ if cur.rowcount <= 0:
+ raise KeyError("Token hash does not exist")
+
+ def delete_token(self, token: bytes):
+ if len(token) != TOKEN_BYTES:
+ raise ValueError("Invalid token")
+ return self.delete_token_hash(sha256(token).digest())
+
+ def check_token(self, token: bytes):
+ if len(token) != TOKEN_BYTES:
+ raise ValueError("Invalid token")
+ (count,) = self.conn.execute(
+ "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,)
+ ).fetchone()
+ return count > 0
+
+ @property
+ def version(self):
+ return db.get_version(self.conn)
+
+
+@contextmanager
+def open(uri: str) -> Iterator[Store]:
+ with db.connect(uri) as conn:
+ yield Store(conn)
diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py
index 2395316..2acfe92 100644
--- a/tests/middleware/test_authenticate.py
+++ b/tests/middleware/test_authenticate.py
@@ -22,8 +22,7 @@ def app():
@pytest.mark.parametrize("method", ["GET", "HEAD"])
-def test_unauthenticated_request(app, method, monkeypatch):
- monkeypatch.delattr(paste, "check_token")
+def test_unauthenticated_request(app, method):
environ = {"REQUEST_METHOD": method}
response = call_app(app, environ)
assert response.data == b"Hello, world!"
@@ -32,11 +31,10 @@ def test_unauthenticated_request(app, method, monkeypatch):
@pytest.mark.parametrize("method", ["GET", "HEAD"])
-def test_unauthenticated_request_with_key(app, method, monkeypatch):
- monkeypatch.delattr(paste, "check_token")
+def test_unauthenticated_request_with_key(app, method):
environ = {
"REQUEST_METHOD": method,
- "paste.db_conn": None,
+ "paste.store": None,
"HTTP_AUTHORIZATION": "ApiKey AAAA",
}
response = call_app(app, environ)
@@ -46,8 +44,7 @@ def test_unauthenticated_request_with_key(app, method, monkeypatch):
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_no_header(app, method, monkeypatch):
- monkeypatch.delattr(paste, "check_token")
+def test_authenticate_no_header(app, method):
environ = {"REQUEST_METHOD": method}
response = call_app(app, environ)
assert response.data == b"401 Unauthorized\n"
@@ -58,8 +55,7 @@ def test_authenticate_no_header(app, method, monkeypatch):
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
@pytest.mark.parametrize("key", ["ApiKey AAAA", "APIKey AAA", "APIKey AAAA", "AAAA"])
-def test_authenticate_malformed_key(app, method, key, monkeypatch):
- monkeypatch.delattr(paste, "check_token")
+def test_authenticate_malformed_key(app, method, key):
environ = {"REQUEST_METHOD": method, "HTTP_AUTHORIZATION": key}
response = call_app(app, environ)
assert response.data == b"401 Unauthorized\n"
@@ -68,24 +64,28 @@ def test_authenticate_malformed_key(app, method, key, monkeypatch):
assert ("WWW-Authenticate", "APIKey") in response.headers
+class MockDB:
+ def __init__(self, check_token):
+ self.check_token = check_token
+
+
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_check_token_fail(app, method, monkeypatch):
+def test_authenticate_check_token_fail(app, method):
check_token_called = False
token = b"test"
- environ = {
- "REQUEST_METHOD": method,
- "paste.db_conn": object(),
- "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
- }
- def check_token(conn, tok):
+ def check_token(tok):
nonlocal check_token_called
- assert conn == environ["paste.db_conn"]
assert tok == token
check_token_called = True
return False
- monkeypatch.setattr(paste, "check_token", check_token)
+ environ = {
+ "REQUEST_METHOD": method,
+ "paste.store": MockDB(check_token),
+ "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
+ }
+
response = call_app(app, environ)
assert check_token_called
assert response.data == b"401 Unauthorized\n"
@@ -95,23 +95,22 @@ def test_authenticate_check_token_fail(app, method, monkeypatch):
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_check_token_success(app, method, monkeypatch):
+def test_authenticate_check_token_success(app, method):
check_token_called = False
token = b"test"
- environ = {
- "REQUEST_METHOD": method,
- "paste.db_conn": object(),
- "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
- }
- def check_token(conn, tok):
+ def check_token(tok):
nonlocal check_token_called
- assert conn == environ["paste.db_conn"]
assert tok == token
check_token_called = True
return True
- monkeypatch.setattr(paste, "check_token", check_token)
+ environ = {
+ "REQUEST_METHOD": method,
+ "paste.store": MockDB(check_token),
+ "HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
+ }
+
response = call_app(app, environ)
assert check_token_called
assert response.data == b"Hello, world!"
diff --git a/tests/middleware/test_open_database.py b/tests/middleware/test_open_database.py
index 43ebb07..1584ac0 100644
--- a/tests/middleware/test_open_database.py
+++ b/tests/middleware/test_open_database.py
@@ -15,10 +15,9 @@ def app():
@open_database
@validator
def app(environ, start_response):
- assert "paste.db_conn" in environ
- conn = environ["paste.db_conn"]
- (ver,) = conn.execute("PRAGMA user_version").fetchone()
- assert ver > 0
+ assert "paste.store" in environ
+ db = environ["paste.store"]
+ assert db.version > 0
start_response("200 OK", [("Content-Type", "text/plain")])
return [b"Hello, World!"]
diff --git a/tests/test_application.py b/tests/test_application.py
index 5a7847e..e86d937 100644
--- a/tests/test_application.py
+++ b/tests/test_application.py
@@ -4,6 +4,7 @@ import pytest
from webtest import TestApp
import paste.db
+import paste.store
from paste import __main__, application
DB = "file::memory:?cache=shared"
@@ -11,7 +12,7 @@ DB = "file::memory:?cache=shared"
@pytest.fixture
def db():
- with paste.db.connect(DB) as d:
+ with paste.store.open(DB) as d:
yield d
@@ -24,7 +25,7 @@ def app(db):
@pytest.fixture
def token(db):
- return b64encode(__main__.generate_token(db)).decode()
+ return b64encode(db.generate_token()).decode()
@pytest.mark.parametrize("method", ["put", "post", "delete"])
@@ -229,7 +230,6 @@ def test_delete(app, token):
)
assert res.status == "201 Created"
assert res.headers["Location"] == res.request.url
- etag = res.headers["ETag"]
res = app.delete("/test_key", expect_errors=True)
assert res.status == "401 Unauthorized"
res = app.delete("/test_key", headers={"Authorization": f"APIKey {token}"})