From 0632c069785c7a1b7f02514ef2fa69d3fcecdc04 Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Sat, 20 Mar 2021 23:15:22 -0400 Subject: [PATCH] Docs, Tests, SIGINT --- .gitignore | 1 + .pre-commit-config.yaml | 4 +- setup.py | 1 + src/overseer/__init__.py | 10 ++++ src/overseer/database.py | 15 ++++++ tests/overseer/test_database.py | 85 ++++++++++++++++++++++++++++++--- tests/overseer/test_scanner.py | 0 7 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 tests/overseer/test_scanner.py diff --git a/.gitignore b/.gitignore index a9dbd1c..53181bb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ **/__pycache__/ +**/.coverage src/overseer.egg-info/ build/ dist/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 02fd918..dcf56fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,9 +4,9 @@ repos: hooks: - id: black language_version: python3.9 - files: '^src/overseer/|^setup.py' + files: '^src/overseer/|^setup.py|^tests/overseer/' - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.0 hooks: - id: flake8 - files: '^src/overseer/|^setup.py' + files: '^src/overseer/|^setup.py|^tests/overseer/' diff --git a/setup.py b/setup.py index 964390d..5c806f7 100644 --- a/setup.py +++ b/setup.py @@ -15,5 +15,6 @@ setup( "sqlalchemy", "Flask", ], + tests_require=["pytest"], extras_require={"dev": ["pre-commit", "black", "flake8", "pytest"]}, ) diff --git a/src/overseer/__init__.py b/src/overseer/__init__.py index 1ececa9..484013b 100644 --- a/src/overseer/__init__.py +++ b/src/overseer/__init__.py @@ -1,4 +1,6 @@ import click +import signal +import sys from importlib.metadata import version from overseer.config import EnvConfig from overseer.scanner import ScanManager @@ -14,6 +16,11 @@ database = DatabaseConnector(config.DATA_PATH) scan_manager = ScanManager() +def signal_handler(sig, frame): + scan_manager.shutdown() + sys.exit(0) + + def create_app(): return app @@ -23,5 +30,8 @@ def cli(): """Management script for the application.""" +# Handle SIGINT +signal.signal(signal.SIGINT, signal_handler) + # Import all flask views import overseer.overseer # noqa: E501,F401,E402 diff --git a/src/overseer/database.py b/src/overseer/database.py index a679a72..1aeef9f 100644 --- a/src/overseer/database.py +++ b/src/overseer/database.py @@ -73,6 +73,11 @@ class DatabaseConnector: **kwargs Either hostname or ip_addr + Raises + ------ + ValueError + If hostname or ip_addr isn't provided in kwargs + Returns ------- ScanTarget @@ -103,6 +108,11 @@ class DatabaseConnector: **kwargs Either hostname or ip_addr + Raises + ------ + ValueError + If hostname or ip_addr isn't provided in kwargs + Returns ------- ScanTarget @@ -196,6 +206,11 @@ class DatabaseConnector: error : str, optional Error message, if any + Raises + ------ + NoResultFound + If we cannot find the desired ScanHistory by ID + Returns ------- ScanHistory diff --git a/tests/overseer/test_database.py b/tests/overseer/test_database.py index 5e7ae8d..694aa70 100644 --- a/tests/overseer/test_database.py +++ b/tests/overseer/test_database.py @@ -1,9 +1,14 @@ -import pytest +# import pytest import ipaddress +from overseer import scan_manager from overseer.database import DatabaseConnector +# We're not testing this & this will stall tests +scan_manager.shutdown() + + def test_create_scan_target(): - db = DatabaseConnector(None, True) + db = DatabaseConnector(None, in_memory=True) hostname = db.create_scan_target(hostname="google.com") ip_address = db.create_scan_target(ip_addr="1.1.1.1") @@ -12,8 +17,74 @@ def test_create_scan_target(): assert ip_address.id == 2 assert ip_address.ip == int(ipaddress.ip_address("1.1.1.1")) -# def test_get_scan_target(): -# def test_get_all_scan_target(): -# def test_create_scan_result(): -# def test_update_scan_result(): -# def test_get_scan_results_by_target(): + +def test_get_scan_target(): + db = DatabaseConnector(None, in_memory=True) + created_target = db.create_scan_target(hostname="google.com") + found_target = db.get_scan_target(hostname="google.com") + + assert created_target.id == found_target.id + assert created_target.ip == found_target.ip + assert created_target.hostname == found_target.hostname + assert created_target.updated_at == found_target.updated_at + assert found_target.hostname == "google.com" + + +def test_get_all_scan_targets(): + db = DatabaseConnector(None, in_memory=True) + for i in range(1, 6): + db.create_scan_target(ip_addr="127.0.0." + str(i)) + + found_targets = db.get_all_scan_targets() + + assert len(found_targets) == 5 + + # This checks for properly ordered items + for i, target in enumerate(found_targets): + desired_target_ip = "127.0.0." + str(5 - i) + assert target.id == 5 - i + assert target.ip == int(ipaddress.ip_address(desired_target_ip)) + + +def test_create_scan_result(): + db = DatabaseConnector(None, in_memory=True) + scan_target = db.create_scan_target(ip_addr="127.0.0.1") + scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1") + scan_history_2 = db.create_scan_result("COMPLETE", ip_addr="127.0.0.2") + + assert scan_history.id == 1 + assert scan_history.target_id == scan_target.id + assert scan_history.status == "IN_PROGRESS" + assert scan_history_2.id == 2 + assert scan_history_2.error is None + assert scan_history_2.target_id != scan_target.id + assert scan_history_2.status == "COMPLETE" + + +def test_update_scan_result(): + db = DatabaseConnector(None, in_memory=True) + scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1") + updated_scan_history = db.update_scan_result( + scan_history.id, "COMPLETE", ["53 UDP", "53 TCP"] + ) + + assert scan_history.id == updated_scan_history.id + assert scan_history.status == "IN_PROGRESS" + assert updated_scan_history.status == "COMPLETE" + assert updated_scan_history.results == "53 UDP,53 TCP" + + +def test_get_scan_results_by_target(): + db = DatabaseConnector(None, in_memory=True) + for i in range(1, 6): + db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1") + + for i in range(1, 3): + db.create_scan_result("COMPLETE", ip_addr="127.0.0.2") + + scan_history_1 = db.get_scan_results_by_target(ip_addr="127.0.0.1") + scan_history_2 = db.get_scan_results_by_target(ip_addr="127.0.0.2") + + assert len(scan_history_1) == 5 + assert len(scan_history_2) == 2 + assert scan_history_1[3].status == "IN_PROGRESS" diff --git a/tests/overseer/test_scanner.py b/tests/overseer/test_scanner.py new file mode 100644 index 0000000..e69de29