import sqlite3 from collections.abc import Iterator from contextlib import contextmanager from hashlib import blake2b from itertools import count from typing import Optional, Union migrations = [ """CREATE TABLE file ( hash BLOB UNIQUE GENERATED ALWAYS AS (DATA_HASH(content)) STORED NOT NULL, content BLOB NOT NULL, created_at INTEGER DEFAULT (unixepoch('now')) ) STRICT; CREATE UNIQUE INDEX file_hash_ix ON file ( hash ); CREATE TABLE link ( name_hash BLOB UNIQUE GENERATED ALWAYS AS (DATA_HASH(name)) STORED NOT NULL, name TEXT NOT NULL, content_type TEXT NOT NULL DEFAULT "text/plain", file_hash BLOB NOT NULL, created_at INTEGER DEFAULT (unixepoch('now')), FOREIGN KEY(file_hash) REFERENCES file(hash) ) STRICT; CREATE UNIQUE INDEX link_name_hash_ix ON link ( name_hash ); CREATE TABLE token ( hash BLOB PRIMARY KEY, created_at INTEGER DEFAULT (unixepoch('now')) ) STRICT, WITHOUT ROWID;""", ] def _sha256_udf(b: Union[bytes, str]) -> bytes: if isinstance(b, str): b = b.encode() return sha256(b).digest() def _data_hash_udf(b: Union[bytes, str]) -> bytes: if isinstance(b, str): b = b.encode() return blake2b(b, digest_size=32).digest() def get_version(connection: sqlite3.Connection) -> int: (user_version,) = connection.execute("PRAGMA user_version").fetchone() return user_version def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None: version = get_version(connection) if len(migrations) < version: raise RuntimeError( f"Attempt to downgrade from v{version} to v{len(migrations)}" ) for i in count(version + 1): if i - 1 >= len(migrations): break migration = migrations[i - 1] try: connection.executescript( "BEGIN IMMEDIATE TRANSACTION;\n" f"{migration};\n" f"PRAGMA user_version = {i:d}" ) except: connection.execute("ROLLBACK TRANSACTION") raise else: connection.execute("COMMIT TRANSACTION") @contextmanager def connect( database: str, migrations: list[str] = migrations, **kwargs ) -> Iterator[sqlite3.Connection]: conn = sqlite3.connect(database, uri=True, **kwargs) conn.execute("PRAGMA foreign_keys = ON") conn.execute("PRAGMA journal_mode = WAL") conn.row_factory = sqlite3.Row conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True) conn.create_function( name="DATA_HASH", narg=1, func=_data_hash_udf, deterministic=True ) migrate(conn, migrations) try: yield conn finally: conn.close()