aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--paste/db.py29
1 files changed, 19 insertions, 10 deletions
diff --git a/paste/db.py b/paste/db.py
index 3de42fb..15e4f03 100644
--- a/paste/db.py
+++ b/paste/db.py
@@ -38,17 +38,14 @@ def _sha256_udf(b: Union[bytes, str]) -> bytes:
return sha256(b).digest()
-@contextmanager
-def connect(
- database: str, migrations: list[str] = migrations, **kwargs
-) -> Iterator[sqlite3.Connection]:
- conn = sqlite3.connect(database, **kwargs)
- conn.execute("PRAGMA foreign_keys = ON")
- conn.execute("PRAGMA journal_mode = WAL")
- conn.row_factory = sqlite3.Row
- conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True)
+def get_version(conn: sqlite3.Connection) -> int:
(user_version,) = conn.execute("PRAGMA user_version").fetchone()
- for i in count(user_version + 1):
+ return user_version
+
+
+def migrate(conn: sqlite3.Connection, migrations: list[str]) -> None:
+ version = get_version(conn)
+ for i in count(version + 1):
if i - 1 >= len(migrations):
break
migration = migrations[i - 1]
@@ -63,6 +60,18 @@ def connect(
raise
else:
conn.execute("COMMIT TRANSACTION")
+
+
+@contextmanager
+def connect(
+ database: str, migrations: list[str] = migrations, **kwargs
+) -> Iterator[sqlite3.Connection]:
+ conn = sqlite3.connect(database, uri=True, **kwargs)
+ conn.execute("PRAGMA foreign_keys = ON")
+ conn.execute("PRAGMA journal_mode = WAL")
+ conn.row_factory = sqlite3.Row
+ conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True)
+ migrate(conn, migrations)
try:
yield conn
finally: