import binascii import sys import traceback from base64 import b64decode, b64encode from import Callable, Iterable from functools import wraps from sqlite3 import Connection from typing import Any, Optional, Protocol, runtime_checkable from wsgiref.util import request_uri from . import db, store @runtime_checkable class Closable(Protocol): def close(self): ... class StartResponse(Protocol): def __call__( self, status: str, headers: list[tuple[str, str]], exc_info: Optional[tuple] = ..., /, ) -> Callable[[bytes], object]: ... Env = dict[str, Any] App = Callable[[Env, StartResponse], Iterable[bytes]] Response = Iterable[bytes] DB_PATH = "paste.sqlite3" def simple_response( start_response: StartResponse, status: str, exc_info: Optional[tuple] = None ) -> Response: body = (status + "\n").encode() start_response( status, [ ("Content-Type", "text/plain"), ("Content-Length", str(len(body))), ], exc_info, ) return [body] def redirect(start_response: StartResponse, location: bytes, typ: str) -> Response: status = { "text/x.redirect.301": "301 Moved Permanently", "text/x.redirect.302": "302 Found", }.get(typ) if not status: return simple_response(start_response, "500 Internal Server Error") body = location start_response( status, [ ("Location", location.decode()), ("Content-Type", "text/plain"), ("Content-Length", str(len(body))), ], ) return [body] ProtoMiddleware = Callable[[App, Env, StartResponse], Response] Middleware = Callable[[App], App] def middleware(f: ProtoMiddleware) -> Middleware: @wraps(f) def outer(app: App): @wraps(app) def inner(environ: Env, start_response: StartResponse): return f(app, environ, start_response) return inner return outer @middleware def catch_exceptions(app: App, environ: Env, start_response: StartResponse) -> Response: try: return app(environ, start_response) except Exception as e: print("".join(traceback.format_exception(type(e), e, e.__traceback__))) return simple_response( start_response, "500 Internal Server Error", sys.exc_info() ) @middleware def validate_method(app: App, environ: Env, start_response: StartResponse) -> Response: if environ["REQUEST_METHOD"] in {"GET", "HEAD", "POST", "PUT", "DELETE", "OPTIONS"}: return app(environ, start_response) if environ["REQUEST_METHOD"] in {"CONNECT", "TRACE", "PATCH"}: return simple_response(start_response, "405 Method Not Allowed") return simple_response(start_response, "501 Not Implemented") @middleware def if_none_match(app: App, environ: Env, start_response: StartResponse) -> Response: if "HTTP_IF_NONE_MATCH" not in environ: return app(environ, start_response) if_none_match = environ["HTTP_IF_NONE_MATCH"] del environ["HTTP_IF_NONE_MATCH"] if environ["REQUEST_METHOD"] not in {"GET", "HEAD"}: return app(environ, start_response) head_env = environ.copy() head_env["REQUEST_METHOD"] = "HEAD" etag = None def head_start_response( status: str, headers: list[tuple[str, str]], exc_info: Optional[tuple] = None, ) -> Callable[[bytes], object]: _, _ = status, exc_info nonlocal etag for key, value in headers: if key == "ETag": etag = value[1:-1] return lambda _: None resp = app(head_env, head_start_response) if isinstance(resp, Closable): resp.close() if not isinstance(etag, str): return app(environ, start_response) if if_none_match == "*": start_response("304 Not Modified", [("ETag", etag)]) return [] etags = if_none_match.split(",") etags = {e.strip(" \t").removeprefix("W/") for e in etags} for e in etags: if e[0] != '"' or e[-1] != '"': return simple_response(start_response, "400 Bad Request") etags = {e[1:-1] for e in etags} if isinstance(etag, str) and etag in etags: start_response("304 Not Modified", [("ETag", etag)]) return [] return app(environ, start_response) @middleware def options(app: App, environ: Env, start_response: StartResponse) -> Response: if environ["REQUEST_METHOD"] != "OPTIONS": return app(environ, start_response) start_response( "204 No Content", [ ("Allow", "GET, HEAD, POST, PUT, DELETE, OPTIONS"), ], ) return [] @middleware def open_database(app: App, environ: Env, start_response: StartResponse) -> Response: db_path = environ.get("PASTE_DB", DB_PATH) with db.connect(db_path) as conn: environ["paste.db_conn"] = conn return app(environ, start_response) def check_auth(conn: Connection, auth: Optional[str]) -> bool: if not auth or not auth.startswith("Bearer "): return False try: token = b64decode(auth.removeprefix("Bearer ").encode()) except binascii.Error: return False (count,) = conn.execute( "SELECT COUNT(*) FROM token WHERE hash = sha256(?)", (token,) ).fetchone() return count == 1 @middleware def authenticate(app: App, environ: Env, start_response: StartResponse) -> Response: conn = environ["paste.db_conn"] token = environ.get("HTTP_AUTHORIZATION") if environ["REQUEST_METHOD"] in {"GET", "HEAD"} or check_auth(conn, token): return app(environ, start_response) return simple_response(start_response, "401 Unauthorized") @catch_exceptions @validate_method @options @if_none_match @open_database @authenticate def application(environ: Env, start_response: StartResponse) -> Response: conn = environ["paste.db_conn"] name = environ["PATH_INFO"] if environ["REQUEST_METHOD"] == "GET": row = store.get(conn, name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content = row if content_type.startswith("text/x.redirect"): return redirect(start_response, content, content_type) start_response( "200 OK", [ ("Content-Type", content_type), ("Content-Length", str(len(content))), ("ETag", f'"{b64encode(content_hash).decode()}"'), ], ) return [content] elif environ["REQUEST_METHOD"] == "HEAD": row = store.head(conn, name) if not row: return simple_response(start_response, "404 Not Found") content_type, content_hash, content_length, opt_content = row if content_type.startswith("text/x.redirect"): return redirect(start_response, opt_content, content_type) start_response( "200 OK", [ ("Content-Type", content_type), ("Content-Length", content_length), ("ETag", f'"{b64encode(content_hash).decode()}"'), ], ) return [] elif environ["REQUEST_METHOD"] in {"POST", "PUT"}: content_type = environ.get("CONTENT_TYPE", "text/plain") content_length = int(environ["CONTENT_LENGTH"]) content = environ["wsgi.input"].read(content_length) store.put(conn, name, content, content_type) return redirect( start_response, request_uri(environ).encode(), "text/x.redirect.302" ) elif environ["REQUEST_METHOD"] == "DELETE": store.delete(conn, name) return simple_response(start_response, "204 No Content") return simple_response(start_response, "500 Internal Server Error")