diff --git a/src/overseer/__init__.py b/src/overseer/__init__.py index d9efbdc..1ececa9 100644 --- a/src/overseer/__init__.py +++ b/src/overseer/__init__.py @@ -20,7 +20,8 @@ def create_app(): @click.group(cls=FlaskGroup, create_app=create_app) def cli(): - """Management script for the Wiki application.""" + """Management script for the application.""" +# Import all flask views import overseer.overseer # noqa: E501,F401,E402 diff --git a/src/overseer/api/__init__.py b/src/overseer/api/__init__.py index 1403f6f..3003a98 100644 --- a/src/overseer/api/__init__.py +++ b/src/overseer/api/__init__.py @@ -4,8 +4,10 @@ open_websockets = [] def send_websocket_event(data): """Send an event to all registered websockets. - Arguments: - data -- Data to send over the websocket(s) + Arguments + --------- + data : obj + Data to send over the websocket(s) """ for socket in open_websockets: socket.send(data) diff --git a/src/overseer/api/v1/__init__.py b/src/overseer/api/v1/__init__.py index b860caf..1cc2ec4 100644 --- a/src/overseer/api/v1/__init__.py +++ b/src/overseer/api/v1/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint +# Setup the API blueprint api = Blueprint("v1", __name__, url_prefix="/api/v1") from overseer.api.v1 import routes, events # noqa: F401,E402 diff --git a/src/overseer/api/v1/events.py b/src/overseer/api/v1/events.py index 8b5ba52..e52304e 100644 --- a/src/overseer/api/v1/events.py +++ b/src/overseer/api/v1/events.py @@ -2,15 +2,18 @@ from overseer import app from overseer.api import open_websockets from flask_socketio import SocketIO +# Create and register new websocket for v1 API socketio = SocketIO(app, path="/api/v1/socket.io") open_websockets.append(socketio) @socketio.on("message") def handle_message(data): + """Not Implemented - Used to response to client sent WebSocket messages""" print("RAW DATA: %s" % data) @socketio.on("json") def handle_json(json): + """Not Implemented - Used to response to client sent WebSocket messages""" print("JSON DATA: %s" % json) diff --git a/src/overseer/api/v1/routes.py b/src/overseer/api/v1/routes.py index a261fad..d18b889 100644 --- a/src/overseer/api/v1/routes.py +++ b/src/overseer/api/v1/routes.py @@ -40,7 +40,7 @@ def get_scans_by_target(target): /api/v1/scans/1.1.1.1 RESPONSE: { "data": [ ] } """ - page = 1 + page = 1 # TODO: Pagination if overseer.scan_manager.is_ip(target): scan_results = overseer.database.get_scan_results_by_target( @@ -56,7 +56,7 @@ def get_scans_by_target(target): @api.route("/search", methods=["GET"]) def get_scans(search): - """ + """Not Implemented GET: REQUEST: /api/v1/search?query=1.1.1.1 /api/v1/search?query=192.168.0.0/24 @@ -67,6 +67,7 @@ def get_scans(search): def __normalize_scan_results(scan_results, target): + """Returns a normalized list of objects for ScanHistory items""" return list( map( lambda x: { diff --git a/src/overseer/config.py b/src/overseer/config.py index 2689ef9..f353840 100644 --- a/src/overseer/config.py +++ b/src/overseer/config.py @@ -2,11 +2,22 @@ import os def get_env(key, default=None, required=False): + """Wrapper for gathering env vars.""" if required: assert key in os.environ, "Missing Environment Variable: %s" % key return os.environ.get(key, default) class EnvConfig: + """Wrap application configurations + + Attributes + ---------- + DATABASE : str + The specied desired database (default: sqlite) + DATA_PATH : str + The path where to store any resources (default: ./) + """ + DATABASE = get_env("OVERSEER_DB", default="sqlite") DATA_PATH = get_env("OVERSEER_DATA_PATH", default="./") diff --git a/src/overseer/database.py b/src/overseer/database.py index a6e3cb0..a679a72 100644 --- a/src/overseer/database.py +++ b/src/overseer/database.py @@ -8,23 +8,52 @@ from sqlalchemy.orm.exc import NoResultFound class DatabaseConnector: + """Used to manipulate the database + + Methods + ------- + create_scan_target(**kwargs) + Create a new ScanTarget database row with the provided hostname or + ip_addr + get_scan_target(**kwargs) + Get a ScanTarget by either its hostname or ip_addr + get_all_scan_targets(page=1) + Get all ScanTargets ordered by most recent, limited to 25 / page + create_scan_result(status, results=[], error=None, **kwargs) + Create a new ScanHistory for the provided hostname or ip_addr + update_scan_result(history_id, status, results=None, error=None) + Update the referenced ScanHistory ID with the provided values + get_scan_results_by_target(page=1, **kwargs) + Get all ScanHistory for the provided hostname or ip_add, limited to + 25 / page + """ + def __init__(self, data_path, in_memory=False): + """ + Parameters + ---------- + data_path : str + Directory to store the sqlite file + in_memory : bool, optional + Directive to store the DB in memory + """ if in_memory: - self.engine = create_engine( + self.__engine = create_engine( "sqlite+pysqlite:///:memory:", echo=False, future=True ) else: db_path = path.join(data_path, "overseer.sqlite") - self.engine = create_engine( + self.__engine = create_engine( "sqlite+pysqlite:///%s" % db_path, echo=False, future=True, ) - Base.metadata.create_all(self.engine) + Base.metadata.create_all(self.__engine) self.__cleanup_stale_records() def __cleanup_stale_records(self): - session = Session(bind=self.engine) + """Cleans up any stale ScanHistory records""" + session = Session(bind=self.__engine) history_filter = ScanHistory.status == "IN_PROGRESS" all_stale = session.query(ScanHistory).filter(history_filter).all() @@ -36,6 +65,19 @@ class DatabaseConnector: session.close() def create_scan_target(self, **kwargs): + """Create a new ScanTarget database row with the provided hostname or + ip_addr + + Parameters + ---------- + **kwargs + Either hostname or ip_addr + + Returns + ------- + ScanTarget + The created ScanTarget + """ if len(kwargs.keys() & {"ip_addr", "hostname"}) != 1: raise ValueError("Missing keyword argument: ip_addr or hostname") @@ -47,13 +89,25 @@ class DatabaseConnector: hostname=kwargs["hostname"], updated_at=datetime.now() ) - session = Session(bind=self.engine, expire_on_commit=False) + session = Session(bind=self.__engine, expire_on_commit=False) session.add(scan_target) session.commit() session.close() return scan_target def get_scan_target(self, **kwargs): + """Get a ScanTarget by either its hostname or ip_addr + + Parameters + ---------- + **kwargs + Either hostname or ip_addr + + Returns + ------- + ScanTarget + The requested ScanTarget + """ if len(kwargs.keys() & {"ip_addr", "hostname"}) != 1: raise ValueError("Missing keyword argument: ip_addr or hostname") @@ -63,13 +117,25 @@ class DatabaseConnector: elif "hostname" in kwargs: target_filter = ScanTarget.hostname == kwargs["hostname"] - session = Session(bind=self.engine) + session = Session(bind=self.__engine) scan_target = session.query(ScanTarget).filter(target_filter).first() session.close() return scan_target def get_all_scan_targets(self, page=1): - session = Session(bind=self.engine) + """Get all ScanTargets ordered by most recent, limited to 25 / page + + Parameters + ---------- + page : int, optional + The desired ScanTarget page + + Returns + ------- + list + List of ScanTarget + """ + session = Session(bind=self.__engine) scan_targets = ( session.query(ScanTarget) .order_by(ScanTarget.updated_at.desc()) @@ -81,6 +147,24 @@ class DatabaseConnector: return scan_targets def create_scan_result(self, status, results=[], error=None, **kwargs): + """Create a new ScanHistory for the provided hostname or ip_addr + + Parameters + ---------- + status : str + The status of the scan (IN_PROGRESS, FAILED, COMPLETE) + results : list + List of strings of open ports (E.g. ["53 UDP", "53 TCP"] + error : str, optional + Error message, if any + **kwargs + Either hostname or ip_addr + + Returns + ------- + ScanHistory + The created ScanHistory + """ scan_target = self.get_scan_target(**kwargs) if not scan_target: scan_target = self.create_scan_target(**kwargs) @@ -92,14 +176,32 @@ class DatabaseConnector: error=error, created_at=datetime.now(), ) - session = Session(bind=self.engine, expire_on_commit=False) + session = Session(bind=self.__engine, expire_on_commit=False) session.add(scan_history) session.commit() session.close() return scan_history def update_scan_result(self, history_id, status, results=None, error=None): - session = Session(bind=self.engine, expire_on_commit=False) + """Update the referenced ScanHistory ID with the provided values + + Parameters + ---------- + history_id : int + The ScanHistory ID to update + status : str + The status of the scan (IN_PROGRESS, FAILED, COMPLETE) + results : list, optional + List of strings of open ports (E.g. ["53 UDP", "53 TCP"] + error : str, optional + Error message, if any + + Returns + ------- + ScanHistory + The updated ScanHistory + """ + session = Session(bind=self.__engine, expire_on_commit=False) scan_history = session.query(ScanHistory).get(history_id) if scan_history is None: @@ -116,6 +218,19 @@ class DatabaseConnector: return scan_history def get_scan_results_by_target(self, page=1, **kwargs): + """Get all ScanHistory for the provided hostname or ip_add, limited to + 25 / page + + Parameters + ---------- + page : int, optional + The desired ScanResult page + + Returns + ------- + list + List of ScanHistory + """ if len(kwargs.keys() & {"ip_addr", "hostname"}) != 1: raise ValueError("Missing keyword argument: ip_addr or hostname") @@ -125,7 +240,7 @@ class DatabaseConnector: elif "hostname" in kwargs: history_filter = ScanTarget.hostname == kwargs["hostname"] - session = Session(bind=self.engine) + session = Session(bind=self.__engine) scan_histories = ( session.query(ScanHistory) .join(ScanHistory.target) diff --git a/src/overseer/models.py b/src/overseer/models.py index fb1eb2f..e3fd28c 100644 --- a/src/overseer/models.py +++ b/src/overseer/models.py @@ -5,6 +5,20 @@ Base = declarative_base() class ScanTarget(Base): + """ScanTarget DB Model + + Attributes + ------- + id : Column(Integer) + The ID in the database + ip : Column(Integer), optional + The integer represented IP Address in the database + hostname : Column(String) + The hostname in the database + updated_at : Column(DateTime) + The DateTime when the ScanTarget was last updated + """ + __tablename__ = "scan_target" # Unique ID @@ -20,6 +34,7 @@ class ScanTarget(Base): updated_at = Column(DateTime()) def __repr__(self): + """The string representation of the class.""" return ( f"ScanTarget(id={self.id!r}, ip={self.ip!r}, " f"hostname={self.hostname!r}, updated_at={self.updated_at!r})" @@ -27,6 +42,26 @@ class ScanTarget(Base): class ScanHistory(Base): + """ScanHistory DB Model + + Attributes + ------- + id : Column(Integer) + The ID in the database + target_id : Column(Integer) + The ScanTarget ID reference + results : Column(String) + The CSV delimited string representing all open ports. + created_at : Column(DateTime) + The DateTime when the ScanHistory was created + status : Column(String) + The status of the ScanHistory (IN_PROGRESS, COMPLETE, FAILED) + error : Column(String) + The error, if any, of the ScanHistory + target : relationship + The ScanTarget relationship reference + """ + __tablename__ = "scan_history" # Unique ID @@ -51,6 +86,7 @@ class ScanHistory(Base): target = relationship("ScanTarget", foreign_keys=[target_id]) def __repr__(self): + """The string representation of the class.""" return ( f"ScanHistory(id={self.id!r}, target_id={self.target_id!r}, " f"results={self.results!r}, created_at={self.created_at!r}, " diff --git a/src/overseer/overseer.py b/src/overseer/overseer.py index 47d81f9..1077c36 100644 --- a/src/overseer/overseer.py +++ b/src/overseer/overseer.py @@ -5,25 +5,19 @@ from overseer.api.v1 import api as api_v1 @app.route("/", methods=["GET"]) def main_entry(): - """ - Initial Entrypoint to the SPA (i.e. 'index.html') - """ + """Initial Entrypoint to the SPA (i.e. 'index.html')""" return make_response(render_template("index.html")) @app.route("/", methods=["GET"]) def catch_all(path): - """ - Necessary due to client side SPA route handling. - """ + """Necessary due to client side SPA route handling""" return make_response(render_template("index.html")) @app.route("/static/") def static_resources(path): - """ - Front End Static Resources - """ + """Front end static resources""" return send_from_directory("static", path) diff --git a/src/overseer/scanner.py b/src/overseer/scanner.py index ce413b8..4ab0e29 100644 --- a/src/overseer/scanner.py +++ b/src/overseer/scanner.py @@ -9,19 +9,39 @@ import time class ScanManager: - def __init__(self): - self.pending_shutdown = False - self.active_scans = [] - self.broadcast_thread = Thread(target=self.__broadcast_thread) - self.broadcast_thread.start() + """Used to manage any ongoing scans - def __broadcast_thread(self): - while not self.pending_shutdown: + Methods + ------- + shutdown() + Shutdown and cleanup the scan monitor thread + get_status() + Get a normalized list of dicts detailing outstanding scans + perform_scan(target) + Initiate a scan on target (Can be a hostname, or ip_addr) + is_ip(target) + Determines if a target is a hostname of ip_addr + """ + + def __init__(self): + """Create instance and start thread""" + self.__pending_shutdown = False + self.__active_scans = [] + self.__monitor_thread = Thread(target=self.__scan_monitor) + self.__monitor_thread.start() + + def __scan_monitor(self): + """Monitors active and completed scans + + Responsible for publishing status to the websockets, cleaning up + completed scans, and updating the database accordingly. + """ + while not self.__pending_shutdown: time.sleep(1) - if len(self.active_scans) == 0: + if len(self.__active_scans) == 0: continue - for scan in self.active_scans: + for scan in self.__active_scans: # WebSocket progress total_progress = (scan.tcp_progress + scan.udp_progress) / 2 results = scan.get_results() @@ -61,13 +81,21 @@ class ScanManager: # Cleanup active scan scan.join() - self.active_scans.remove(scan) + self.__active_scans.remove(scan) def shutdown(self): - self.pending_shutdown = True - self.broadcast_thread.join() + """Shutdown and cleanup the scan monitor thread""" + self.__pending_shutdown = True + self.__monitor_thread.join() def get_status(self): + """Get a normalized list of dicts detailing outstanding scans + + Returns + ------- + list + List of normalized details on current active scans + """ return list( map( lambda x: { @@ -83,11 +111,18 @@ class ScanManager: (x.tcp_progress + x.udp_progress) / 2 ), # noqa: E501 }, - self.active_scans, + self.__active_scans, ) ) def perform_scan(self, target): + """Initiate a scan on target (Can be a hostname, or ip_addr) + + Parameters + ---------- + target : str + Either a hostname or IP address of the endpoint to scan + """ try: target = socket.gethostbyname(target) except socket.error: @@ -104,10 +139,22 @@ class ScanManager: new_scan = Scanner(target, scan_history) new_scan.start() - self.active_scans.append(new_scan) + self.__active_scans.append(new_scan) return scan_history def is_ip(self, target): + """Determines if a target is a hostname of ip_addr + + Parameters + ---------- + target : str + Either a hostname or IP address of the endpoint to scan + + Returns + ------- + bool + Whether the target is an IP or not + """ try: ipaddress.ip_address(target) return True @@ -116,12 +163,46 @@ class ScanManager: class Scanner(Thread): + """Subclass of Thread - used to perform a TCP and UDP scan + + Attributes + ---------- + target : str + The hostname or ip_addr to scan + scan_history : ScanHistory + The ScanHistory DB model reference for this scan + tcp_progress : int + The current progress percentage of the threaded TCP scan + udp_progress : int + The current progress percentage of the threaded UDP scan + tcp_results : list + The current list if ints of open TCP ports + udp_results : list + The current list if ints of open UDP ports + + Methods + ------- + run() + Overridden run method from the Thread superclass. Starts the scan. + get_results() + Returns a normalized list of string of open ports + (E.g. ["53 UDP", "53 TCP"]) + """ + def __init__(self, target, scan_history): + """ + Parameters + ---------- + target : str + The hostname or ip_addr to scan + scan_history : ScanHistory + The ScanHistory DB model reference for this scan + """ Thread.__init__(self) self.target = target self.scan_history = scan_history - self.port_count = 1000 + self.__port_count = 1000 self.tcp_progress = 0 self.udp_progress = 0 @@ -129,11 +210,19 @@ class Scanner(Thread): self.tcp_results = [] self.udp_results = [] - self.udp_payloads = {} + self.__udp_payloads = {} self.__load_nmap_payloads() def __load_nmap_payloads(self): - """Load and parse nmap UDP payloads""" + """Load and parse nmap UDP payloads + + Because of how UDP is designed, we have to test with specific payloads. + This parses nmaps protocol specific payloads [0] in preperation. + + TODO: This should be cached. No need to parse it on every scan. + + [0] https://nmap.org/book/nmap-payloads.html + """ # Open file & remove comments nmap_payloads = os.path.join( @@ -165,25 +254,33 @@ class Scanner(Thread): start_port = int(port_range[0]) end_port = int(port_range[1]) for port_match in range(start_port, end_port + 1): - if port_match not in self.udp_payloads: - self.udp_payloads[port_match] = [] - self.udp_payloads[port_match].extend(match_payloads) + if port_match not in self.__udp_payloads: + self.__udp_payloads[port_match] = [] + self.__udp_payloads[port_match].extend(match_payloads) else: port_match = int(raw_port) - if port_match not in self.udp_payloads: - self.udp_payloads[port_match] = [] - self.udp_payloads[port_match].extend(match_payloads) + if port_match not in self.__udp_payloads: + self.__udp_payloads[port_match] = [] + self.__udp_payloads[port_match].extend(match_payloads) def run(self): + """Overridden run method from the Thread superclass. Starts the scan""" tcp_thread = Thread(target=self.__scan_tcp) udp_thread = Thread(target=self.__scan_udp) tcp_thread.start() udp_thread.start() tcp_thread.join() udp_thread.join() - return {"TCP": self.tcp_results, "UDP": self.udp_results} def get_results(self): + """Returns a normalized list of string of open ports + + Returns + ------- + list + List of open ports (E.g. ["53 UDP", "53 TCP"]) + """ + results = list(map(lambda x: "%s UDP" % x, self.udp_results)) results.extend( list(map(lambda x: "%s TCP" % x, self.tcp_results)) @@ -192,23 +289,24 @@ class Scanner(Thread): return results def __scan_tcp(self): - for port in range(1, self.port_count): + """Threaded TCP Scanner""" + for port in range(1, self.__port_count): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.settimeout(0.1) result = s.connect_ex((self.target, port)) if result == 0: self.tcp_results.append(port) s.close() - self.tcp_progress = round(port / self.port_count * 100) + self.tcp_progress = round(port / self.__port_count * 100) def __scan_udp(self): - for port in range(1, self.port_count): - # print("UDP port %s..." % port) + """Threaded UDP Scanner""" + for port in range(1, self.__port_count): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.settimeout(0.01) payloads = ["\x00"] - if port in self.udp_payloads: - payloads = self.udp_payloads[port] + if port in self.__udp_payloads: + payloads.extend(self.__udp_payloads[port]) for payload in payloads: s.sendto(payload.encode("utf-8"), (self.target, port)) try: @@ -220,16 +318,4 @@ class Scanner(Thread): except socket.timeout: pass s.close() - self.udp_progress = round(port / self.port_count * 100) - - -# FOR TESTING PURPOSES -def main(): - sm = ScanManager() - sm.perform_scan("localhost") - # sm.perform_scan("10.0.20.254") - # sm.perform_scan("10.0.21.20") - - -if __name__ == "__main__": - main() + self.udp_progress = round(port / self.__port_count * 100)