aboutsummaryrefslogtreecommitdiffstats
path: root/paste/store.py
diff options
context:
space:
mode:
Diffstat (limited to 'paste/store.py')
-rw-r--r--paste/store.py201
1 files changed, 128 insertions, 73 deletions
diff --git a/paste/store.py b/paste/store.py
index dd00edd..b7e471d 100644
--- a/paste/store.py
+++ b/paste/store.py
@@ -1,77 +1,132 @@
-from secrets import token_urlsafe
+from collections.abc import Iterator
+from contextlib import contextmanager
+from hashlib import sha256
+from secrets import token_bytes, token_urlsafe
from sqlite3 import Connection, IntegrityError
+from . import db
-def put(conn: Connection, name: str, content: bytes, content_type: str):
- with conn:
- conn.execute(
- "INSERT OR IGNORE INTO file (content) VALUES (?)",
- (content,),
- )
- (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone()
- cur = conn.execute(
- """UPDATE link
- SET content_type = ?, file_hash = ?
+TOKEN_BYTES = 96 // 8
+
+
+class Store:
+ def __init__(self, conn: Connection):
+ self.conn = conn
+
+ def put(self, name: str, content: bytes, content_type: str):
+ with self.conn:
+ self.conn.execute(
+ "INSERT OR IGNORE INTO file (content) VALUES (?)",
+ (content,),
+ )
+ (content_hash,) = self.conn.execute(
+ "SELECT DATA_HASH(?)", (content,)
+ ).fetchone()
+ cur = self.conn.execute(
+ """UPDATE link
+ SET content_type = ?, file_hash = ?
+ WHERE name_hash = DATA_HASH(?)""",
+ (content_type, content_hash, name),
+ )
+ if cur.rowcount == 1:
+ return False, content_hash
+ self.conn.execute(
+ """INSERT INTO link (
+ name, content_type, file_hash
+ ) VALUES (?, ?, ?)""",
+ (name, content_type, content_hash),
+ )
+ return True, content_hash
+
+ def post(self, prefix: str, content: bytes, content_type: str):
+ with self.conn:
+ self.conn.execute(
+ "INSERT OR IGNORE INTO file (content) VALUES (?)",
+ (content,),
+ )
+ (content_hash,) = self.conn.execute(
+ "SELECT DATA_HASH(?)", (content,)
+ ).fetchone()
+ for _ in range(16):
+ name = prefix + token_urlsafe(5)
+ try:
+ self.conn.execute(
+ """INSERT INTO link (name, content_type, file_hash)
+ VALUES (?, ?, ?)""",
+ (name, content_type, content_hash),
+ )
+ except IntegrityError:
+ continue
+ break
+ else:
+ raise RuntimeError("Could not insert a link in 16 attempts")
+ return name, content_hash
+
+ def get(self, name: str):
+ row = self.conn.execute(
+ """SELECT link.content_type, file.hash, file.content
+ FROM link
+ JOIN file ON file.hash = link.file_hash
+ WHERE name_hash = DATA_HASH(?)""",
+ (name,),
+ ).fetchone()
+ return row
+
+ def head(self, name: str):
+ row = self.conn.execute(
+ """SELECT link.content_type, file.hash, length(file.content)
+ FROM link
+ JOIN file ON file.hash = link.file_hash
WHERE name_hash = DATA_HASH(?)""",
- (content_type, content_hash, name),
- )
- if cur.rowcount == 1:
- return False, content_hash
- conn.execute(
- """INSERT INTO link (
- name, content_type, file_hash
- ) VALUES (?, ?, ?)""",
- (name, content_type, content_hash),
- )
- return True, content_hash
-
-
-def post(conn: Connection, prefix: str, content: bytes, content_type: str):
- with conn:
- conn.execute(
- "INSERT OR IGNORE INTO file (content) VALUES (?)",
- (content,),
- )
- (content_hash,) = conn.execute("SELECT DATA_HASH(?)", (content,)).fetchone()
- for _ in range(16):
- name = prefix + token_urlsafe(5)
- try:
- conn.execute(
- """INSERT INTO link (name, content_type, file_hash)
- VALUES (?, ?, ?)""",
- (name, content_type, content_hash),
- )
- except IntegrityError:
- continue
- break
- else:
- raise RuntimeError("Could not insert a link in 16 attempts")
- return name, content_hash
-
-
-def get(conn: Connection, name: str):
- row = conn.execute(
- """SELECT link.content_type, file.hash, file.content
- FROM link
- JOIN file ON file.hash = link.file_hash
- WHERE name_hash = DATA_HASH(?)""",
- (name,),
- ).fetchone()
- return row
-
-
-def head(conn: Connection, name: str):
- row = conn.execute(
- """SELECT link.content_type, file.hash, length(file.content)
- FROM link
- JOIN file ON file.hash = link.file_hash
- WHERE name_hash = DATA_HASH(?)""",
- (name,),
- ).fetchone()
- return row
-
-
-def delete(conn: Connection, name: str):
- with conn:
- cur = conn.execute("DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,))
- return cur.rowcount == 1
+ (name,),
+ ).fetchone()
+ return row
+
+ def delete(self, name: str):
+ with self.conn:
+ cur = self.conn.execute(
+ "DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,)
+ )
+ return cur.rowcount == 1
+
+ def generate_token(self):
+ token = token_bytes(TOKEN_BYTES)
+ with self.conn:
+ self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,))
+ return token
+
+ def get_tokens(self):
+ return self.conn.execute(
+ "SELECT hash, created_at FROM token ORDER BY created_at"
+ ).fetchall()
+
+ def delete_token_hash(self, token_hash: bytes):
+ if len(token_hash) != 256 // 8:
+ raise ValueError("Invalid token hash")
+ with self.conn:
+ cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
+ if cur.rowcount <= 0:
+ raise KeyError("Token hash does not exist")
+
+ def delete_token(self, token: bytes):
+ if len(token) != TOKEN_BYTES:
+ raise ValueError("Invalid token")
+ return self.delete_token_hash(sha256(token).digest())
+
+ def check_token(self, token: bytes):
+ if len(token) != TOKEN_BYTES:
+ raise ValueError("Invalid token")
+ (count,) = self.conn.execute(
+ "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,)
+ ).fetchone()
+ return count > 0
+
+ @property
+ def version(self):
+ return db.get_version(self.conn)
+
+
+@contextmanager
+def open(uri: str) -> Iterator[Store]:
+ with db.connect(uri) as conn:
+ yield Store(conn)