aboutsummaryrefslogtreecommitdiffstats
path: root/paste/store.py
blob: f69c03df00d89f6a4dc8539a24b1c3294f258e0b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from collections.abc import Iterator
from contextlib import contextmanager
from hashlib import sha256
from secrets import token_bytes, token_urlsafe
from sqlite3 import Connection, IntegrityError

from . import db

TOKEN_BYTES = 96 // 8


class Store:
    def __init__(self, conn: Connection):
        self.conn = conn

    def put(self, name: str, content: bytes, content_type: str):
        with self.conn:
            self.conn.execute(
                "INSERT OR IGNORE INTO file (content) VALUES (?)",
                (content,),
            )
            (content_hash,) = self.conn.execute(
                "SELECT DATA_HASH(?)", (content,)
            ).fetchone()
            cur = self.conn.execute(
                """UPDATE link
                SET content_type = ?, file_hash = ?
                WHERE name_hash = DATA_HASH(?)""",
                (content_type, content_hash, name),
            )
            if cur.rowcount == 1:
                return False, content_hash
            self.conn.execute(
                """INSERT INTO link (
                    name, content_type, file_hash
                ) VALUES (?, ?, ?)""",
                (name, content_type, content_hash),
            )
            return True, content_hash

    def post(self, prefix: str, content: bytes, content_type: str):
        with self.conn:
            self.conn.execute(
                "INSERT OR IGNORE INTO file (content) VALUES (?)",
                (content,),
            )
            (content_hash,) = self.conn.execute(
                "SELECT DATA_HASH(?)", (content,)
            ).fetchone()
            for _ in range(16):
                name = prefix + token_urlsafe(5)
                try:
                    self.conn.execute(
                        """INSERT INTO link (name, content_type, file_hash)
                            VALUES (?, ?, ?)""",
                        (name, content_type, content_hash),
                    )
                except IntegrityError:
                    continue
                break
            else:
                raise RuntimeError("Could not insert a link in 16 attempts")
            return name, content_hash

    def get(self, name: str):
        row = self.conn.execute(
            """SELECT link.content_type, file.hash, file.content
            FROM link
            JOIN file ON file.hash = link.file_hash
            WHERE name_hash = DATA_HASH(?)""",
            (name,),
        ).fetchone()
        return row

    def head(self, name: str):
        row = self.conn.execute(
            """SELECT link.content_type, file.hash, length(file.content)
            FROM link
            JOIN file ON file.hash = link.file_hash
            WHERE name_hash = DATA_HASH(?)""",
            (name,),
        ).fetchone()
        return row

    def delete(self, name: str):
        with self.conn:
            cur = self.conn.execute(
                "DELETE FROM link WHERE name_hash = DATA_HASH(?)", (name,)
            )
            return cur.rowcount == 1

    def generate_token(self):
        token = token_bytes(TOKEN_BYTES)
        with self.conn:
            self.conn.execute("INSERT INTO token (hash) VALUES (SHA256(?))", (token,))
        return token

    def get_tokens(self):
        return self.conn.execute(
            "SELECT hash, created_at FROM token ORDER BY created_at"
        ).fetchall()

    def delete_token_hash(self, token_hash: bytes):
        if len(token_hash) != 256 // 8:
            raise ValueError("Invalid token hash")
        with self.conn:
            cur = self.conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
        if cur.rowcount <= 0:
            raise KeyError("Token hash does not exist")

    def delete_token(self, token: bytes):
        if len(token) != TOKEN_BYTES:
            raise ValueError("Invalid token")
        return self.delete_token_hash(sha256(token).digest())

    def check_token(self, token: bytes):
        if len(token) != TOKEN_BYTES:
            raise ValueError("Invalid token")
        (count,) = self.conn.execute(
            "SELECT COUNT(*) FROM token WHERE hash = SHA256(?)", (token,)
        ).fetchone()
        return count > 0