aboutsummaryrefslogtreecommitdiffstats
path: root/paste/db.py
blob: 15e4f03a6e0615b9f183d3444b91bf8042d1b273 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from contextlib import contextmanager
from hashlib import sha256
from itertools import count
from collections.abc import Iterator
from typing import Union
import sqlite3

migrations = [
    """CREATE TABLE file (
    hash BLOB UNIQUE GENERATED ALWAYS AS (sha256(content)) STORED NOT NULL,
    content BLOB NOT NULL,
    created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT;

CREATE UNIQUE INDEX file_hash_ix ON file ( hash );

CREATE TABLE link (
    name_hash BLOB UNIQUE GENERATED ALWAYS AS (sha256(name)) STORED NOT NULL,
    name TEXT NOT NULL,
    content_type TEXT NOT NULL DEFAULT "text/plain",
    file_hash BLOB NOT NULL,
    created_at INTEGER DEFAULT (unixepoch('now')),
    FOREIGN KEY(file_hash) REFERENCES file(hash)
) STRICT;

CREATE UNIQUE INDEX link_name_hash_ix ON link ( name_hash );

CREATE TABLE token (
    hash BLOB PRIMARY KEY,
    created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT, WITHOUT ROWID;""",
]


def _sha256_udf(b: Union[bytes, str]) -> bytes:
    if isinstance(b, str):
        b = b.encode()
    return sha256(b).digest()


def get_version(conn: sqlite3.Connection) -> int:
    (user_version,) = conn.execute("PRAGMA user_version").fetchone()
    return user_version


def migrate(conn: sqlite3.Connection, migrations: list[str]) -> None:
    version = get_version(conn)
    for i in count(version + 1):
        if i - 1 >= len(migrations):
            break
        migration = migrations[i - 1]
        try:
            conn.executescript(
                "BEGIN IMMEDIATE TRANSACTION;\n"
                f"{migration}\n"
                f"PRAGMA user_version = {i:d}"
            )
        except:
            conn.execute("ROLLBACK TRANSACTION")
            raise
        else:
            conn.execute("COMMIT TRANSACTION")


@contextmanager
def connect(
    database: str, migrations: list[str] = migrations, **kwargs
) -> Iterator[sqlite3.Connection]:
    conn = sqlite3.connect(database, uri=True, **kwargs)
    conn.execute("PRAGMA foreign_keys = ON")
    conn.execute("PRAGMA journal_mode = WAL")
    conn.row_factory = sqlite3.Row
    conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True)
    migrate(conn, migrations)
    try:
        yield conn
    finally:
        conn.close()