aboutsummaryrefslogtreecommitdiffstats
path: root/paste/db.py
blob: 25e85ce0f0f1a6ed92d47fdb61d711c34a34d40d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from hashlib import blake2b
from itertools import count
from typing import Optional, Union

migrations = [
    """CREATE TABLE file (
    hash BLOB UNIQUE GENERATED ALWAYS AS (DATA_HASH(content)) STORED NOT NULL,
    content BLOB NOT NULL,
    created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT;

CREATE UNIQUE INDEX file_hash_ix ON file ( hash );

CREATE TABLE link (
    name_hash BLOB UNIQUE GENERATED ALWAYS AS (DATA_HASH(name)) STORED NOT NULL,
    name TEXT NOT NULL,
    content_type TEXT NOT NULL DEFAULT "text/plain",
    file_hash BLOB NOT NULL,
    created_at INTEGER DEFAULT (unixepoch('now')),
    FOREIGN KEY(file_hash) REFERENCES file(hash)
) STRICT;

CREATE UNIQUE INDEX link_name_hash_ix ON link ( name_hash );

CREATE TABLE token (
    hash BLOB PRIMARY KEY,
    created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT, WITHOUT ROWID;""",
]


def _sha256_udf(b: Union[bytes, str]) -> bytes:
    if isinstance(b, str):
        b = b.encode()
    return sha256(b).digest()


def _data_hash_udf(b: Union[bytes, str]) -> bytes:
    if isinstance(b, str):
        b = b.encode()
    return blake2b(b, digest_size=32).digest()


def get_version(connection: sqlite3.Connection) -> int:
    (user_version,) = connection.execute("PRAGMA user_version").fetchone()
    return user_version


def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None:
    version = get_version(connection)
    if len(migrations) < version:
        raise RuntimeError(
            f"Attempt to downgrade from v{version} to v{len(migrations)}"
        )
    for i in count(version + 1):
        if i - 1 >= len(migrations):
            break
        migration = migrations[i - 1]
        try:
            connection.executescript(
                "BEGIN IMMEDIATE TRANSACTION;\n"
                f"{migration}\n"
                f"PRAGMA user_version = {i:d}"
            )
        except:
            connection.execute("ROLLBACK TRANSACTION")
            raise
        else:
            connection.execute("COMMIT TRANSACTION")


@contextmanager
def connect(
    database: str, migrations: list[str] = migrations, **kwargs
) -> Iterator[sqlite3.Connection]:
    conn = sqlite3.connect(database, uri=True, **kwargs)
    conn.execute("PRAGMA foreign_keys = ON")
    conn.execute("PRAGMA journal_mode = WAL")
    conn.row_factory = sqlite3.Row
    conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True)
    conn.create_function(
        name="DATA_HASH", narg=1, func=_data_hash_udf, deterministic=True
    )
    migrate(conn, migrations)
    try:
        yield conn
    finally:
        conn.close()