Slight refactor

This commit is contained in:
Evan Reichard 2021-03-22 12:59:11 -04:00
parent d93cfd183c
commit 5ab87e1c6a
10 changed files with 56 additions and 66 deletions

View File

@ -1,7 +1,8 @@
FROM python:3.9.2-slim as build
COPY . /app
WORKDIR /app
RUN python setup.py install
RUN python setup.py clean --all install
FROM python:3.9.2-alpine3.13
COPY --from=build /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.9/site-packages
@ -9,4 +10,4 @@ COPY --from=build /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.
RUN pip3 install gunicorn
ENTRYPOINT ["gunicorn"]
CMD ["overseer:app", "--bind", "0.0.0.0:5000", "--threads=4"]
CMD ["overseer:create_app()", "--bind", "0.0.0.0:5000", "--threads=4"]

View File

@ -2,7 +2,7 @@ import click
import signal
import sys
from importlib.metadata import version
from overseer.config import EnvConfig
from overseer.config import Config
from overseer.scanner import ScanManager
from overseer.database import DatabaseConnector
from flask import Flask
@ -11,17 +11,23 @@ from flask.cli import FlaskGroup
__version__ = version("overseer")
app = Flask(__name__)
config = EnvConfig()
database = DatabaseConnector(config.DATA_PATH)
database = DatabaseConnector(Config.DB_TYPE, Config.DATA_PATH)
scan_manager = ScanManager()
def signal_handler(sig, frame):
scan_manager.shutdown()
scan_manager.stop()
sys.exit(0)
def create_app():
import overseer.api.common as api_common
import overseer.api.v1 as api_v1
app.register_blueprint(api_common.bp)
app.register_blueprint(api_v1.bp)
scan_manager.start()
return app
@ -32,6 +38,3 @@ def cli():
# Handle SIGINT
signal.signal(signal.SIGINT, signal_handler)
# Import all flask views
import overseer.overseer # noqa: E501,F401,E402

View File

@ -1,25 +1,27 @@
from overseer import app
from flask import make_response, render_template, send_from_directory
from overseer.api.v1 import api as api_v1
from flask import Blueprint
# Setup blueprint
bp = Blueprint("common", __name__)
@app.route("/", methods=["GET"])
@bp.route("/", methods=["GET"])
def main_entry():
"""Initial Entrypoint to the SPA (i.e. 'index.html')"""
return make_response(render_template("index.html"))
@app.route("/<path:path>", methods=["GET"])
@bp.route("/<path:path>", methods=["GET"])
def catch_all(path):
"""Necessary due to client side SPA route handling"""
return make_response(render_template("index.html"))
@app.route("/static/<path:path>")
@bp.route("/static/<path:path>")
def static_resources(path):
"""Front end static resources"""
return send_from_directory("static", path)
# Version API's
app.register_blueprint(api_v1)
# app.register_blueprint(api_v1)

View File

