TORinator/src/torinator/database.py
2022-09-27 22:05:01 -04:00

150 lines
4.4 KiB
Python

from torinator.models import Base, IPAddress, ExclusionIPAddress, TorIPAddress
import threading
from os import path
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
class DatabaseConnector:
def __init__(self, db_type, data_path=None):
"""
Parameters
----------
db_type : str
Database tyle (e.g. sqlite, memory)
data_path : str
Directory to store the sqlite file
"""
if db_type.lower() == "memory":
self.__engine = create_engine(
"sqlite+pysqlite:///:memory:", echo=False, future=True
)
elif db_type.lower() == "sqlite":
db_path = path.join(data_path, "torinator.sqlite")
self.__engine = create_engine(
"sqlite+pysqlite:///%s" % db_path,
echo=False,
future=True,
)
Base.metadata.create_all(self.__engine)
self.__lock = threading.Lock()
def update_tor_ips(self, ip_strs, list_name):
self.__lock.acquire()
session = Session(bind=self.__engine)
# Add Any New IPs
ip_addrs = []
for ip_str in ip_strs:
ip_addr = (
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
) # noqa: E501
if not ip_addr:
ip_addrs.append(IPAddress(ip=ip_str))
session.bulk_save_objects(ip_addrs)
# Remove Old Tor Relationships
session.query(TorIPAddress).filter(
TorIPAddress.list_name == list_name
).delete() # noqa: E501
# session.commit()
# Add New Tor Relationships
tor_ips = []
for ip_str in ip_strs:
ip_addr = (
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
) # noqa: E501
tor_ip = TorIPAddress(ip_id=ip_addr.id, list_name=list_name)
tor_ips.append(tor_ip)
session.bulk_save_objects(tor_ips)
session.commit()
session.close()
self.__lock.release()
def get_ips(self, include_exclusions=False):
self.__lock.acquire()
session = Session(bind=self.__engine)
# Return All or Tor?
query = session.query(IPAddress)
if not include_exclusions:
query = (
query.join(TorIPAddress)
.outerjoin(ExclusionIPAddress)
.filter(ExclusionIPAddress.id == None) # noqa: E711
)
all_ips = query.all()
session.close()
self.__lock.release()
return list(map(lambda x: x.ip, all_ips))
def get_exclusions(self):
self.__lock.acquire()
session = Session(bind=self.__engine)
# Return All Exclusions
all_ips = (
session.query(IPAddress)
.join(ExclusionIPAddress, IPAddress.id == ExclusionIPAddress.ip_id)
.all()
)
session.close()
self.__lock.release()
return list(map(lambda x: x.ip, all_ips))
def add_exclusion(self, ip_str):
self.__lock.acquire()
session = Session(bind=self.__engine)
# Add New IP
ip_row = (
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
) # noqa: E501
if not ip_row:
session.add(IPAddress(ip=ip_str))
session.commit()
ip_row = (
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
) # noqa: E501
# Add New Exclusion
ip_exclusion = (
session.query(ExclusionIPAddress)
.filter(ExclusionIPAddress.ip_id == ip_row.id)
.first()
)
if not ip_exclusion:
ip_exclusion = ExclusionIPAddress(ip_id=ip_row.id)
session.add(ip_exclusion)
session.commit()
session.close()
self.__lock.release()
def delete_exclusion(self, ip_str):
self.__lock.acquire()
session = Session(bind=self.__engine)
# Get IP Address
ip_row = (
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
) # noqa: E501
# Delete List Items
if ip_row:
session.query(ExclusionIPAddress).filter(
ExclusionIPAddress.ip_id == ip_row.id
).delete()
session.commit()
session.close()
self.__lock.release()