diff options
Diffstat (limited to 'paste/__main__.py')
-rw-r--r-- | paste/__main__.py | 67 |
1 files changed, 12 insertions, 55 deletions
diff --git a/paste/__main__.py b/paste/__main__.py index 2b19baa..27ff72b 100644 --- a/paste/__main__.py +++ b/paste/__main__.py @@ -1,16 +1,12 @@ from base64 import b64decode, b64encode from contextlib import AbstractContextManager from datetime import datetime, timezone -from hashlib import sha256 from os import getenv -from secrets import token_bytes -from sqlite3 import Connection from sys import argv, stderr from wsgiref.simple_server import make_server -from . import DB_PATH, application, db +from . import DB_PATH, application, store -TOKEN_BYTES = 96 // 8 PROGRAM_NAME = "paste" @@ -43,50 +39,11 @@ Environment: ) -def generate_token(conn: Connection): - token = token_bytes(TOKEN_BYTES) - token_hash = sha256(token).digest() - with conn: - conn.execute("INSERT INTO token (hash) VALUES (?)", (token_hash,)) - return token - - -def get_tokens(conn: Connection): - return conn.execute( - "SELECT hash, created_at FROM token ORDER BY created_at" - ).fetchall() - - -def delete_token_hash(conn: Connection, token_hash: bytes): - if len(token_hash) != 256 // 8: - raise ValueError("Invalid token hash") - with conn: - cur = conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,)) - if cur.rowcount <= 0: - raise KeyError("Token hash does not exist") - - -def delete_token(conn: Connection, token: bytes): - if len(token) != TOKEN_BYTES: - raise ValueError("Invalid token") - return delete_token_hash(conn, sha256(token).digest()) - - -def check_token(conn: Connection, token: bytes): - if len(token) != TOKEN_BYTES: - raise ValueError("Invalid token") - token_hash = sha256(token).digest() - (count,) = conn.execute( - "SELECT COUNT(*) FROM token WHERE hash = ?", (token_hash,) - ).fetchone() - return count > 0 - - db_path = getenv("PASTE_DB", DB_PATH) -def connect() -> AbstractContextManager[Connection]: - return db.connect(db_path) +def open_db() -> AbstractContextManager[store.Store]: + return store.open(db_path) def main(): @@ -97,11 +54,11 @@ def main(): httpd.serve_forever() elif len(argv) == 2: if argv[1] == "new-token": - with connect() as conn: - print(b64encode(generate_token(conn)).decode()) + with open_db() as db: + print(b64encode(db.generate_token()).decode()) elif argv[1] == "list-tokens": - with connect() as conn: - for token_hash, created_at in get_tokens(conn): + with open_db() as db: + for token_hash, created_at in db.get_tokens(): created_at = datetime.fromtimestamp(created_at, timezone.utc) print(f"{token_hash.hex()}\t{created_at.ctime()}") elif argv[1] == "-h" or argv[1] == "--help": @@ -113,12 +70,12 @@ def main(): elif len(argv) == 3: if argv[1] == "delete-token": token = argv[2] - with connect() as conn: + with open_db() as db: try: try: - delete_token(conn, b64decode(token)) + db.delete_token(b64decode(token)) except ValueError: - delete_token_hash(conn, bytes.fromhex(token)) + db.delete_token_hash(bytes.fromhex(token)) except ValueError: print("Malformed token", file=stderr) exit(1) @@ -126,9 +83,9 @@ def main(): print("Token not found", file=stderr) exit(1) elif argv[1] == "verify-token": - with connect() as conn: + with open_db() as db: try: - if not check_token(conn, b64decode(argv[2])): + if not db.check_token(b64decode(argv[2])): print("Token not found", file=stderr) exit(1) print("Found") |