@ -1,9 +1,15 @@
import overseer
from overseer.api.v1 import api
from flask import request
from overseer.api import open_websockets
from flask import Blueprint, request
from flask_socketio import SocketIO
# Setup blueprint & websocket
bp = Blueprint("v1", __name__, url_prefix="/api/v1")
socketio = SocketIO(overseer.app, path="/api/v1/socket.io")
open_websockets.append(socketio)
@api.route("/status", methods=["GET"])
@bp.route("/status", methods=["GET"])
def get_status():
"""Get server version and all active scans."""
return {
@ -12,7 +18,7 @@ def get_status():
}
@api.route("/scans", methods=["POST"])
@bp.route("/scans", methods=["POST"])
def post_scans():
"""
POST:
@ -35,7 +41,7 @@ def post_scans():
return __normalize_scan_results([scan_history], data["target"])[0]
@api.route("/scans/<string:target>", methods=["GET"])
@bp.route("/scans/<string:target>", methods=["GET"])
def get_scans_by_target(target):
"""
GET:
@ -57,7 +63,7 @@ def get_scans_by_target(target):
return {"data": __normalize_scan_results(scan_results, target)}
@api.route("/search", methods=["GET"])
@bp.route("/search", methods=["GET"])
def get_scans(search):
"""Not Implemented
GET:

View File

@ -1,5 +0,0 @@
from flask import Blueprint
# Setup the API blueprint
api = Blueprint("v1", __name__, url_prefix="/api/v1")
from overseer.api.v1 import routes, events # noqa: F401,E402

View File

@ -1,19 +0,0 @@
from overseer import app
from overseer.api import open_websockets
from flask_socketio import SocketIO
# Create and register new websocket for v1 API
socketio = SocketIO(app, path="/api/v1/socket.io")
open_websockets.append(socketio)
@socketio.on("message")
def handle_message(data):
"""Not Implemented - Used to response to client sent WebSocket messages"""
print("RAW DATA: %s" % data)
@socketio.on("json")
def handle_json(json):
"""Not Implemented - Used to response to client sent WebSocket messages"""
print("JSON DATA: %s" % json)

View File

@ -8,7 +8,7 @@ def get_env(key, default=None, required=False):
return os.environ.get(key, default)
class EnvConfig:
class Config:
"""Wrap application configurations
Attributes
@ -19,5 +19,5 @@ class EnvConfig:
The path where to store any resources (default: ./)
"""
DATABASE = get_env("OVERSEER_DB", default="sqlite")
DB_TYPE = get_env("OVERSEER_DB_TYPE", default="sqlite")
DATA_PATH = get_env("OVERSEER_DATA_PATH", default="./")

View File

@ -28,20 +28,20 @@ class DatabaseConnector:
25 / page
"""
def __init__(self, data_path, in_memory=False):
def __init__(self, db_type, data_path=None):
"""
Parameters
----------
db_type : str
Database tyle (e.g. sqlite, memory)
data_path : str
Directory to store the sqlite file
in_memory : bool, optional
Directive to store the DB in memory
"""
if in_memory:
if db_type.lower() == "memory":
self.__engine = create_engine(
"sqlite+pysqlite:///:memory:", echo=False, future=True
)
else:
elif db_type.lower() == "sqlite":
db_path = path.join(data_path, "overseer.sqlite")
self.__engine = create_engine(
"sqlite+pysqlite:///%s" % db_path,

View File

@ -13,8 +13,10 @@ class ScanManager:
Methods
-------
shutdown()
Shutdown and cleanup the scan monitor thread
start()
Start the scan monitor thread
stop()
Stop and cleanup the scan monitor thread
get_status()
Get a normalized list of dicts detailing outstanding scans
perform_scan(target)
@ -27,8 +29,6 @@ class ScanManager:
"""Create instance and start thread"""
self.__pending_shutdown = False
self.__active_scans = []
self.__monitor_thread = Thread(target=self.__scan_monitor)
self.__monitor_thread.start()
def __scan_monitor(self):
"""Monitors active and completed scans
@ -83,7 +83,11 @@ class ScanManager:
scan.join()
self.__active_scans.remove(scan)
def shutdown(self):
def start(self):
self.__monitor_thread = Thread(target=self.__scan_monitor)
self.__monitor_thread.start()
def stop(self):
"""Shutdown and cleanup the scan monitor thread"""
self.__pending_shutdown = True
self.__monitor_thread.join()

View File

@ -1,14 +1,12 @@
# import pytest
import ipaddress
from overseer import scan_manager
from overseer.database import DatabaseConnector
# We're not testing this & this will stall tests
scan_manager.shutdown()
DB_TYPE = "MEMORY"
def test_create_scan_target():
db = DatabaseConnector(None, in_memory=True)
db = DatabaseConnector(DB_TYPE)
hostname = db.create_scan_target(hostname="google.com")
ip_address = db.create_scan_target(ip_addr="1.1.1.1")
@ -19,7 +17,7 @@ def test_create_scan_target():
def test_get_scan_target():
db = DatabaseConnector(None, in_memory=True)
db = DatabaseConnector(DB_TYPE)
created_target = db.create_scan_target(hostname="google.com")
found_target = db.get_scan_target(hostname="google.com")
@ -31,7 +29,7 @@ def test_get_scan_target():
def test_get_all_scan_targets():
db = DatabaseConnector(None, in_memory=True)
db = DatabaseConnector(DB_TYPE)
for i in range(1, 6):
db.create_scan_target(ip_addr="127.0.0." + str(i))
@ -47,7 +45,7 @@ def test_get_all_scan_targets():
def test_create_scan_result():
db = DatabaseConnector(None, in_memory=True)
db = DatabaseConnector(DB_TYPE)
scan_target = db.create_scan_target(ip_addr="127.0.0.1")
scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1")
scan_history_2 = db.create_scan_result("COMPLETE", ip_addr="127.0.0.2")
@ -62,7 +60,7 @@ def test_create_scan_result():
def test_update_scan_result():
db = DatabaseConnector(None, in_memory=True)
db = DatabaseConnector(DB_TYPE)
scan_history = db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1")
updated_scan_history = db.update_scan_result(
scan_history.id, "COMPLETE", ["53 UDP", "53 TCP"]
@ -75,7 +73,7 @@ def test_update_scan_result():
def test_get_scan_results_by_target():
db = DatabaseConnector(None, in_memory=True)
db = DatabaseConnector(DB_TYPE)
for i in range(1, 6):
db.create_scan_result("IN_PROGRESS", ip_addr="127.0.0.1")