Slight refactor

This commit is contained in:
Evan Reichard 2021-03-22 12:59:11 -04:00
parent d93cfd183c
commit 5ab87e1c6a
10 changed files with 56 additions and 66 deletions

View File

@ -1,7 +1,8 @@
FROM python:3.9.2-slim as build FROM python:3.9.2-slim as build
COPY . /app COPY . /app
WORKDIR /app WORKDIR /app
RUN python setup.py install RUN python setup.py clean --all install
FROM python:3.9.2-alpine3.13 FROM python:3.9.2-alpine3.13
COPY --from=build /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.9/site-packages COPY --from=build /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.9/site-packages
@ -9,4 +10,4 @@ COPY --from=build /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.
RUN pip3 install gunicorn RUN pip3 install gunicorn
ENTRYPOINT ["gunicorn"] ENTRYPOINT ["gunicorn"]
CMD ["overseer:app", "--bind", "0.0.0.0:5000", "--threads=4"] CMD ["overseer:create_app()", "--bind", "0.0.0.0:5000", "--threads=4"]

View File

@ -2,7 +2,7 @@ import click
import signal import signal
import sys import sys
from importlib.metadata import version from importlib.metadata import version
from overseer.config import EnvConfig from overseer.config import Config
from overseer.scanner import ScanManager from overseer.scanner import ScanManager
from overseer.database import DatabaseConnector from overseer.database import DatabaseConnector
from flask import Flask from flask import Flask
@ -11,17 +11,23 @@ from flask.cli import FlaskGroup
__version__ = version("overseer") __version__ = version("overseer")
app = Flask(__name__) app = Flask(__name__)
config = EnvConfig() database = DatabaseConnector(Config.DB_TYPE, Config.DATA_PATH)
database = DatabaseConnector(config.DATA_PATH)
scan_manager = ScanManager() scan_manager = ScanManager()
def signal_handler(sig, frame): def signal_handler(sig, frame):
scan_manager.shutdown() scan_manager.stop()
sys.exit(0) sys.exit(0)
def create_app(): def create_app():
import overseer.api.common as api_common
import overseer.api.v1 as api_v1
app.register_blueprint(api_common.bp)
app.register_blueprint(api_v1.bp)
scan_manager.start()
return app return app
@ -32,6 +38,3 @@ def cli():
# Handle SIGINT # Handle SIGINT
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
# Import all flask views
import overseer.overseer # noqa: E501,F401,E402

View File

@ -1,25 +1,27 @@
from overseer import app
from flask import make_response, render_template, send_from_directory from flask import make_response, render_template, send_from_directory
from overseer.api.v1 import api as api_v1 from flask import Blueprint
# Setup blueprint
bp = Blueprint("common", __name__)
@app.route("/", methods=["GET"]) @bp.route("/", methods=["GET"])
def main_entry(): def main_entry():
"""Initial Entrypoint to the SPA (i.e. 'index.html')""" """Initial Entrypoint to the SPA (i.e. 'index.html')"""
return make_response(render_template("index.html")) return make_response(render_template("index.html"))
@app.route("/<path:path>", methods=["GET"]) @bp.route("/<path:path>", methods=["GET"])
def catch_all(path): def catch_all(path):
"""Necessary due to client side SPA route handling""" """Necessary due to client side SPA route handling"""
return make_response(render_template("index.html")) return make_response(render_template("index.html"))
@app.route("/static/<path:path>") @bp.route("/static/<path:path>")
def static_resources(path): def static_resources(path):
"""Front end static resources""" """Front end static resources"""
return send_from_directory("static", path) return send_from_directory("static", path)
# Version API's # Version API's
app.register_blueprint(api_v1) # app.register_blueprint(api_v1)

View File

