1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
from contextlib import contextmanager
from hashlib import sha256
from itertools import count
from collections.abc import Iterator
from typing import Union
import sqlite3
migrations = [
"""CREATE TABLE file (
hash BLOB UNIQUE GENERATED ALWAYS AS (sha256(content)) STORED NOT NULL,
content BLOB NOT NULL,
created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT;
CREATE UNIQUE INDEX file_hash_ix ON file ( hash );
CREATE TABLE link (
name_hash BLOB UNIQUE GENERATED ALWAYS AS (sha256(name)) STORED NOT NULL,
name TEXT NOT NULL,
content_type TEXT NOT NULL DEFAULT "text/plain",
file_hash BLOB NOT NULL,
created_at INTEGER DEFAULT (unixepoch('now')),
FOREIGN KEY(file_hash) REFERENCES file(hash)
) STRICT;
CREATE UNIQUE INDEX link_name_hash_ix ON link ( name_hash );
CREATE TABLE token (
hash BLOB PRIMARY KEY,
created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT, WITHOUT ROWID;""",
]
def _sha256_udf(b: Union[bytes, str]) -> bytes:
if isinstance(b, str):
b = b.encode()
return sha256(b).digest()
@contextmanager
def connect(
database: str, migrations: list[str] = migrations, **kwargs
) -> Iterator[sqlite3.Connection]:
conn = sqlite3.connect(database, **kwargs)
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL")
conn.row_factory = sqlite3.Row
conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True)
(user_version,) = conn.execute("PRAGMA user_version").fetchone()
for i in count(user_version + 1):
if i - 1 >= len(migrations):
break
migration = migrations[i - 1]
try:
conn.executescript(
"BEGIN IMMEDIATE TRANSACTION;\n"
f"{migration}\n"
f"PRAGMA user_version = {i:d}"
)
except:
conn.execute("ROLLBACK TRANSACTION")
raise
else:
conn.execute("COMMIT TRANSACTION")
try:
yield conn
finally:
conn.close()
|