Fix Model Relationship, Auto Convert IP Addresses to Integers

This commit is contained in:
Evan Reichard 2021-03-17 22:44:50 -04:00
parent 8ff4aaebbe
commit 2e0284c33e
2 changed files with 36 additions and 18 deletions

View File

@ -1,54 +1,69 @@
import models import models
import ipaddress
from datetime import datetime from datetime import datetime
from os import path from os import path
from sqlalchemy import create_engine, or_, insert, Table, Column, Integer, String, ForeignKey, DateTime from sqlalchemy import create_engine, or_, insert, Table, Column, Integer, String, ForeignKey, DateTime
from sqlalchemy.orm import declarative_base, Session from sqlalchemy.orm import declarative_base, Session
class DatabaseConnector: class DatabaseConnector:
def __init__(self, data_path): def __init__(self, data_path, in_memory=False):
if in_memory:
self.engine = create_engine("sqlite+pysqlite:///:memory:", echo=True, future=True)
else:
self.engine = create_engine("sqlite+pysqlite:///%s" % path.join(data_path, "overseer.sqlite"), echo=True, future=True) self.engine = create_engine("sqlite+pysqlite:///%s" % path.join(data_path, "overseer.sqlite"), echo=True, future=True)
models.Base.metadata.create_all(self.engine) models.Base.metadata.create_all(self.engine)
def create_scan_result(self, ip_address, scan_results, hostname=None): def create_scan_result(self, ip_address, scan_results, hostname=None):
int_ip_address = int(ipaddress.ip_address(ip_address))
session = Session(bind=self.engine) session = Session(bind=self.engine)
# Does an existing target exist? # Does an existing target exist?
scan_target = session.query(models.ScanTarget).filter(or_( scan_target = session.query(models.ScanTarget).filter(or_(
models.ScanTarget.ip==ip_address, models.ScanTarget.ip==int_ip_address,
models.ScanTarget.hostname==hostname, models.ScanTarget.hostname==hostname,
)).first() )).first()
# TODO: Do we need to update hostname?
# Nope, create one # Nope, create one
if not scan_target: if not scan_target:
scan_target = models.ScanTarget(ip=ip_address, hostname=hostname) scan_target = models.ScanTarget(ip=int_ip_address, hostname=hostname)
session.add(scan_target) session.add(scan_target)
session.commit() session.commit()
# Create scan history # Create scan history
scan_history = models.ScanHistory(target=scan_target.id, results=",".join(map(str, scan_results)), datetime=datetime.now()) scan_history = models.ScanHistory(target_id=scan_target.id, results=",".join(map(str, scan_results)), datetime=datetime.now())
session.add(scan_history) session.add(scan_history)
session.commit() session.commit()
session.close() session.close()
def create_target(self, ip_address, hostname): def get_scan_results(self, **kwargs):
stmt = insert(models.ScanTarget).values(ip=ip_address, hostname=hostname) if len(kwargs.keys() & {'ip_address', 'hostname'}) != 1:
compiled = stmt.compile() raise ValueError('Missing keyword argument: ip_address or hostname')
self.__execute_statement(stmt)
return stmt
hostname = kwargs["hostname"] if "hostname" in kwargs else None
ip_address = kwargs["ip_address"] if "ip_address" in kwargs else None
int_ip_address = int(ipaddress.ip_address(ip_address)) if ip_address else None
def __execute_statement(self, stmt): session = Session(bind=self.engine)
with self.engine.connect() as conn:
result = conn.execute(stmt) # Get all scan histories
conn.commit() scan_histories = session.query(models.ScanHistory).join(models.ScanHistory.target).filter(or_(
return result models.ScanTarget.ip==int_ip_address,
models.ScanTarget.hostname==hostname,
)).all()
session.close()
return scan_histories
# FOR TESTING PURPOSES # FOR TESTING PURPOSES
def main(): def main():
db = DatabaseConnector("/Users/evanreichard/Development/git/overseer/src/overseer") db = DatabaseConnector("/Users/evanreichard/Development/git/overseer/src/overseer")
db.create_scan_result(1234577, [5,6,7,8], "test222.com") db.create_scan_result("1.2.3.4", [5,6,7,8], "test222.com")
db.get_scan_results(ip_address="1.2.3.4")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,5 +1,5 @@
from sqlalchemy import Column, Integer, String, ForeignKey, DateTime from sqlalchemy import Column, Integer, String, ForeignKey, DateTime
from sqlalchemy.orm import declarative_base from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base() Base = declarative_base()
@ -25,7 +25,7 @@ class ScanHistory(Base):
id = Column(Integer, primary_key=True, unique=True) id = Column(Integer, primary_key=True, unique=True)
# Scan Target Reference # Scan Target Reference
target = Column(Integer, ForeignKey("scan_target.id")) target_id = Column(Integer, ForeignKey("scan_target.id"))
# Results # Results
results = Column(String) results = Column(String)
@ -33,5 +33,8 @@ class ScanHistory(Base):
# DateTime # DateTime
datetime = Column(DateTime()) datetime = Column(DateTime())
# Relationship
target = relationship("ScanTarget", foreign_keys=[target_id])
def __repr__(self): def __repr__(self):
return f"ScanHistory(id={self.id!r}, target={self.target!r}, results={self.results!r}, datetime={self.datetime!r})" return f"ScanHistory(id={self.id!r}, target={self.target!r}, results={self.results!r}, datetime={self.datetime!r})"