aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--paste/__main__.py25
-rw-r--r--tests/test_cli.py10
2 files changed, 19 insertions, 16 deletions
diff --git a/paste/__main__.py b/paste/__main__.py
index 7260516..389e215 100644
--- a/paste/__main__.py
+++ b/paste/__main__.py
@@ -4,6 +4,7 @@ from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timezone
from os import getenv
+from typing import TextIO
from wsgiref.simple_server import make_server
from . import DB_PATH, application, db, store
@@ -11,14 +12,14 @@ from . import DB_PATH, application, db, store
PROGRAM_NAME = "paste"
-def print_usage() -> None:
+def print_usage(file: TextIO) -> None:
print(
f"Usage: {PROGRAM_NAME} [-h|serve|new-token|list-tokens|delete-token]",
- file=sys.stderr,
+ file=file,
)
-def print_help() -> None:
+def print_help(file: TextIO) -> None:
print(
f"""A simple WSGI paste site
@@ -36,7 +37,7 @@ Environment:
PASTE_HOST The HTTP server host (default: localhost)
PASTE_PORT The HTTP server port (default: 8080)
PASTE_DB The path to the sqlite3 database (default: {DB_PATH})""",
- file=sys.stderr,
+ file=file,
)
@@ -49,7 +50,7 @@ def open_auth() -> Iterator[store.Auth]:
yield store.Auth(conn)
-def main(argv=sys.argv):
+def main(argv: list[str] = sys.argv, stderr: TextIO = sys.stderr):
if len(argv) <= 1 or len(argv) == 2 and argv[1] == "serve":
host = getenv("PASTE_HOST", "localhost")
port = int(getenv("PASTE_PORT", "8080"))
@@ -65,10 +66,10 @@ def main(argv=sys.argv):
created_at = datetime.fromtimestamp(created_at, timezone.utc)
print(f"{token_hash.hex()}\t{created_at.ctime()}")
elif argv[1] == "-h" or argv[1] == "--help":
- print_usage()
- print_help()
+ print_usage(stderr)
+ print_help(stderr)
else:
- print_usage()
+ print_usage(stderr)
exit(1)
elif len(argv) == 3:
if argv[1] == "delete-token":
@@ -80,20 +81,20 @@ def main(argv=sys.argv):
except ValueError:
auth.delete_token_hash(bytes.fromhex(token))
except ValueError:
- print("Malformed token", file=sys.stderr)
+ print("Malformed token", file=stderr)
exit(1)
except KeyError:
- print("Token not found", file=sys.stderr)
+ print("Token not found", file=stderr)
exit(1)
elif argv[1] == "verify-token":
with open_auth() as auth:
try:
if not auth.check_token(b64decode(argv[2])):
- print("Token not found", file=sys.stderr)
+ print("Token not found", file=stderr)
exit(1)
print("Found")
except ValueError:
- print("Malformed token", file=sys.stderr)
+ print("Malformed token", file=stderr)
if __name__ == "__main__":
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 6f3beb2..5e0acb3 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,12 +1,14 @@
+from io import StringIO
+
from paste.__main__ import main
-def test_cli(monkeypatch, capfd):
+def test_cli(monkeypatch):
monkeypatch.setattr("paste.__main__.db", None)
monkeypatch.setattr("paste.__main__.application", None)
monkeypatch.setattr("paste.__main__.store", None)
monkeypatch.setenv("PASTE_DB", "test_value")
- main(["paste", "--help"])
- _, err = capfd.readouterr()
- assert "paste" in err
+ buf = StringIO()
+ main(["paste", "--help"], stderr=buf)
+ assert "paste" in buf.getvalue()