aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2023-03-28 18:53:15 +0100
committerTomasz Kramkowski <tomasz@kramkow.ski>2023-03-28 20:10:48 +0100
commitd198fca95919cc78275d3d9fa8f1b0a8acfdbab3 (patch)
treec5e01750f40310df6a365da94b7e76ef540f8a27
parenta4669636144bf1f3d61bb5af80e23841f2ad8481 (diff)
downloadpaste-d198fca95919cc78275d3d9fa8f1b0a8acfdbab3.tar.gz
paste-d198fca95919cc78275d3d9fa8f1b0a8acfdbab3.tar.xz
paste-d198fca95919cc78275d3d9fa8f1b0a8acfdbab3.zip
Switch back to centralised opening of the database
Create Store instances when needed This will make more sense with following commits
-rw-r--r--paste/__init__.py25
-rw-r--r--paste/__main__.py29
-rw-r--r--paste/store.py6
-rw-r--r--tests/middleware/test_authenticate.py23
-rw-r--r--tests/middleware/test_open_store.py20
-rw-r--r--tests/test_application.py6
6 files changed, 59 insertions, 50 deletions
diff --git a/paste/__init__.py b/paste/__init__.py
index 567aa8e..155e927 100644
--- a/paste/__init__.py
+++ b/paste/__init__.py
@@ -8,7 +8,8 @@ from functools import wraps
from typing import Optional
from wsgiref.util import application_uri, request_uri
-from . import store
+from . import db
+from .store import Store
from .types import (
App,
Closable,
@@ -134,10 +135,10 @@ def options(app: App, environ: Env, start_response: StartResponse) -> Response:
@middleware
-def open_store(app: App, environ: Env, start_response: StartResponse) -> Response:
+def open_database(app: App, environ: Env, start_response: StartResponse) -> Response:
db_path = environ.get("PASTE_DB", DB_PATH)
- with store.open(db_path) as stor:
- environ["paste.store"] = stor
+ with db.connect(db_path) as conn:
+ environ["paste.db_conn"] = conn
return app(environ, start_response)
@@ -154,7 +155,7 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo
value = b64decode(value.encode(), validate=True)
except (binascii.Error, UnicodeEncodeError):
return False
- return environ["paste.store"].check_token(value)
+ return Store(environ["paste.db_conn"]).check_token(value)
if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth():
return app(environ, start_response)
@@ -169,13 +170,13 @@ def authenticate(app: App, environ: Env, start_response: StartResponse) -> Respo
@validate_method
@options
@if_none_match
-@open_store
+@open_database
@authenticate
def application(environ: Env, start_response: StartResponse) -> Response:
- stor = environ["paste.store"]
+ store = Store(environ["paste.db_conn"])
name = environ["PATH_INFO"]
if environ["REQUEST_METHOD"] == "GET":
- row = stor.get(name)
+ row = store.get(name)
if not row:
return simple_response(start_response, "404 Not Found")
content_type, content_hash, content = row
@@ -189,7 +190,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
)
return [content]
elif environ["REQUEST_METHOD"] == "HEAD":
- row = stor.head(name)
+ row = store.head(name)
if not row:
return simple_response(start_response, "404 Not Found")
content_type, content_hash, content_length = row
@@ -206,7 +207,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
content_type = environ.get("CONTENT_TYPE", "text/plain")
content_length = int(environ["CONTENT_LENGTH"])
content = environ["wsgi.input"].read(content_length)
- created, content_hash = stor.put(name, content, content_type)
+ created, content_hash = store.put(name, content, content_type)
start_response(
"201 Created" if created else "204 No Content",
[
@@ -219,7 +220,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
content_type = environ.get("CONTENT_TYPE", "text/plain")
content_length = int(environ["CONTENT_LENGTH"])
content = environ["wsgi.input"].read(content_length)
- path, content_hash = stor.post(name, content, content_type)
+ path, content_hash = store.post(name, content, content_type)
uri = application_uri(environ)
path = urllib.parse.quote(path)
if uri[-1] == "/" and path[:1] == "/":
@@ -234,7 +235,7 @@ def application(environ: Env, start_response: StartResponse) -> Response:
)
return []
elif environ["REQUEST_METHOD"] == "DELETE":
- if stor.delete(name):
+ if store.delete(name):
start_response("204 No Content", [])
return []
return simple_response(start_response, "404 Not Found")
diff --git a/paste/__main__.py b/paste/__main__.py
index 27ff72b..6e684c4 100644
--- a/paste/__main__.py
+++ b/paste/__main__.py
@@ -1,11 +1,12 @@
from base64 import b64decode, b64encode
-from contextlib import AbstractContextManager
+from collections.abc import Iterator
+from contextlib import contextmanager
from datetime import datetime, timezone
from os import getenv
from sys import argv, stderr
from wsgiref.simple_server import make_server
-from . import DB_PATH, application, store
+from . import DB_PATH, application, db, store
PROGRAM_NAME = "paste"
@@ -42,8 +43,10 @@ Environment:
db_path = getenv("PASTE_DB", DB_PATH)
-def open_db() -> AbstractContextManager[store.Store]:
- return store.open(db_path)
+@contextmanager
+def open_auth() -> Iterator[store.Auth]:
+ with db.connect(db_path) as conn:
+ yield store.Auth(conn)
def main():
@@ -54,11 +57,11 @@ def main():
httpd.serve_forever()
elif len(argv) == 2:
if argv[1] == "new-token":
- with open_db() as db:
- print(b64encode(db.generate_token()).decode())
+ with open_auth() as auth:
+ print(b64encode(auth.generate_token()).decode())
elif argv[1] == "list-tokens":
- with open_db() as db:
- for token_hash, created_at in db.get_tokens():
+ with open_auth() as auth:
+ for token_hash, created_at in auth.get_tokens():
created_at = datetime.fromtimestamp(created_at, timezone.utc)
print(f"{token_hash.hex()}\t{created_at.ctime()}")
elif argv[1] == "-h" or argv[1] == "--help":
@@ -70,12 +73,12 @@ def main():
elif len(argv) == 3:
if argv[1] == "delete-token":
token = argv[2]
- with open_db() as db:
+ with open_auth() as auth:
try:
try:
- db.delete_token(b64decode(token))
+ auth.delete_token(b64decode(token))
except ValueError:
- db.delete_token_hash(bytes.fromhex(token))
+ auth.delete_token_hash(bytes.fromhex(token))
except ValueError:
print("Malformed token", file=stderr)
exit(1)
@@ -83,9 +86,9 @@ def main():
print("Token not found", file=stderr)
exit(1)
elif argv[1] == "verify-token":
- with open_db() as db:
+ with open_auth() as auth:
try:
- if not db.check_token(b64decode(argv[2])):
+ if not auth.check_token(b64decode(argv[2])):
print("Token not found", file=stderr)
exit(1)
print("Found")
diff --git a/paste/store.py b/paste/store.py
index e7022f6..f69c03d 100644
--- a/paste/store.py
+++ b/paste/store.py
@@ -120,9 +120,3 @@ class Store:
"SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,)
).fetchone()
return count > 0
-
-
-@contextmanager
-def open(uri: str) -> Iterator[Store]:
- with db.connect(uri) as conn:
- yield Store(conn)
diff --git a/tests/middleware/test_authenticate.py b/tests/middleware/test_authenticate.py
index 2acfe92..9fccb32 100644
--- a/tests/middleware/test_authenticate.py
+++ b/tests/middleware/test_authenticate.py
@@ -34,7 +34,7 @@ def test_unauthenticated_request(app, method):
def test_unauthenticated_request_with_key(app, method):
environ = {
"REQUEST_METHOD": method,
- "paste.store": None,
+ "paste.db_conn": None,
"HTTP_AUTHORIZATION": "ApiKey AAAA",
}
response = call_app(app, environ)
@@ -64,13 +64,22 @@ def test_authenticate_malformed_key(app, method, key):
assert ("WWW-Authenticate", "APIKey") in response.headers
-class MockDB:
+class MockConnection:
def __init__(self, check_token):
self.check_token = check_token
+class MockStore:
+ def __init__(self, c):
+ assert isinstance(c, MockConnection)
+ self.conn = c
+
+ def check_token(self, tok):
+ return self.conn.check_token(tok)
+
+
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_check_token_fail(app, method):
+def test_authenticate_check_token_fail(app, method, monkeypatch):
check_token_called = False
token = b"test"
@@ -82,10 +91,11 @@ def test_authenticate_check_token_fail(app, method):
environ = {
"REQUEST_METHOD": method,
- "paste.store": MockDB(check_token),
+ "paste.db_conn": MockConnection(check_token),
"HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
}
+ monkeypatch.setattr("paste.Store", MockStore)
response = call_app(app, environ)
assert check_token_called
assert response.data == b"401 Unauthorized\n"
@@ -95,7 +105,7 @@ def test_authenticate_check_token_fail(app, method):
@pytest.mark.parametrize("method", ["POST", "PUT", "DELETE"])
-def test_authenticate_check_token_success(app, method):
+def test_authenticate_check_token_success(app, method, monkeypatch):
check_token_called = False
token = b"test"
@@ -107,10 +117,11 @@ def test_authenticate_check_token_success(app, method):
environ = {
"REQUEST_METHOD": method,
- "paste.store": MockDB(check_token),
+ "paste.db_conn": MockConnection(check_token),
"HTTP_AUTHORIZATION": f"APIKey {b64encode(token).decode()}",
}
+ monkeypatch.setattr("paste.Store", MockStore)
response = call_app(app, environ)
assert check_token_called
assert response.data == b"Hello, world!"
diff --git a/tests/middleware/test_open_store.py b/tests/middleware/test_open_store.py
index dd430e9..3e844a9 100644
--- a/tests/middleware/test_open_store.py
+++ b/tests/middleware/test_open_store.py
@@ -1,30 +1,30 @@
from contextlib import contextmanager
from wsgiref.validate import validator
-from paste import open_store
+from paste import open_database
from ..common_wsgi import call_app
-def test_open_store(monkeypatch):
+def test_open_database(monkeypatch):
db_path = "test_db_path"
- store = object()
+ conn = object()
@contextmanager
- def store_open(path):
- assert path == db_path
- yield store
+ def connect(uri):
+ assert uri == db_path
+ yield conn
@validator
- @open_store
+ @open_database
@validator
def app(environ, start_response):
- assert "paste.store" in environ
- assert environ["paste.store"] == store
+ assert "paste.db_conn" in environ
+ assert environ["paste.db_conn"] == conn
start_response("200 OK", [("Content-Type", "text/plain")])
return [b"Hello, World!"]
- monkeypatch.setattr("paste.store.open", store_open)
+ monkeypatch.setattr("paste.db.connect", connect)
response = call_app(app, environ={"PASTE_DB": db_path})
assert response.status == "200 OK"
assert response.data == b"Hello, World!"
diff --git a/tests/test_application.py b/tests/test_application.py
index e86d937..20cbeac 100644
--- a/tests/test_application.py
+++ b/tests/test_application.py
@@ -4,15 +4,15 @@ import pytest
from webtest import TestApp
import paste.db
-import paste.store
from paste import __main__, application
+from paste.store import Store
DB = "file::memory:?cache=shared"
@pytest.fixture
def db():
- with paste.store.open(DB) as d:
+ with paste.db.connect(DB) as d:
yield d
@@ -25,7 +25,7 @@ def app(db):
@pytest.fixture
def token(db):
- return b64encode(db.generate_token()).decode()
+ return b64encode(Store(db).generate_token()).decode()
@pytest.mark.parametrize("method", ["put", "post", "delete"])