@ -1,9 +1,15 @@
import overseer import overseer
from overseer.api.v1 import api from overseer.api import open_websockets
from flask import request from flask import Blueprint, request
from flask_socketio import SocketIO
# Setup blueprint & websocket
bp = Blueprint("v1", __name__, url_prefix="/api/v1")
socketio = SocketIO(overseer.app, path="/api/v1/socket.io")
open_websockets.append(socketio)
@api.route("/status", methods=["GET"]) @bp.route("/status", methods=["GET"])
def get_status(): def get_status():
"""Get server version and all active scans.""" """Get server version and all active scans."""
return { return {
@ -12,7 +18,7 @@ def get_status():
} }
@api.route("/scans", methods=["POST"]) @bp.route("/scans", methods=["POST"])
def post_scans(): def post_scans():
""" """
POST: POST:
@ -35,7 +41,7 @@ def post_scans():
return __normalize_scan_results([scan_history], data["target"])[0] return __normalize_scan_results([scan_history], data["target"])[0]
@api.route("/scans/<string:target>", methods=["GET"]) @bp.route("/scans/<string:target>", methods=["GET"])
def get_scans_by_target(target): def get_scans_by_target(target):
""" """
GET: GET:
@ -57,7 +63,7 @@ def get_scans_by_target(target):
return {"data": __normalize_scan_results(scan_results, target)} return {"data": __normalize_scan_results(scan_results, target)}
@api.route("/search", methods=["GET"]) @bp.route("/search", methods=["GET"])
def get_scans(search): def get_scans(search):
"""Not Implemented """Not Implemented
GET: GET:

View File

@ -1,5 +0,0 @@
from flask import Blueprint
# Setup the API blueprint
api = Blueprint("v1", __name__, url_prefix="/api/v1")
from overseer.api.v1 import routes, events # noqa: F401,E402

View File

@ -1,19 +0,0 @@
from overseer import app
from overseer.api import open_websockets
from flask_socketio import SocketIO
# Create and register new websocket for v1 API
socketio = SocketIO(app, path="/api/v1/socket.io")
open_websockets.append(socketio)
@socketio.on("message")
def handle_message(data):
"""Not Implemented - Used to response to client sent WebSocket messages"""
print("RAW DATA: %s" % data)
@socketio.on("json")
def handle_json(json):
"""Not Implemented - Used to response to client sent WebSocket messages"""
print("JSON DATA: %s" % json)

View File

@ -8,7 +8,7 @@ def get_env(key, default=None, required=False):
return os.environ.get(key, default) return os.environ.get(key, default)
class EnvConfig: class Config:
"""Wrap application configurations """Wrap application configurations
Attributes Attributes
@ -19,5 +19,5 @@ class EnvConfig:
The path where to store any resources (default: ./) The path where to store any resources (default: ./)
""" """
DATABASE = get_env("OVERSEER_DB", default="sqlite") DB_TYPE = get_env("OVERSEER_DB_TYPE", default="sqlite")
DATA_PATH = get_env("OVERSEER_DATA_PATH", default="./") DATA_PATH = get_env("OVERSEER_DATA_PATH", default="./")

View File

