from collections.abc import Iterator from contextlib import contextmanager from hashlib import sha256 from secrets import token_bytes, token_urlsafe from sqlite3 import Connection, IntegrityError from . import db TOKEN_BYTES = 96 // 8 class Store: def __init__(self, conn: Connection): self.conn = conn def put(self, name: str, content: bytes, content_type: str): with self.conn: self.conn.execute( "INSERT OR IGNORE INTO file (content) VALUES (?)", (content,), ) (content_hash,) = self.conn.execute( "SELECT DATA_HASH(?)", (content,) ).fetchone() cur = self.conn.execute( """UPDATE link SET content_type = ?, file_hash = ? WHERE name_hash = DATA_HASH(?)""", (content_type, content_hash, name), ) if cur.rowcount == 1: return False, content_hash self.conn.execute( """INSERT INTO link ( name, content_type, file_hash ) VALUES (?, ?, ?)""", (name, content_type, content_hash), ) return True, content_hash def post(self, prefix: str, content: bytes, content_type: str): with self.conn: self.conn.execute( "INSERT OR IGNORE INTO file (content) VALUES (?)", (content,), ) (content_hash,) = self.conn.execute( "SELECT DATA_HASH(?)", (content,) ).fetchone() for _ in range(16): name = prefix + token_urlsafe(5) try: self.conn.execute( """INSERT INTO link (name, content_type, file_hash) VALUES (?, ?, ?)""", (name, content_type, content_hash), ) except IntegrityError: continue break else: raise RuntimeError("Could not insert a link in 16 attempts") return name, content_hash def get(self, name: str): row = self.conn.execute( """SELECT link.content_type, file.hash, file.content FROM link JOIN file ON file.hash = link.file_hash WHERE name_hash = DATA_HASH(?)""", (name,), ).fetchone() return row def head(self, name: str): row = self.conn.execute( """SELECT link.content_type, file.hash, length(file.content) FROM link JOIN file ON file.hash = link.file_hash WHERE name_hash = DATA_HASH(?)""", (name,), ).fetchone() return row def delete(self, name: str): with self.conn: cur = self.conn.execute( "DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,) ) return cur.rowcount == 1 class Auth: def __init__(self, conn: Connection): self.conn = conn def generate_token(self): token = token_bytes(TOKEN_BYTES) with self.conn: self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,)) return token def get_tokens(self): return self.conn.execute( "SELECT hash, created_at FROM token ORDER BY created_at" ).fetchall() def delete_token_hash(self, token_hash: bytes): if len(token_hash) != 256 // 8: raise ValueError("Invalid token hash") with self.conn: cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,)) if cur.rowcount <= 0: raise KeyError("Token hash does not exist") def delete_token(self, token: bytes): if len(token) != TOKEN_BYTES: raise ValueError("Invalid token") return self.delete_token_hash(sha256(token).digest()) def check_token(self, token: bytes): if len(token) != TOKEN_BYTES: raise ValueError("Invalid token") (count,) = self.conn.execute( "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,) ).fetchone() return count > 0