import binascii import sys import traceback from base64 import b64decode, b64encode from collections.abc import Callable, Iterable from functools import wraps from sqlite3 import Connection from typing import Any, Optional, Protocol, runtime_checkable from wsgiref.util import request_uri from . import db, store @runtime_checkable class Closable(Protocol): def close(self): ... class StartResponse(Protocol): def __call__( self, status: str, headers: list[tuple[str, str]], exc_info: Optional[tuple] = ..., /, ) -> Callable[[bytes], object]: ... Env = dict[str, Any] App = Callable[[Env, StartResponse], Iterable[bytes]] Response = Iterable[bytes] DB_PATH = "paste.sqlite3" def simple_response( start_response: StartResponse, status: str, exc_info: Optional[tuple] = None ) -> Response: body = (status + "\n").encode() start_response( status, [ ("Content-Type", "text/plain"), ("Content-Length", str(len(body))), ], exc_info, ) return [body] def redirect(start_response: StartResponse, location: bytes, typ: str) -> Response: status = { "text/x.redirect.301": "301 Moved Permanently", "text/x.redirect.302": "302 Found", }.get(typ) if not status: return simple_response(start_response, "500 Internal Server Error") body = location start_response( status, [ ("Location", location.decode()), ("Content-Type", "text/plain"), ("Content-Length", str(len(body))), ], ) return [body] ProtoMiddleware = Callable[[App, Env, StartResponse], Response] Middleware = Callable[[App], App] def middleware(f: ProtoMiddleware) -> Middleware: @wraps(f) def outer(app: App): @wraps(app) def inner(environ: Env, start_response: StartResponse): return f(app, environ, start_response) return inner return outer @middleware def catch_exceptions(app: App, environ: Env, start_response: StartResponse) -> Response: try: return app(environ, start_response) except Exception as e: print("".join(traceback.format_exception(type(e), e, e.__traceback__))) return simple_response( start_response, "500 Internal Server Error", sys.exc_info() ) @middleware def validate_method(app: App, environ: Env, start_response: StartResponse) -> Response: if environ["REQUEST_METHOD"] in {"GET", "HEAD", "POST", "PUT", "DELETE", "OPTIONS"}: return app(environ, start_response) if environ["REQUEST_METHOD"] in {"CONNECT", "TRACE", "PATCH"}: return simple_response(start_response, "405 Method Not Allowed") return simple_response(start_response, "501 Not Implemented") @middleware def if_none_match(app: App, environ: Env, start_response: StartResponse) -> Response: if "HTTP_IF_NONE_MATCH" not in environ: return app(environ, start_response) if_none_match = environ["HTTP_IF_NONE_MATCH"] del environ["HTTP_IF_NONE_MATCH"] if environ["REQUEST_METHOD"] not in {"GET", "HEAD"}: return app(environ, start_response) head_env = environ.copy() head_env["REQUEST_METHOD"] = "HEAD" etag = None def head_start_response( status: str, headers: list[tuple[str, str]], exc_info: Optional[tuple] = None, ) -> Callable[[bytes], object]: _, _ = status, exc_info nonlocal etag for key, value in headers: if key == "ETag": etag = value[1:-1] return lambda _: None resp = app(head_env, head_start_response) if isinstance(resp, Closable): resp.close() if not isinstance(etag, str): return app(environ, start_response) if if_none_match == "*": start_response("304 Not Modified", [("ETag", etag)]) return [] etags = if_none_match.split(",") etags = {e.strip(" \t").removeprefix("W/") for e in etags} for e in etags: if e[0] != '"' or e[-1] != '"': return simple_response(start_response, "400 Bad Request") etags = {e[1:-1] for e in etags} if isinstance(etag, str) and etag in etags: start_response("304 Not Modified", [("ETag", etag)]) return [] return app(environ, start_response) @middleware def options(app: App, environ: Env, start_response: StartResponse) -> Response: if environ["REQUEST_METHOD"] != "OPTIONS": return app(environ, start_response) start_response( "204 No Content", [ ("Allow", "GET, HEAD, POST, PUT, DELETE, OPTIONS"), ], ) return [] @middleware def open_database(app: App, environ: Env, start_response: StartResponse) -> Response: db_path = environ.get("PASTE_DB", DB_PATH) with db.connect(db_path) as conn: environ["paste.db_conn"] = conn return app(environ, start_response) def check_auth(conn: Connection, auth: Optional[str]) -> bool: if not auth or not auth.startswith("Bearer "): return False try: token = b64decode(auth.removeprefix("Bearer ").encode()) except binascii.Error: return False (count,) = conn.execute( "SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,) ).fetchone() return count == 1 @middleware def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: conn = environ["paste.db_conn"] token = environ.get("HTTP_AUTHORIZATION") if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(conn, token): return app(environ, start_response) return simple_response(start_response, "401 Unauthorized") @catch_exceptions @validate_method @options @if_none_match @open_database @authenticate def application(environ: Env, start_response: StartResponse) -> Response: conn = environ["paste.db_conn"] name = environ["PATH_INFO"] if environ["REQUEST_METHOD"] == "GET": row = store.get(conn, name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content = row if content_type.startswith("text/x.redirect"): return redirect(start_response, content, content_type) start_response( "200 OK", [ ("Content-Type", content_type), ("Content-Length", str(len(content))), ("ETag", f'"{b64encode(content_hash).decode()}"'), ], ) return [content] elif environ["REQUEST_METHOD"] == "HEAD": row = store.head(conn, name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content_length, opt_content = row if content_type.startswith("text/x.redirect"): return redirect(start_response, opt_content, content_type) start_response( "200 OK", [ ("Content-Type", content_type), ("Content-Length", content_length), ("ETag", f'"{b64encode(content_hash).decode()}"'), ], ) return [] elif environ["REQUEST_METHOD"] in {"POST", "PUT"}: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) store.put(conn, name, content, content_type) return redirect( start_response, request_uri(environ).encode(), "text/x.redirect.302" ) elif environ["REQUEST_METHOD"] == "DELETE": store.delete(conn, name) return simple_response(start_response, "204 No Content") return simple_response(start_response, "500 Internal Server Error")