from base64 import b64decode, b64encode from contextlib import AbstractContextManager from datetime import datetime, timezone from hashlib import sha256 from os import getenv from secrets import token_bytes from sqlite3 import Connection from sys import argv, stderr from wsgiref.simple_server import make_server from . import DB_PATH, application, db TOKEN_BYTES = 96 // 8 PROGRAM_NAME = "paste" def print_usage() -> None: print( f"Usage: {PROGRAM_NAME} [-h|serve|new-token|list-tokens|delete-token]", file=stderr, ) def print_help() -> None: print( f"""A simple WSGI paste site Commands: serve (default) Start a basic HTTP server (do NOT use in production) new-token Generate a new access token list-tokens List access token hashes (sha256 hash, created at) check-token Check if a token is valid and exists in the database delete-token Delete a token (specify token or hash) Options: -h Show this help Environment: PASTE_HOST The HTTP server host (default: localhost) PASTE_PORT The HTTP server port (default: 8080) PASTE_DB The path to the sqlite3 database (default: {DB_PATH})""", file=stderr, ) def generate_token(conn: Connection): token = token_bytes(TOKEN_BYTES) token_hash = sha256(token).digest() with conn: conn.execute("INSERT INTO token (hash) VALUES (?)", (token_hash,)) return token def get_tokens(conn: Connection): return conn.execute( "SELECT hash, created_at FROM token ORDER BY created_at" ).fetchall() def delete_token_hash(conn: Connection, token_hash: bytes): if len(token_hash) != 256 // 8: raise ValueError("Invalid token hash") with conn: cur = conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,)) if cur.rowcount <= 0: raise KeyError("Token hash does not exist") def delete_token(conn: Connection, token: bytes): if len(token) != TOKEN_BYTES: raise ValueError("Invalid token") return delete_token_hash(conn, sha256(token).digest()) def check_token(conn: Connection, token: bytes): if len(token) != TOKEN_BYTES: raise ValueError("Invalid token") token_hash = sha256(token).digest() (count,) = conn.execute( "SELECT COUNT(*) FROM token WHERE hash = ?", (token_hash,) ).fetchone() return count > 0 db_path = getenv("PASTE_DB", DB_PATH) def connect() -> AbstractContextManager[Connection]: return db.connect(db_path) def main(): if len(argv) <= 1 or len(argv) == 2 and argv[1] == "serve": host = getenv("PASTE_HOST", "localhost") port = int(getenv("PASTE_PORT", "8080")) httpd = make_server(host, port, application) httpd.serve_forever() elif len(argv) == 2: if argv[1] == "new-token": with connect() as conn: print(b64encode(generate_token(conn)).decode()) elif argv[1] == "list-tokens": with connect() as conn: for token_hash, created_at in get_tokens(conn): created_at = datetime.fromtimestamp(created_at, timezone.utc) print(f"{token_hash.hex()}\t{created_at.ctime()}") elif argv[1] == "-h" or argv[1] == "--help": print_usage() print_help() else: print_usage() exit(1) elif len(argv) == 3: if argv[1] == "delete-token": token = argv[2] with connect() as conn: try: try: delete_token(conn, b64decode(token)) except ValueError: delete_token_hash(conn, bytes.fromhex(token)) except ValueError: print("Malformed token", file=stderr) exit(1) except KeyError: print("Token not found", file=stderr) exit(1) elif argv[1] == "verify-token": with connect() as conn: try: if not check_token(conn, b64decode(argv[2])): print("Token not found", file=stderr) exit(1) print("Found") except ValueError: print("Malformed token", file=stderr) main()