diff options
author | Tomasz Kramkowski <tomasz@kramkow.ski> | 2023-03-24 20:22:38 +0000 |
---|---|---|
committer | Tomasz Kramkowski <tomasz@kramkow.ski> | 2023-03-24 20:25:01 +0000 |
commit | d60eaea561d8095cf6f677df0c151ab7e5932e42 (patch) | |
tree | 69fcd0ad3dbd11e56c9e2ceffe7e52d98cc49691 | |
parent | 761231a2a67b2d8817c10f2e5cd7d866005407f1 (diff) | |
download | paste-d60eaea561d8095cf6f677df0c151ab7e5932e42.tar.gz paste-d60eaea561d8095cf6f677df0c151ab7e5932e42.tar.xz paste-d60eaea561d8095cf6f677df0c151ab7e5932e42.zip |
paste.db.connect: add docstrings and doctests
-rw-r--r-- | paste/db.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/paste/db.py b/paste/db.py index fc1d98c..3dff8d6 100644 --- a/paste/db.py +++ b/paste/db.py @@ -45,11 +45,48 @@ def _data_hash_udf(b: Union[bytes, str]) -> bytes: def get_version(connection: sqlite3.Connection) -> int: + """Return the 'user_version' of a SQLite database connection. + + >>> with connect("file::memory:") as conn: + ... get_version(conn) + 1 + """ (user_version,) = connection.execute("PRAGMA user_version").fetchone() return user_version def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None: + """Migrate a SQLite connection. Raise sqlite3.Error upon failure. + Raise RuntimeError upon downgrade attempts. + + connection is a open SQLite3 Connection object. + + migrations is a list of migrations to apply. Each migration must be + a sequence of valid SQLite DDL or DML statements. The value of the + 'user_version' pragma is used to determine which migrations are + executed. Migrations are executed atomically. When a migration is + successfully executed, the 'user_version' is incremented. Migrations + with an index less than the 'user_version' are skipped. When + successful, the 'user_version' will be the same as the length of the + provided migrations list. If the length of the migrations list is + less than user_version, RuntimeError is raised. + + >>> with connect("file::memory:", migrations=None) as conn: + ... get_version(conn) + ... migrate(conn, ["CREATE TABLE t (v)"]) + ... get_version(conn) + ... for row in conn.execute("SELECT type, name FROM sqlite_schema;"): + ... print(f"{row['type']}, {row['name']}") + ... migrate(conn, ["...", "DROP TABLE t; CREATE TABLE u (v);"]) + ... get_version(conn) + ... for row in conn.execute("SELECT type, name FROM sqlite_schema;"): + ... print(f"{row['type']}, {row['name']}") + 0 + 1 + table, t + 2 + table, u + """ version = get_version(connection) if len(migrations) < version: raise RuntimeError( @@ -76,6 +113,33 @@ def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None: def connect( database: str, migrations: Optional[list[str]] = migrations, **kwargs ) -> Iterator[sqlite3.Connection]: + """Return a context manager for a SQLite database connection. Upon entry, + open a connection, enable foreign keys, migrate if necessary, and return + the connection object. Upon exit, close the connection. Raise + sqlite3.Error upon failure. Raise RuntimeError upon downgrade attempts. + + database is treated as a SQLite URI if it starts with 'file:' or a + file path otherwise. + + migrations is an optional list of migration strings to apply. If + provided, they are passed to db.migrate after the connection is + opened but before it is returned. By default, db.migrations are + used. + + >>> with connect("file::memory:") as conn: + ... get_version(conn) + ... _ = conn.execute("INSERT INTO file (content) VALUES (?)", (b"abc",)) + ... for row in conn.execute("SELECT hash, content FROM file").fetchall(): + ... print(f"{row['hash'].hex()}, {row['content']}") + ... + 1 + bddd813c634239723171ef3fee98579b94964e3bb1cb3e427262c8c068d52319, b'abc' + >>> with connect("file::memory:", migrations=None) as conn: + ... get_version(conn) + ... + 0 + """ + conn = sqlite3.connect(database, uri=True, **kwargs) conn.execute("PRAGMA foreign_keys = ON") conn.execute("PRAGMA journal_mode = WAL") @@ -90,3 +154,9 @@ def connect( yield conn finally: conn.close() + + +if __name__ == "__main__": + import doctest + + doctest.testmod() |