150 lines
4.4 KiB
Python
150 lines
4.4 KiB
Python
from torinator.models import Base, IPAddress, ExclusionIPAddress, TorIPAddress
|
|
import threading
|
|
from os import path
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
class DatabaseConnector:
|
|
def __init__(self, db_type, data_path=None):
|
|
"""
|
|
Parameters
|
|
----------
|
|
db_type : str
|
|
Database tyle (e.g. sqlite, memory)
|
|
data_path : str
|
|
Directory to store the sqlite file
|
|
"""
|
|
if db_type.lower() == "memory":
|
|
self.__engine = create_engine(
|
|
"sqlite+pysqlite:///:memory:", echo=False, future=True
|
|
)
|
|
elif db_type.lower() == "sqlite":
|
|
db_path = path.join(data_path, "torinator.sqlite")
|
|
self.__engine = create_engine(
|
|
"sqlite+pysqlite:///%s" % db_path,
|
|
echo=False,
|
|
future=True,
|
|
)
|
|
Base.metadata.create_all(self.__engine)
|
|
self.__lock = threading.Lock()
|
|
|
|
def update_tor_ips(self, ip_strs, list_name):
|
|
self.__lock.acquire()
|
|
session = Session(bind=self.__engine)
|
|
|
|
# Add Any New IPs
|
|
ip_addrs = []
|
|
for ip_str in ip_strs:
|
|
ip_addr = (
|
|
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
|
|
) # noqa: E501
|
|
if not ip_addr:
|
|
ip_addrs.append(IPAddress(ip=ip_str))
|
|
|
|
session.bulk_save_objects(ip_addrs)
|
|
|
|
# Remove Old Tor Relationships
|
|
session.query(TorIPAddress).filter(
|
|
TorIPAddress.list_name == list_name
|
|
).delete() # noqa: E501
|
|
# session.commit()
|
|
|
|
# Add New Tor Relationships
|
|
tor_ips = []
|
|
for ip_str in ip_strs:
|
|
ip_addr = (
|
|
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
|
|
) # noqa: E501
|
|
tor_ip = TorIPAddress(ip_id=ip_addr.id, list_name=list_name)
|
|
tor_ips.append(tor_ip)
|
|
|
|
session.bulk_save_objects(tor_ips)
|
|
session.commit()
|
|
|
|
session.close()
|
|
self.__lock.release()
|
|
|
|
def get_ips(self, include_exclusions=False):
|
|
self.__lock.acquire()
|
|
session = Session(bind=self.__engine)
|
|
|
|
# Return All or Tor?
|
|
query = session.query(IPAddress)
|
|
if not include_exclusions:
|
|
query = (
|
|
query.join(TorIPAddress)
|
|
.outerjoin(ExclusionIPAddress)
|
|
.filter(ExclusionIPAddress.id == None) # noqa: E711
|
|
)
|
|
all_ips = query.all()
|
|
|
|
session.close()
|
|
self.__lock.release()
|
|
|
|
return list(map(lambda x: x.ip, all_ips))
|
|
|
|
def get_exclusions(self):
|
|
self.__lock.acquire()
|
|
session = Session(bind=self.__engine)
|
|
|
|
# Return All Exclusions
|
|
all_ips = (
|
|
session.query(IPAddress)
|
|
.join(ExclusionIPAddress, IPAddress.id == ExclusionIPAddress.ip_id)
|
|
.all()
|
|
)
|
|
|
|
session.close()
|
|
self.__lock.release()
|
|
|
|
return list(map(lambda x: x.ip, all_ips))
|
|
|
|
def add_exclusion(self, ip_str):
|
|
self.__lock.acquire()
|
|
session = Session(bind=self.__engine)
|
|
|
|
# Add New IP
|
|
ip_row = (
|
|
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
|
|
) # noqa: E501
|
|
if not ip_row:
|
|
session.add(IPAddress(ip=ip_str))
|
|
session.commit()
|
|
ip_row = (
|
|
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
|
|
) # noqa: E501
|
|
|
|
# Add New Exclusion
|
|
ip_exclusion = (
|
|
session.query(ExclusionIPAddress)
|
|
.filter(ExclusionIPAddress.ip_id == ip_row.id)
|
|
.first()
|
|
)
|
|
if not ip_exclusion:
|
|
ip_exclusion = ExclusionIPAddress(ip_id=ip_row.id)
|
|
session.add(ip_exclusion)
|
|
|
|
session.commit()
|
|
session.close()
|
|
self.__lock.release()
|
|
|
|
def delete_exclusion(self, ip_str):
|
|
self.__lock.acquire()
|
|
session = Session(bind=self.__engine)
|
|
|
|
# Get IP Address
|
|
ip_row = (
|
|
session.query(IPAddress).filter(IPAddress.ip == ip_str).first()
|
|
) # noqa: E501
|
|
|
|
# Delete List Items
|
|
if ip_row:
|
|
session.query(ExclusionIPAddress).filter(
|
|
ExclusionIPAddress.ip_id == ip_row.id
|
|
).delete()
|
|
session.commit()
|
|
|
|
session.close()
|
|
self.__lock.release()
|