aboutsummaryrefslogtreecommitdiffstats
path: root/paste/__main__.py
diff options
context:
space:
mode:
Diffstat (limited to 'paste/__main__.py')
-rw-r--r--paste/__main__.py67
1 files changed, 12 insertions, 55 deletions
diff --git a/paste/__main__.py b/paste/__main__.py
index 2b19baa..27ff72b 100644
--- a/paste/__main__.py
+++ b/paste/__main__.py
@@ -1,16 +1,12 @@
from base64 import b64decode, b64encode
from contextlib import AbstractContextManager
from datetime import datetime, timezone
-from hashlib import sha256
from os import getenv
-from secrets import token_bytes
-from sqlite3 import Connection
from sys import argv, stderr
from wsgiref.simple_server import make_server
-from . import DB_PATH, application, db
+from . import DB_PATH, application, store
-TOKEN_BYTES = 96 // 8
PROGRAM_NAME = "paste"
@@ -43,50 +39,11 @@ Environment:
)
-def generate_token(conn: Connection):
- token = token_bytes(TOKEN_BYTES)
- token_hash = sha256(token).digest()
- with conn:
- conn.execute("INSERT INTO token (hash) VALUES (?)", (token_hash,))
- return token
-
-
-def get_tokens(conn: Connection):
- return conn.execute(
- "SELECT hash, created_at FROM token ORDER BY created_at"
- ).fetchall()
-
-
-def delete_token_hash(conn: Connection, token_hash: bytes):
- if len(token_hash) != 256 // 8:
- raise ValueError("Invalid token hash")
- with conn:
- cur = conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
- if cur.rowcount <= 0:
- raise KeyError("Token hash does not exist")
-
-
-def delete_token(conn: Connection, token: bytes):
- if len(token) != TOKEN_BYTES:
- raise ValueError("Invalid token")
- return delete_token_hash(conn, sha256(token).digest())
-
-
-def check_token(conn: Connection, token: bytes):
- if len(token) != TOKEN_BYTES:
- raise ValueError("Invalid token")
- token_hash = sha256(token).digest()
- (count,) = conn.execute(
- "SELECT COUNT(*) FROM token WHERE hash = ?", (token_hash,)
- ).fetchone()
- return count > 0
-
-
db_path = getenv("PASTE_DB", DB_PATH)
-def connect() -> AbstractContextManager[Connection]:
- return db.connect(db_path)
+def open_db() -> AbstractContextManager[store.Store]:
+ return store.open(db_path)
def main():
@@ -97,11 +54,11 @@ def main():
httpd.serve_forever()
elif len(argv) == 2:
if argv[1] == "new-token":
- with connect() as conn:
- print(b64encode(generate_token(conn)).decode())
+ with open_db() as db:
+ print(b64encode(db.generate_token()).decode())
elif argv[1] == "list-tokens":
- with connect() as conn:
- for token_hash, created_at in get_tokens(conn):
+ with open_db() as db:
+ for token_hash, created_at in db.get_tokens():
created_at = datetime.fromtimestamp(created_at, timezone.utc)
print(f"{token_hash.hex()}\t{created_at.ctime()}")
elif argv[1] == "-h" or argv[1] == "--help":
@@ -113,12 +70,12 @@ def main():
elif len(argv) == 3:
if argv[1] == "delete-token":
token = argv[2]
- with connect() as conn:
+ with open_db() as db:
try:
try:
- delete_token(conn, b64decode(token))
+ db.delete_token(b64decode(token))
except ValueError:
- delete_token_hash(conn, bytes.fromhex(token))
+ db.delete_token_hash(bytes.fromhex(token))
except ValueError:
print("Malformed token", file=stderr)
exit(1)
@@ -126,9 +83,9 @@ def main():
print("Token not found", file=stderr)
exit(1)
elif argv[1] == "verify-token":
- with connect() as conn:
+ with open_db() as db:
try:
- if not check_token(conn, b64decode(argv[2])):
+ if not db.check_token(b64decode(argv[2])):
print("Token not found", file=stderr)
exit(1)
print("Found")