from torinator.models import Base, IPAddress, ExclusionIPAddress, TorIPAddress import threading from os import path from sqlalchemy import create_engine from sqlalchemy.orm import Session class DatabaseConnector: def __init__(self, db_type, data_path=None): """ Parameters ---------- db_type : str Database tyle (e.g. sqlite, memory) data_path : str Directory to store the sqlite file """ if db_type.lower() == "memory": self.__engine = create_engine( "sqlite+pysqlite:///:memory:", echo=False, future=True ) elif db_type.lower() == "sqlite": db_path = path.join(data_path, "torinator.sqlite") self.__engine = create_engine( "sqlite+pysqlite:///%s" % db_path, echo=False, future=True, ) Base.metadata.create_all(self.__engine) self.__lock = threading.Lock() def update_tor_ips(self, ip_strs, list_name): self.__lock.acquire() session = Session(bind=self.__engine) # Add Any New IPs ip_addrs = [] for ip_str in ip_strs: ip_addr = ( session.query(IPAddress).filter(IPAddress.ip == ip_str).first() ) # noqa: E501 if not ip_addr: ip_addrs.append(IPAddress(ip=ip_str)) session.bulk_save_objects(ip_addrs) # Remove Old Tor Relationships session.query(TorIPAddress).filter( TorIPAddress.list_name == list_name ).delete() # noqa: E501 # session.commit() # Add New Tor Relationships tor_ips = [] for ip_str in ip_strs: ip_addr = ( session.query(IPAddress).filter(IPAddress.ip == ip_str).first() ) # noqa: E501 tor_ip = TorIPAddress(ip_id=ip_addr.id, list_name=list_name) tor_ips.append(tor_ip) session.bulk_save_objects(tor_ips) session.commit() session.close() self.__lock.release() def get_ips(self, include_exclusions=False): self.__lock.acquire() session = Session(bind=self.__engine) # Return All or Tor? query = session.query(IPAddress) if not include_exclusions: query = ( query.join(TorIPAddress) .outerjoin(ExclusionIPAddress) .filter(ExclusionIPAddress.id == None) # noqa: E711 ) all_ips = query.all() session.close() self.__lock.release() return list(map(lambda x: x.ip, all_ips)) def get_exclusions(self): self.__lock.acquire() session = Session(bind=self.__engine) # Return All Exclusions all_ips = ( session.query(IPAddress) .join(ExclusionIPAddress, IPAddress.id == ExclusionIPAddress.ip_id) .all() ) session.close() self.__lock.release() return list(map(lambda x: x.ip, all_ips)) def add_exclusion(self, ip_str): self.__lock.acquire() session = Session(bind=self.__engine) # Add New IP ip_row = ( session.query(IPAddress).filter(IPAddress.ip == ip_str).first() ) # noqa: E501 if not ip_row: session.add(IPAddress(ip=ip_str)) session.commit() ip_row = ( session.query(IPAddress).filter(IPAddress.ip == ip_str).first() ) # noqa: E501 # Add New Exclusion ip_exclusion = ( session.query(ExclusionIPAddress) .filter(ExclusionIPAddress.ip_id == ip_row.id) .first() ) if not ip_exclusion: ip_exclusion = ExclusionIPAddress(ip_id=ip_row.id) session.add(ip_exclusion) session.commit() session.close() self.__lock.release() def delete_exclusion(self, ip_str): self.__lock.acquire() session = Session(bind=self.__engine) # Get IP Address ip_row = ( session.query(IPAddress).filter(IPAddress.ip == ip_str).first() ) # noqa: E501 # Delete List Items if ip_row: session.query(ExclusionIPAddress).filter( ExclusionIPAddress.ip_id == ip_row.id ).delete() session.commit() session.close() self.__lock.release()