aboutsummaryrefslogtreecommitdiffstats
path: root/paste/__main__.py
blob: 27ff72b5409de6bf975f2f6a24b598cb6978bcbc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from base64 import b64decode, b64encode
from contextlib import AbstractContextManager
from datetime import datetime, timezone
from os import getenv
from sys import argv, stderr
from wsgiref.simple_server import make_server

from . import DB_PATH, application, store

PROGRAM_NAME = "paste"


def print_usage() -> None:
    print(
        f"Usage: {PROGRAM_NAME} [-h|serve|new-token|list-tokens|delete-token]",
        file=stderr,
    )


def print_help() -> None:
    print(
        f"""A simple WSGI paste site

Commands:
    serve         (default) Start a basic HTTP server (do NOT use in production)
    new-token     Generate a new access token
    list-tokens   List access token hashes (sha256 hash, created at)
    check-token   Check if a token is valid and exists in the database
    delete-token  Delete a token (specify token or hash)

Options:
    -h  Show this help

Environment:
    PASTE_HOST  The HTTP server host (default: localhost)
    PASTE_PORT  The HTTP server port (default: 8080)
    PASTE_DB    The path to the sqlite3 database (default: {DB_PATH})""",
        file=stderr,
    )


db_path = getenv("PASTE_DB", DB_PATH)


def open_db() -> AbstractContextManager[store.Store]:
    return store.open(db_path)


def main():
    if len(argv) <= 1 or len(argv) == 2 and argv[1] == "serve":
        host = getenv("PASTE_HOST", "localhost")
        port = int(getenv("PASTE_PORT", "8080"))
        httpd = make_server(host, port, application)
        httpd.serve_forever()
    elif len(argv) == 2:
        if argv[1] == "new-token":
            with open_db() as db:
                print(b64encode(db.generate_token()).decode())
        elif argv[1] == "list-tokens":
            with open_db() as db:
                for token_hash, created_at in db.get_tokens():
                    created_at = datetime.fromtimestamp(created_at, timezone.utc)
                    print(f"{token_hash.hex()}\t{created_at.ctime()}")
        elif argv[1] == "-h" or argv[1] == "--help":
            print_usage()
            print_help()
        else:
            print_usage()
            exit(1)
    elif len(argv) == 3:
        if argv[1] == "delete-token":
            token = argv[2]
            with open_db() as db:
                try:
                    try:
                        db.delete_token(b64decode(token))
                    except ValueError:
                        db.delete_token_hash(bytes.fromhex(token))
                except ValueError:
                    print("Malformed token", file=stderr)
                    exit(1)
                except KeyError:
                    print("Token not found", file=stderr)
                    exit(1)
        elif argv[1] == "verify-token":
            with open_db() as db:
                try:
                    if not db.check_token(b64decode(argv[2])):
                        print("Token not found", file=stderr)
                        exit(1)
                    print("Found")
                except ValueError:
                    print("Malformed token", file=stderr)


if __name__ == "__main__":
    main()