diff options
-rw-r--r-- | paste/db.py | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/paste/db.py b/paste/db.py index 3de42fb..15e4f03 100644 --- a/paste/db.py +++ b/paste/db.py @@ -38,17 +38,14 @@ def _sha256_udf(b: Union[bytes, str]) -> bytes: return sha256(b).digest() -@contextmanager -def connect( - database: str, migrations: list[str] = migrations, **kwargs -) -> Iterator[sqlite3.Connection]: - conn = sqlite3.connect(database, **kwargs) - conn.execute("PRAGMA foreign_keys = ON") - conn.execute("PRAGMA journal_mode = WAL") - conn.row_factory = sqlite3.Row - conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True) +def get_version(conn: sqlite3.Connection) -> int: (user_version,) = conn.execute("PRAGMA user_version").fetchone() - for i in count(user_version + 1): + return user_version + + +def migrate(conn: sqlite3.Connection, migrations: list[str]) -> None: + version = get_version(conn) + for i in count(version + 1): if i - 1 >= len(migrations): break migration = migrations[i - 1] @@ -63,6 +60,18 @@ def connect( raise else: conn.execute("COMMIT TRANSACTION") + + +@contextmanager +def connect( + database: str, migrations: list[str] = migrations, **kwargs +) -> Iterator[sqlite3.Connection]: + conn = sqlite3.connect(database, uri=True, **kwargs) + conn.execute("PRAGMA foreign_keys = ON") + conn.execute("PRAGMA journal_mode = WAL") + conn.row_factory = sqlite3.Row + conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True) + migrate(conn, migrations) try: yield conn finally: |