aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTomasz Kramkowski <tomasz@kramkow.ski>2023-03-24 20:22:38 +0000
committerTomasz Kramkowski <tomasz@kramkow.ski>2023-03-24 20:25:01 +0000
commitd60eaea561d8095cf6f677df0c151ab7e5932e42 (patch)
tree69fcd0ad3dbd11e56c9e2ceffe7e52d98cc49691
parent761231a2a67b2d8817c10f2e5cd7d866005407f1 (diff)
downloadpaste-d60eaea561d8095cf6f677df0c151ab7e5932e42.tar.gz
paste-d60eaea561d8095cf6f677df0c151ab7e5932e42.tar.xz
paste-d60eaea561d8095cf6f677df0c151ab7e5932e42.zip
paste.db.connect: add docstrings and doctests
-rw-r--r--paste/db.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/paste/db.py b/paste/db.py
index fc1d98c..3dff8d6 100644
--- a/paste/db.py
+++ b/paste/db.py
@@ -45,11 +45,48 @@ def _data_hash_udf(b: Union[bytes, str]) -> bytes:
def get_version(connection: sqlite3.Connection) -> int:
+ """Return the 'user_version' of a SQLite database connection.
+
+ >>> with connect("file::memory:") as conn:
+ ... get_version(conn)
+ 1
+ """
(user_version,) = connection.execute("PRAGMA user_version").fetchone()
return user_version
def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None:
+ """Migrate a SQLite connection. Raise sqlite3.Error upon failure.
+ Raise RuntimeError upon downgrade attempts.
+
+ connection is a open SQLite3 Connection object.
+
+ migrations is a list of migrations to apply. Each migration must be
+ a sequence of valid SQLite DDL or DML statements. The value of the
+ 'user_version' pragma is used to determine which migrations are
+ executed. Migrations are executed atomically. When a migration is
+ successfully executed, the 'user_version' is incremented. Migrations
+ with an index less than the 'user_version' are skipped. When
+ successful, the 'user_version' will be the same as the length of the
+ provided migrations list. If the length of the migrations list is
+ less than user_version, RuntimeError is raised.
+
+ >>> with connect("file::memory:", migrations=None) as conn:
+ ... get_version(conn)
+ ... migrate(conn, ["CREATE TABLE t (v)"])
+ ... get_version(conn)
+ ... for row in conn.execute("SELECT type, name FROM sqlite_schema;"):
+ ... print(f"{row['type']}, {row['name']}")
+ ... migrate(conn, ["...", "DROP TABLE t; CREATE TABLE u (v);"])
+ ... get_version(conn)
+ ... for row in conn.execute("SELECT type, name FROM sqlite_schema;"):
+ ... print(f"{row['type']}, {row['name']}")
+ 0
+ 1
+ table, t
+ 2
+ table, u
+ """
version = get_version(connection)
if len(migrations) < version:
raise RuntimeError(
@@ -76,6 +113,33 @@ def migrate(connection: sqlite3.Connection, migrations: list[str]) -> None:
def connect(
database: str, migrations: Optional[list[str]] = migrations, **kwargs
) -> Iterator[sqlite3.Connection]:
+ """Return a context manager for a SQLite database connection. Upon entry,
+ open a connection, enable foreign keys, migrate if necessary, and return
+ the connection object. Upon exit, close the connection. Raise
+ sqlite3.Error upon failure. Raise RuntimeError upon downgrade attempts.
+
+ database is treated as a SQLite URI if it starts with 'file:' or a
+ file path otherwise.
+
+ migrations is an optional list of migration strings to apply. If
+ provided, they are passed to db.migrate after the connection is
+ opened but before it is returned. By default, db.migrations are
+ used.
+
+ >>> with connect("file::memory:") as conn:
+ ... get_version(conn)
+ ... _ = conn.execute("INSERT INTO file (content) VALUES (?)", (b"abc",))
+ ... for row in conn.execute("SELECT hash, content FROM file").fetchall():
+ ... print(f"{row['hash'].hex()}, {row['content']}")
+ ...
+ 1
+ bddd813c634239723171ef3fee98579b94964e3bb1cb3e427262c8c068d52319, b'abc'
+ >>> with connect("file::memory:", migrations=None) as conn:
+ ... get_version(conn)
+ ...
+ 0
+ """
+
conn = sqlite3.connect(database, uri=True, **kwargs)
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL")
@@ -90,3 +154,9 @@ def connect(
yield conn
finally:
conn.close()
+
+
+if __name__ == "__main__":
+ import doctest
+
+ doctest.testmod()