From 72e3796c9e89946fd820fda96912dba2519e0828 Mon Sep 17 00:00:00 2001 From: Tomasz Kramkowski Date: Wed, 29 Mar 2023 18:00:10 +0100 Subject: very basic cli test --- paste/__main__.py | 16 ++++++++-------- tests/test_cli.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 8 deletions(-) create mode 100644 tests/test_cli.py diff --git a/paste/__main__.py b/paste/__main__.py index 6e684c4..7260516 100644 --- a/paste/__main__.py +++ b/paste/__main__.py @@ -1,9 +1,9 @@ +import sys from base64 import b64decode, b64encode from collections.abc import Iterator from contextlib import contextmanager from datetime import datetime, timezone from os import getenv -from sys import argv, stderr from wsgiref.simple_server import make_server from . import DB_PATH, application, db, store @@ -14,7 +14,7 @@ PROGRAM_NAME = "paste" def print_usage() -> None: print( f"Usage: {PROGRAM_NAME} [-h|serve|new-token|list-tokens|delete-token]", - file=stderr, + file=sys.stderr, ) @@ -36,7 +36,7 @@ Environment: PASTE_HOST The HTTP server host (default: localhost) PASTE_PORT The HTTP server port (default: 8080) PASTE_DB The path to the sqlite3 database (default: {DB_PATH})""", - file=stderr, + file=sys.stderr, ) @@ -49,7 +49,7 @@ def open_auth() -> Iterator[store.Auth]: yield store.Auth(conn) -def main(): +def main(argv=sys.argv): if len(argv) <= 1 or len(argv) == 2 and argv[1] == "serve": host = getenv("PASTE_HOST", "localhost") port = int(getenv("PASTE_PORT", "8080")) @@ -80,20 +80,20 @@ def main(): except ValueError: auth.delete_token_hash(bytes.fromhex(token)) except ValueError: - print("Malformed token", file=stderr) + print("Malformed token", file=sys.stderr) exit(1) except KeyError: - print("Token not found", file=stderr) + print("Token not found", file=sys.stderr) exit(1) elif argv[1] == "verify-token": with open_auth() as auth: try: if not auth.check_token(b64decode(argv[2])): - print("Token not found", file=stderr) + print("Token not found", file=sys.stderr) exit(1) print("Found") except ValueError: - print("Malformed token", file=stderr) + print("Malformed token", file=sys.stderr) if __name__ == "__main__": diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..6f3beb2 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,12 @@ +from paste.__main__ import main + + +def test_cli(monkeypatch, capfd): + monkeypatch.setattr("paste.__main__.db", None) + monkeypatch.setattr("paste.__main__.application", None) + monkeypatch.setattr("paste.__main__.store", None) + monkeypatch.setenv("PASTE_DB", "test_value") + + main(["paste", "--help"]) + _, err = capfd.readouterr() + assert "paste" in err -- cgit v1.2.3-54-g00ecf