From 2e0284c33e707469d20c9ea6b96d580cdf015d9e Mon Sep 17 00:00:00 2001 From: Evan Reichard Date: Wed, 17 Mar 2021 22:44:50 -0400 Subject: [PATCH] Fix Model Relationship, Auto Convert IP Addresses to Integers --- src/overseer/database.py | 47 ++++++++++++++++++++++++++-------------- src/overseer/models.py | 7 ++++-- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/src/overseer/database.py b/src/overseer/database.py index 500d0ef..de0c85d 100644 --- a/src/overseer/database.py +++ b/src/overseer/database.py @@ -1,54 +1,69 @@ import models +import ipaddress from datetime import datetime from os import path from sqlalchemy import create_engine, or_, insert, Table, Column, Integer, String, ForeignKey, DateTime from sqlalchemy.orm import declarative_base, Session class DatabaseConnector: - def __init__(self, data_path): - self.engine = create_engine("sqlite+pysqlite:///%s" % path.join(data_path, "overseer.sqlite"), echo=True, future=True) + def __init__(self, data_path, in_memory=False): + if in_memory: + self.engine = create_engine("sqlite+pysqlite:///:memory:", echo=True, future=True) + else: + self.engine = create_engine("sqlite+pysqlite:///%s" % path.join(data_path, "overseer.sqlite"), echo=True, future=True) models.Base.metadata.create_all(self.engine) def create_scan_result(self, ip_address, scan_results, hostname=None): + int_ip_address = int(ipaddress.ip_address(ip_address)) session = Session(bind=self.engine) # Does an existing target exist? scan_target = session.query(models.ScanTarget).filter(or_( - models.ScanTarget.ip==ip_address, + models.ScanTarget.ip==int_ip_address, models.ScanTarget.hostname==hostname, )).first() + # TODO: Do we need to update hostname? + # Nope, create one if not scan_target: - scan_target = models.ScanTarget(ip=ip_address, hostname=hostname) + scan_target = models.ScanTarget(ip=int_ip_address, hostname=hostname) session.add(scan_target) session.commit() # Create scan history - scan_history = models.ScanHistory(target=scan_target.id, results=",".join(map(str, scan_results)), datetime=datetime.now()) + scan_history = models.ScanHistory(target_id=scan_target.id, results=",".join(map(str, scan_results)), datetime=datetime.now()) session.add(scan_history) session.commit() session.close() - def create_target(self, ip_address, hostname): - stmt = insert(models.ScanTarget).values(ip=ip_address, hostname=hostname) - compiled = stmt.compile() - self.__execute_statement(stmt) - return stmt + def get_scan_results(self, **kwargs): + if len(kwargs.keys() & {'ip_address', 'hostname'}) != 1: + raise ValueError('Missing keyword argument: ip_address or hostname') + hostname = kwargs["hostname"] if "hostname" in kwargs else None + ip_address = kwargs["ip_address"] if "ip_address" in kwargs else None + int_ip_address = int(ipaddress.ip_address(ip_address)) if ip_address else None - def __execute_statement(self, stmt): - with self.engine.connect() as conn: - result = conn.execute(stmt) - conn.commit() - return result + session = Session(bind=self.engine) + + # Get all scan histories + scan_histories = session.query(models.ScanHistory).join(models.ScanHistory.target).filter(or_( + models.ScanTarget.ip==int_ip_address, + models.ScanTarget.hostname==hostname, + )).all() + + session.close() + + return scan_histories # FOR TESTING PURPOSES def main(): db = DatabaseConnector("/Users/evanreichard/Development/git/overseer/src/overseer") - db.create_scan_result(1234577, [5,6,7,8], "test222.com") + db.create_scan_result("1.2.3.4", [5,6,7,8], "test222.com") + db.get_scan_results(ip_address="1.2.3.4") if __name__ == "__main__": main() diff --git a/src/overseer/models.py b/src/overseer/models.py index f875702..e3302b1 100644 --- a/src/overseer/models.py +++ b/src/overseer/models.py @@ -1,5 +1,5 @@ from sqlalchemy import Column, Integer, String, ForeignKey, DateTime -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() @@ -25,7 +25,7 @@ class ScanHistory(Base): id = Column(Integer, primary_key=True, unique=True) # Scan Target Reference - target = Column(Integer, ForeignKey("scan_target.id")) + target_id = Column(Integer, ForeignKey("scan_target.id")) # Results results = Column(String) @@ -33,5 +33,8 @@ class ScanHistory(Base): # DateTime datetime = Column(DateTime()) + # Relationship + target = relationship("ScanTarget", foreign_keys=[target_id]) + def __repr__(self): return f"ScanHistory(id={self.id!r}, target={self.target!r}, results={self.results!r}, datetime={self.datetime!r})"