aboutsummaryrefslogtreecommitdiffstats
path: root/paste/__main__.py
blob: 389e215342eab08f3d0485c174ad631b085b409f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import sys
from base64 import b64decode, b64encode
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timezone
from os import getenv
from typing import TextIO
from wsgiref.simple_server import make_server

from . import DB_PATH, application, db, store

PROGRAM_NAME = "paste"


def print_usage(file: TextIO) -> None:
    print(
        f"Usage: {PROGRAM_NAME} [-h|serve|new-token|list-tokens|delete-token]",
        file=file,
    )


def print_help(file: TextIO) -> None:
    print(
        f"""A simple WSGI paste site

Commands:
    serve         (default) Start a basic HTTP server (do NOT use in production)
    new-token     Generate a new access token
    list-tokens   List access token hashes (sha256 hash, created at)
    check-token   Check if a token is valid and exists in the database
    delete-token  Delete a token (specify token or hash)

Options:
    -h  Show this help

Environment:
    PASTE_HOST  The HTTP server host (default: localhost)
    PASTE_PORT  The HTTP server port (default: 8080)
    PASTE_DB    The path to the sqlite3 database (default: {DB_PATH})""",
        file=file,
    )


db_path = getenv("PASTE_DB", DB_PATH)


@contextmanager
def open_auth() -> Iterator[store.Auth]:
    with db.connect(db_path) as conn:
        yield store.Auth(conn)


def main(argv: list[str] = sys.argv, stderr: TextIO = sys.stderr):
    if len(argv) <= 1 or len(argv) == 2 and argv[1] == "serve":
        host = getenv("PASTE_HOST", "localhost")
        port = int(getenv("PASTE_PORT", "8080"))
        httpd = make_server(host, port, application)
        httpd.serve_forever()
    elif len(argv) == 2:
        if argv[1] == "new-token":
            with open_auth() as auth:
                print(b64encode(auth.generate_token()).decode())
        elif argv[1] == "list-tokens":
            with open_auth() as auth:
                for token_hash, created_at in auth.get_tokens():
                    created_at = datetime.fromtimestamp(created_at, timezone.utc)
                    print(f"{token_hash.hex()}\t{created_at.ctime()}")
        elif argv[1] == "-h" or argv[1] == "--help":
            print_usage(stderr)
            print_help(stderr)
        else:
            print_usage(stderr)
            exit(1)
    elif len(argv) == 3:
        if argv[1] == "delete-token":
            token = argv[2]
            with open_auth() as auth:
                try:
                    try:
                        auth.delete_token(b64decode(token))
                    except ValueError:
                        auth.delete_token_hash(bytes.fromhex(token))
                except ValueError:
                    print("Malformed token", file=stderr)
                    exit(1)
                except KeyError:
                    print("Token not found", file=stderr)
                    exit(1)
        elif argv[1] == "verify-token":
            with open_auth() as auth:
                try:
                    if not auth.check_token(b64decode(argv[2])):
                        print("Token not found", file=stderr)
                        exit(1)
                    print("Found")
                except ValueError:
                    print("Malformed token", file=stderr)


if __name__ == "__main__":
    main()