From db04b5ad8f914b173806880bf34861c76d0732a9 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Fri, 24 Feb 2023 23:06:30 +0000 Subject: paste.db: Refactor connect --- paste/db.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/paste/db.py b/paste/db.py index 3de42fb..15e4f03 100644 --- a/paste/db.py +++ b/paste/db.py @@ -38,17 +38,14 @@ def _sha256_udf(b: Union[bytes, str]) -> bytes: return sha256(b).digest() -@contextmanager -def connect( - database: str, migrations: list[str] = migrations, **kwargs -) -> Iterator[sqlite3.Connection]: - conn = sqlite3.connect(database, **kwargs) - conn.execute("PRAGMA foreign_keys = ON") - conn.execute("PRAGMA journal_mode = WAL") - conn.row_factory = sqlite3.Row - conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True) +def get_version(conn: sqlite3.Connection) -> int: (user_version,) = conn.execute("PRAGMA user_version").fetchone() - for i in count(user_version + 1): + return user_version + + +def migrate(conn: sqlite3.Connection, migrations: list[str]) -> None: + version = get_version(conn) + for i in count(version + 1): if i - 1 >= len(migrations): break migration = migrations[i - 1] @@ -63,6 +60,18 @@ def connect( raise else: conn.execute("COMMIT TRANSACTION") + + +@contextmanager +def connect( + database: str, migrations: list[str] = migrations, **kwargs +) -> Iterator[sqlite3.Connection]: + conn = sqlite3.connect(database, uri=True, **kwargs) + conn.execute("PRAGMA foreign_keys = ON") + conn.execute("PRAGMA journal_mode = WAL") + conn.row_factory = sqlite3.Row + conn.create_function(name="sha256", narg=1, func=_sha256_udf, deterministic=True) + migrate(conn, migrations) try: yield conn finally: -- cgit v1.2.3-54-g00ecf