aboutsummaryrefslogtreecommitdiffstats
path: root/paste/db.py
blob: c4ee46ee22c336c655e015198f920c52b07a9f17 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from hashlib import blake2b, sha256
from itertools import count
from typing import Optional, Union

migrations = [
    """CREATE TABLE file (
    hash BLOB UNIQUE GENERATED ALWAYS AS (DATA_HASH(content)) STORED NOT NULL,
    content BLOB NOT NULL,
    created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT;

CREATE UNIQUE INDEX file_hash_ix ON file ( hash );

CREATE TABLE link (
    name_hash BLOB UNIQUE GENERATED ALWAYS AS (DATA_HASH(name)) STORED NOT NULL,
    name TEXT NOT NULL,
    content_type TEXT NOT NULL DEFAULT "text/plain",
    file_hash BLOB NOT NULL,
    created_at INTEGER DEFAULT (unixepoch('now')),
    FOREIGN KEY(file_hash) REFERENCES file(hash)
) STRICT;

CREATE UNIQUE INDEX link_name_hash_ix ON link ( name_hash );

CREATE TABLE token (
    hash BLOB PRIMARY KEY,
    created_at INTEGER DEFAULT (unixepoch('now'))
) STRICT, WITHOUT ROWID;""",
]


def _sha256_udf(b: Union[bytes, str]) -> bytes:
    if isinstance(b, str):
        b = b.encode()
    return sha256(b).digest()


def _data_hash_udf(b: Union[bytes, str]) -> bytes:
    if isinstance(b, str):
        b = b.encode()
    return blake2b(b, digest_size=32).digest()


def get_version(connection: sqlite3.Connection) -> int:
    """Return the 'user_version' of a SQLite database connection.

    >>> with connect("file::memory:") as conn:
    ...     get_version(conn)
    1
    """
    (user_version,) = connection.execute("PRAGMA user_version").fetchone()
    return user_version


def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None:
    """Migrate a SQLite connection.  Raise sqlite3.Error upon failure.
    Raise RuntimeError upon downgrade attempts.

    connection is a open SQLite3 Connection object.

    migrations is a list of migrations to apply.  Each migration must be
    a sequence of valid SQLite DDL or DML statements.  The value of the
    'user_version' pragma is used to determine which migrations are
    executed.  Migrations are executed atomically.  When a migration is
    successfully executed, the 'user_version' is incremented.  Migrations
    with an index less than the 'user_version' are skipped.  When
    successful, the 'user_version' will be the same as the length of the
    provided migrations list.  If the length of the migrations list is
    less than user_version, RuntimeError is raised.

    >>> with connect("file::memory:", migrations=None) as conn:
    ...     get_version(conn)
    ...     migrate(conn, ["CREATE TABLE t (v)"])
    ...     get_version(conn)
    ...     for row in conn.execute("SELECT type, name FROM sqlite_schema;"):
    ...         print(f"{row['type']}, {row['name']}")
    ...     migrate(conn, ["...", "DROP TABLE t; CREATE TABLE u (v);"])
    ...     get_version(conn)
    ...     for row in conn.execute("SELECT type, name FROM sqlite_schema;"):
    ...         print(f"{row['type']}, {row['name']}")
    0
    1
    table, t
    2
    table, u
    """
    version = get_version(connection)
    if len(migrations) < version:
        raise RuntimeError(
            f"Attempt to downgrade from v{version} to v{len(migrations)}"
        )
    for i in count(version + 1):
        if i - 1 >= len(migrations):
            break
        migration = migrations[i - 1]
        try:
            connection.executescript(
                "BEGIN IMMEDIATE TRANSACTION;\n"
                f"{migration};\n"
                f"PRAGMA user_version = {i:d}"
            )
        except:
            connection.execute("ROLLBACK TRANSACTION")
            raise
        else:
            connection.execute("COMMIT TRANSACTION")


@contextmanager
def connect(
    database: str, migrations: Optional[list[str]] = migrations, **kwargs
) -> Iterator[sqlite3.Connection]:
    """Return a context manager for a SQLite database connection.  Upon entry,
    open a connection, enable foreign keys, migrate if necessary, and return
    the connection object.  Upon exit, close the connection.  Raise
    sqlite3.Error upon failure. Raise RuntimeError upon downgrade attempts.

    database is treated as a SQLite URI if it starts with 'file:' or a
    file path otherwise.

    migrations is an optional list of migration strings to apply.  If
    provided, they are passed to db.migrate after the connection is
    opened but before it is returned.  By default, db.migrations are
    used.

    >>> with connect("file::memory:") as conn:
    ...     get_version(conn)
    ...     _ = conn.execute("INSERT INTO file (content) VALUES (?)", (b"abc",))
    ...     for row in conn.execute("SELECT hash, content FROM file").fetchall():
    ...         print(f"{row['hash'].hex()}, {row['content']}")
    ...
    1
    bddd813c634239723171ef3fee98579b94964e3bb1cb3e427262c8c068d52319, b'abc'
    >>> with connect("file::memory:", migrations=None) as conn:
    ...     get_version(conn)
    ...
    0
    """

    conn = sqlite3.connect(database, uri=True, **kwargs)
    conn.execute("PRAGMA foreign_keys = ON")
    conn.execute("PRAGMA journal_mode = WAL")
    conn.row_factory = sqlite3.Row
    conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True)
    conn.create_function(
        name="DATA_HASH", narg=1, func=_data_hash_udf, deterministic=True
    )
    if migrations is not None:
        migrate(conn, migrations)
    try:
        yield conn
    finally:
        conn.close()


if __name__ == "__main__":
    import doctest

    doctest.testmod()