aboutsummaryrefslogtreecommitdiffstats
path: root/paste/__main__.py
blob: f1aa117d9e9e39bd7aac652e35ba4701da709a09 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from base64 import b64encode, b64decode
from datetime import datetime, timezone
from hashlib import sha256
from os import getenv
from secrets import token_bytes
from sys import argv, stderr
from wsgiref.simple_server import make_server
from sqlite3 import Connection
from contextlib import AbstractContextManager

from . import application, DB_PATH
from . import db

TOKEN_BYTES = 96 // 8
PROGRAM_NAME = "paste"


def print_usage() -> None:
    print(
        f"Usage: {PROGRAM_NAME} [-h|serve|new-token|list-tokens|delete-token]",
        file=stderr,
    )


def print_help() -> None:
    print(
        f"""A simple WSGI paste site

Commands:
    serve         (default) Start a basic HTTP server (do NOT use in production)
    new-token     Generate a new access token
    list-tokens   List access token hashes (sha256 hash, created at)
    check-token   Check if a token is valid and exists in the database
    delete-token  Delete a token (specify token or hash)

Options:
    -h  Show this help

Environment:
    PASTE_HOST  The HTTP server host (default: localhost)
    PASTE_PORT  The HTTP server port (default: 8080)
    PASTE_DB    The path to the sqlite3 database (default: {DB_PATH})""",
        file=stderr,
    )


def generate_token(conn: Connection):
    token = token_bytes(TOKEN_BYTES)
    token_hash = sha256(token).digest()
    with conn:
        conn.execute("INSERT INTO token (hash) VALUES (?)", (token_hash,))
    return token


def get_tokens(conn: Connection):
    return conn.execute(
        "SELECT hash, created_at FROM token ORDER BY created_at"
    ).fetchall()


def delete_token_hash(conn: Connection, token_hash: bytes):
    if len(token_hash) != 256 // 8:
        raise ValueError("Invalid token hash")
    with conn:
        cur = conn.execute("DELETE FROM token WHERE hash = ?", (token_hash,))
    if cur.rowcount <= 0:
        raise KeyError("Token hash does not exist")


def delete_token(conn: Connection, token: bytes):
    if len(token) != TOKEN_BYTES:
        raise ValueError("Invalid token")
    return delete_token_hash(conn, sha256(token).digest())


def check_token(conn: Connection, token: bytes):
    if len(token) != TOKEN_BYTES:
        raise ValueError("Invalid token")
    token_hash = sha256(token).digest()
    (count,) = conn.execute(
        "SELECT COUNT(*) FROM token WHERE hash = ?", (token_hash,)
    ).fetchone()
    return count > 0


db_path = getenv("PASTE_DB", DB_PATH)


def connect() -> AbstractContextManager[Connection]:
    return db.connect(db_path)


def main():
    if len(argv) <= 1 or len(argv) == 2 and argv[1] == "serve":
        host = getenv("PASTE_HOST", "localhost")
        port = int(getenv("PASTE_PORT", "8080"))
        httpd = make_server(host, port, application)
        httpd.serve_forever()
    elif len(argv) == 2:
        if argv[1] == "new-token":
            with connect() as conn:
                print(b64encode(generate_token(conn)).decode())
        elif argv[1] == "list-tokens":
            with connect() as conn:
                for token_hash, created_at in get_tokens(conn):
                    created_at = datetime.fromtimestamp(created_at, timezone.utc)
                    print(f"{token_hash.hex()}\t{created_at.ctime()}")
        elif argv[1] == "-h" or argv[1] == "--help":
            print_usage()
            print_help()
        else:
            print_usage()
            exit(1)
    elif len(argv) == 3:
        if argv[1] == "delete-token":
            token = argv[2]
            with connect() as conn:
                try:
                    try:
                        delete_token(conn, b64decode(token))
                    except ValueError:
                        delete_token_hash(conn, bytes.fromhex(token))
                except ValueError:
                    print("Malformed token", file=stderr)
                    exit(1)
                except KeyError:
                    print("Token not found", file=stderr)
                    exit(1)
        elif argv[1] == "verify-token":
            with connect() as conn:
                try:
                    if not check_token(conn, b64decode(argv[2])):
                        print("Token not found", file=stderr)
                        exit(1)
                    print("Found")
                except ValueError:
                    print("Malformed token", file=stderr)


main()