1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
|
from collections.abc import Iterator
from contextlib import contextmanager
from hashlib import sha256
from secrets import token_bytes, token_urlsafe
from sqlite3 import Connection, IntegrityError
from . import db
TOKEN_BYTES = 96 // 8
class Store:
def __init__(self, conn: Connection):
self.conn = conn
def put(self, name: str, content: bytes, content_type: str):
with self.conn:
self.conn.execute(
"INSERT OR IGNORE INTO file (content) VALUES (?)",
(content,),
)
(content_hash,) = self.conn.execute(
"SELECT DATA_HASH(?)", (content,)
).fetchone()
cur = self.conn.execute(
"""UPDATE link
SET content_type = ?, file_hash = ?
WHERE name_hash = DATA_HASH(?)""",
(content_type, content_hash, name),
)
if cur.rowcount == 1:
return False, content_hash
self.conn.execute(
"""INSERT INTO link (
name, content_type, file_hash
) VALUES (?, ?, ?)""",
(name, content_type, content_hash),
)
return True, content_hash
def post(self, prefix: str, content: bytes, content_type: str):
with self.conn:
self.conn.execute(
"INSERT OR IGNORE INTO file (content) VALUES (?)",
(content,),
)
(content_hash,) = self.conn.execute(
"SELECT DATA_HASH(?)", (content,)
).fetchone()
for _ in range(16):
name = prefix + token_urlsafe(5)
try:
self.conn.execute(
"""INSERT INTO link (name, content_type, file_hash)
VALUES (?, ?, ?)""",
(name, content_type, content_hash),
)
except IntegrityError:
continue
break
else:
raise RuntimeError("Could not insert a link in 16 attempts")
return name, content_hash
def get(self, name: str):
row = self.conn.execute(
"""SELECT link.content_type, file.hash, file.content
FROM link
JOIN file ON file.hash = link.file_hash
WHERE name_hash = DATA_HASH(?)""",
(name,),
).fetchone()
return row
def head(self, name: str):
row = self.conn.execute(
"""SELECT link.content_type, file.hash, length(file.content)
FROM link
JOIN file ON file.hash = link.file_hash
WHERE name_hash = DATA_HASH(?)""",
(name,),
).fetchone()
return row
def delete(self, name: str):
with self.conn:
cur = self.conn.execute(
"DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,)
)
return cur.rowcount == 1
class Auth:
def __init__(self, conn: Connection):
self.conn = conn
def generate_token(self):
token = token_bytes(TOKEN_BYTES)
with self.conn:
self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,))
return token
def get_tokens(self):
return self.conn.execute(
"SELECT hash, created_at FROM token ORDER BY created_at"
).fetchall()
def delete_token_hash(self, token_hash: bytes):
if len(token_hash) != 256 // 8:
raise ValueError("Invalid token hash")
with self.conn:
cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
if cur.rowcount <= 0:
raise KeyError("Token hash does not exist")
def delete_token(self, token: bytes):
if len(token) != TOKEN_BYTES:
raise ValueError("Invalid token")
return self.delete_token_hash(sha256(token).digest())
def check_token(self, token: bytes):
if len(token) != TOKEN_BYTES:
raise ValueError("Invalid token")
(count,) = self.conn.execute(
"SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,)
).fetchone()
return count > 0
|