@ -28,20 +28,20 @@ class DatabaseConnector:
25 / page 25 / page
""" """
def __init__(self, data_path, in_memory=False): def __init__(self, db_type, data_path=None):
""" """
Parameters Parameters
---------- ----------
db_type : str
Database tyle (e.g. sqlite, memory)
data_path : str data_path : str
Directory to store the sqlite file Directory to store the sqlite file
in_memory : bool, optional
Directive to store the DB in memory
""" """
if in_memory: if db_type.lower() == "memory":
self.__engine = create_engine( self.__engine = create_engine(
"sqlite+pysqlite:///:memory:", echo=False, future=True "sqlite+pysqlite:///:memory:", echo=False, future=True
) )
else: elif db_type.lower() == "sqlite":
db_path = path.join(data_path, "overseer.sqlite") db_path = path.join(data_path, "overseer.sqlite")
self.__engine = create_engine( self.__engine = create_engine(
"sqlite+pysqlite:///%s" % db_path, "sqlite+pysqlite:///%s" % db_path,

View File

@ -13,8 +13,10 @@ class ScanManager:
Methods Methods
------- -------
shutdown() start()
Shutdown and cleanup the scan monitor thread Start the scan monitor thread
stop()
Stop and cleanup the scan monitor thread
get_status() get_status()
Get a normalized list of dicts detailing outstanding scans Get a normalized list of dicts detailing outstanding scans
perform_scan(target) perform_scan(target)
@ -27,8 +29,6 @@ class ScanManager:
"""Create instance and start thread""" """Create instance and start thread"""
self.__pending_shutdown = False self.__pending_shutdown = False
self.__active_scans = [] self.__active_scans = []
self.__monitor_thread = Thread(target=self.__scan_monitor)
self.__monitor_thread.start()
def __scan_monitor(self): def __scan_monitor(self):
"""Monitors active and completed scans """Monitors active and completed scans
@ -83,7 +83,11 @@ class ScanManager:
scan.join() scan.join()
self.__active_scans.remove(scan) self.__active_scans.remove(scan)
def shutdown(self): def start(self):
self.__monitor_thread = Thread(target=self.__scan_monitor)
self.__monitor_thread.start()
def stop(self):
"""Shutdown and cleanup the scan monitor thread""" """Shutdown and cleanup the scan monitor thread"""
self.__pending_shutdown = True self.__pending_shutdown = True
self.__monitor_thread.join() self.__monitor_thread.join()

View File

@ -1,14 +1,12 @@
# import pytest # import pytest
import ipaddress import ipaddress
from overseer import scan_manager
from overseer.database import DatabaseConnector from overseer.database import DatabaseConnector
# We're not testing this & this will stall tests DB_TYPE = "MEMORY"
scan_manager.shutdown()
def test_create_scan_target(): def test_create_scan_target():
db = DatabaseConnector(None, in_memory=True) db = DatabaseConnector(DB_TYPE)
hostname = db.create_scan_target(hostname="google.com") hostname = db.create_scan_target(hostname="google.com")
ip_address = db.create_scan_target(ip_addr="1.1.1.1") ip_address = db.create_scan_target(ip_addr="1.1.1.1")
@ -19,7 +17,7 @@ def test_create_scan_target():
def test_get_scan_target(): def test_get_scan_target():
db = DatabaseConnector(None, in_memory=True) db = DatabaseConnector(DB_TYPE)
created_target = db.create_scan_target(hostname="google.com") created_target = db.create_scan_target(hostname="google.com")
found_target = db.get_scan_target(hostname="google.com") found_target = db.get_scan_target(hostname="google.com")
@ -31,7 +29,7 @@ def test_get_scan_target():
def test_get_all_scan_targets(): def test_get_all_scan_targets():
db = DatabaseConnector(None, in_memory=True) db = DatabaseConnector(DB_TYPE)
for i in range(1, 6): for i in range(1, 6):
db.create_scan_target(ip_addr="127.0.0." + str(i)) db.create_scan_target(ip_addr="127.0.0." + str(i))
@ -47,7 +45,7 @@ def test_get_all_scan_targets():
def test_create_scan_result(): def test_create_scan_result():
db = DatabaseConnector(None, in_memory=True) db = DatabaseConnector(DB_TYPE)
scan_target = db.create_scan_target(ip_addr="127.0.0.1") scan_target = db.create_scan_target(ip_addr="127.0.0.1")
scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1") scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1")
scan_history_2 = db.create_scan_result("COMPLETE", ip_addr="127.0.0.2") scan_history_2 = db.create_scan_result("COMPLETE", ip_addr="127.0.0.2")
@ -62,7 +60,7 @@ def test_create_scan_result():
def test_update_scan_result(): def test_update_scan_result():
db = DatabaseConnector(None, in_memory=True) db = DatabaseConnector(DB_TYPE)
scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1") scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1")
updated_scan_history = db.update_scan_result( updated_scan_history = db.update_scan_result(
scan_history.id, "COMPLETE", ["53 UDP", "53 TCP"] scan_history.id, "COMPLETE", ["53 UDP", "53 TCP"]
@ -75,7 +73,7 @@ def test_update_scan_result():
def test_get_scan_results_by_target(): def test_get_scan_results_by_target():
db = DatabaseConnector(None, in_memory=True) db = DatabaseConnector(DB_TYPE)
for i in range(1, 6): for i in range(1, 6):
db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1") db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1")