diff options
Diffstat (limited to 'paste/store.py')
-rw-r--r-- | paste/store.py | 201 |
1 files changed, 128 insertions, 73 deletions
diff --git a/paste/store.py b/paste/store.py index dd00edd..b7e471d 100644 --- a/paste/store.py +++ b/paste/store.py @@ -1,77 +1,132 @@ -from secrets import token_urlsafe +from collections.abc import Iterator +from contextlib import contextmanager +from hashlib import sha256 +from secrets import token_bytes, token_urlsafe from sqlite3 import Connection, IntegrityError +from . import db -def put(conn: Connection, name: str, content: bytes, content_type: str): - with conn: - conn.execute( - "INSERT OR IGNORE INTO file (content) VALUES (?)", - (content,), - ) - (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone() - cur = conn.execute( - """UPDATE link - SET content_type = ?, file_hash = ? +TOKEN_BYTES = 96 // 8 + + +class Store: + def __init__(self, conn: Connection): + self.conn = conn + + def put(self, name: str, content: bytes, content_type: str): + with self.conn: + self.conn.execute( + "INSERT OR IGNORE INTO file (content) VALUES (?)", + (content,), + ) + (content_hash,) = self.conn.execute( + "SELECT DATA_HASH(?)", (content,) + ).fetchone() + cur = self.conn.execute( + """UPDATE link + SET content_type = ?, file_hash = ? + WHERE name_hash = DATA_HASH(?)""", + (content_type, content_hash, name), + ) + if cur.rowcount == 1: + return False, content_hash + self.conn.execute( + """INSERT INTO link ( + name, content_type, file_hash + ) VALUES (?, ?, ?)""", + (name, content_type, content_hash), + ) + return True, content_hash + + def post(self, prefix: str, content: bytes, content_type: str): + with self.conn: + self.conn.execute( + "INSERT OR IGNORE INTO file (content) VALUES (?)", + (content,), + ) + (content_hash,) = self.conn.execute( + "SELECT DATA_HASH(?)", (content,) + ).fetchone() + for _ in range(16): + name = prefix + token_urlsafe(5) + try: + self.conn.execute( + """INSERT INTO link (name, content_type, file_hash) + VALUES (?, ?, ?)""", + (name, content_type, content_hash), + ) + except IntegrityError: + continue + break + else: + raise RuntimeError("Could not insert a link in 16 attempts") + return name, content_hash + + def get(self, name: str): + row = self.conn.execute( + """SELECT link.content_type, file.hash, file.content + FROM link + JOIN file ON file.hash = link.file_hash + WHERE name_hash = DATA_HASH(?)""", + (name,), + ).fetchone() + return row + + def head(self, name: str): + row = self.conn.execute( + """SELECT link.content_type, file.hash, length(file.content) + FROM link + JOIN file ON file.hash = link.file_hash WHERE name_hash = DATA_HASH(?)""", - (content_type, content_hash, name), - ) - if cur.rowcount == 1: - return False, content_hash - conn.execute( - """INSERT INTO link ( - name, content_type, file_hash - ) VALUES (?, ?, ?)""", - (name, content_type, content_hash), - ) - return True, content_hash - - -def post(conn: Connection, prefix: str, content: bytes, content_type: str): - with conn: - conn.execute( - "INSERT OR IGNORE INTO file (content) VALUES (?)", - (content,), - ) - (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone() - for _ in range(16): - name = prefix + token_urlsafe(5) - try: - conn.execute( - """INSERT INTO link (name, content_type, file_hash) - VALUES (?, ?, ?)""", - (name, content_type, content_hash), - ) - except IntegrityError: - continue - break - else: - raise RuntimeError("Could not insert a link in 16 attempts") - return name, content_hash - - -def get(conn: Connection, name: str): - row = conn.execute( - """SELECT link.content_type, file.hash, file.content - FROM link - JOIN file ON file.hash = link.file_hash - WHERE name_hash = DATA_HASH(?)""", - (name,), - ).fetchone() - return row - - -def head(conn: Connection, name: str): - row = conn.execute( - """SELECT link.content_type, file.hash, length(file.content) - FROM link - JOIN file ON file.hash = link.file_hash - WHERE name_hash = DATA_HASH(?)""", - (name,), - ).fetchone() - return row - - -def delete(conn: Connection, name: str): - with conn: - cur = conn.execute("DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,)) - return cur.rowcount == 1 + (name,), + ).fetchone() + return row + + def delete(self, name: str): + with self.conn: + cur = self.conn.execute( + "DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,) + ) + return cur.rowcount == 1 + + def generate_token(self): + token = token_bytes(TOKEN_BYTES) + with self.conn: + self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,)) + return token + + def get_tokens(self): + return self.conn.execute( + "SELECT hash, created_at FROM token ORDER BY created_at" + ).fetchall() + + def delete_token_hash(self, token_hash: bytes): + if len(token_hash) != 256 // 8: + raise ValueError("Invalid token hash") + with self.conn: + cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,)) + if cur.rowcount <= 0: + raise KeyError("Token hash does not exist") + + def delete_token(self, token: bytes): + if len(token) != TOKEN_BYTES: + raise ValueError("Invalid token") + return self.delete_token_hash(sha256(token).digest()) + + def check_token(self, token: bytes): + if len(token) != TOKEN_BYTES: + raise ValueError("Invalid token") + (count,) = self.conn.execute( + "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,) + ).fetchone() + return count > 0 + + @property + def version(self): + return db.get_version(self.conn) + + +@contextmanager +def open(uri: str) -> Iterator[Store]: + with db.connect(uri) as conn: + yield Store(conn) |