mirror of https://github.com/OWASP/Nettacker.git
database lock issues fixed using litequeue
This commit is contained in:
parent
ad76ce537a
commit
dc1ea7bd9e
|
|
@ -0,0 +1,48 @@
|
||||||
|
import argparse
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nettacker.database.writer import get_writer
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_sig(signum, frame):
|
||||||
|
writer = get_writer()
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def run():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--once", action="store_true", help="Drain the queue once and exit")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=None, help="Writer batch size")
|
||||||
|
parser.add_argument("--interval", type=float, default=None, help="Writer sleep interval")
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-items", type=int, default=None, help="Max items to process in --once mode"
|
||||||
|
)
|
||||||
|
parser.add_argument("--summary", action="store_true", help="Print a summary after --once")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, _handle_sig)
|
||||||
|
signal.signal(signal.SIGTERM, _handle_sig)
|
||||||
|
|
||||||
|
# apply runtime config
|
||||||
|
from nettacker.database.writer import get_writer_configured, get_stats
|
||||||
|
|
||||||
|
writer = get_writer_configured(batch_size=args.batch_size, interval=args.interval)
|
||||||
|
if args.once:
|
||||||
|
processed = writer.drain_once(max_iterations=args.max_items or 100000)
|
||||||
|
if args.summary:
|
||||||
|
stats = get_stats()
|
||||||
|
print(
|
||||||
|
f"processed={processed} total_processed={stats.get('processed')} queue_size={stats.get('queue_size')}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
writer.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run()
|
||||||
|
|
@ -9,6 +9,7 @@ from nettacker.api.helpers import structure
|
||||||
from nettacker.config import Config
|
from nettacker.config import Config
|
||||||
from nettacker.core.messages import messages
|
from nettacker.core.messages import messages
|
||||||
from nettacker.database.models import HostsLog, Report, TempEvents
|
from nettacker.database.models import HostsLog, Report, TempEvents
|
||||||
|
from nettacker.database.writer import get_writer
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
log = logger.get_logger()
|
log = logger.get_logger()
|
||||||
|
|
@ -95,6 +96,22 @@ def submit_report_to_db(event):
|
||||||
return True if submitted otherwise False
|
return True if submitted otherwise False
|
||||||
"""
|
"""
|
||||||
log.verbose_info(messages("inserting_report_db"))
|
log.verbose_info(messages("inserting_report_db"))
|
||||||
|
writer = get_writer()
|
||||||
|
job = {
|
||||||
|
"action": "insert_report",
|
||||||
|
"payload": {
|
||||||
|
"date": event["date"],
|
||||||
|
"scan_id": event["scan_id"],
|
||||||
|
"report_path_filename": event["options"]["report_path_filename"],
|
||||||
|
"options": event["options"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
if writer.enqueue(job):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback to direct write
|
||||||
session = create_connection()
|
session = create_connection()
|
||||||
session.add(
|
session.add(
|
||||||
Report(
|
Report(
|
||||||
|
|
@ -140,6 +157,14 @@ def submit_logs_to_db(log):
|
||||||
True if success otherwise False
|
True if success otherwise False
|
||||||
"""
|
"""
|
||||||
if isinstance(log, dict):
|
if isinstance(log, dict):
|
||||||
|
writer = get_writer()
|
||||||
|
job = {"action": "insert_hostslog", "payload": log}
|
||||||
|
try:
|
||||||
|
if writer.enqueue(job):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback
|
||||||
session = create_connection()
|
session = create_connection()
|
||||||
session.add(
|
session.add(
|
||||||
HostsLog(
|
HostsLog(
|
||||||
|
|
@ -169,6 +194,14 @@ def submit_temp_logs_to_db(log):
|
||||||
True if success otherwise False
|
True if success otherwise False
|
||||||
"""
|
"""
|
||||||
if isinstance(log, dict):
|
if isinstance(log, dict):
|
||||||
|
writer = get_writer()
|
||||||
|
job = {"action": "insert_tempevent", "payload": log}
|
||||||
|
try:
|
||||||
|
if writer.enqueue(job):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback
|
||||||
session = create_connection()
|
session = create_connection()
|
||||||
session.add(
|
session.add(
|
||||||
TempEvents(
|
TempEvents(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,336 @@
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from multiprocessing import Queue
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from nettacker import logger
|
||||||
|
from nettacker.config import Config
|
||||||
|
from nettacker.database.models import Report, HostsLog, TempEvents
|
||||||
|
|
||||||
|
log = logger.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class DBWriter:
|
||||||
|
def __init__(self, batch_size=100, interval=0.5):
|
||||||
|
self.batch_size = int(batch_size)
|
||||||
|
self.interval = float(interval)
|
||||||
|
self._stop = threading.Event()
|
||||||
|
self._thread = None
|
||||||
|
# total processed across lifetime
|
||||||
|
self._processed_count = 0
|
||||||
|
|
||||||
|
self._use_litequeue = False
|
||||||
|
self._lq = None
|
||||||
|
self._lq_put = None
|
||||||
|
self._lq_get = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import litequeue as _litequeue
|
||||||
|
|
||||||
|
queue_file = Path(Config.path.data_dir) / "nettacker_db_queue.lq"
|
||||||
|
queue_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# try common constructors
|
||||||
|
if hasattr(_litequeue, "LiteQueue"):
|
||||||
|
self._lq = _litequeue.LiteQueue(str(queue_file))
|
||||||
|
elif hasattr(_litequeue, "Queue"):
|
||||||
|
self._lq = _litequeue.Queue(str(queue_file))
|
||||||
|
else:
|
||||||
|
# fallback to a module-level factory
|
||||||
|
try:
|
||||||
|
self._lq = _litequeue.open(str(queue_file))
|
||||||
|
except Exception:
|
||||||
|
self._lq = None
|
||||||
|
|
||||||
|
if self._lq is not None:
|
||||||
|
# prefer destructive pop/get ordering
|
||||||
|
if hasattr(self._lq, "put"):
|
||||||
|
self._lq_put = self._lq.put
|
||||||
|
elif hasattr(self._lq, "push"):
|
||||||
|
self._lq_put = self._lq.push
|
||||||
|
elif hasattr(self._lq, "add"):
|
||||||
|
self._lq_put = self._lq.add
|
||||||
|
|
||||||
|
if hasattr(self._lq, "pop"):
|
||||||
|
self._lq_get = self._lq.pop
|
||||||
|
elif hasattr(self._lq, "get"):
|
||||||
|
# note: some implementations require message_id; prefer pop above
|
||||||
|
self._lq_get = self._lq.get
|
||||||
|
elif hasattr(self._lq, "take"):
|
||||||
|
self._lq_get = self._lq.take
|
||||||
|
|
||||||
|
if self._lq_put and self._lq_get:
|
||||||
|
self._use_litequeue = True
|
||||||
|
except Exception:
|
||||||
|
self._use_litequeue = False
|
||||||
|
|
||||||
|
if not self._use_litequeue:
|
||||||
|
self.queue = Queue()
|
||||||
|
|
||||||
|
db_url = Config.db.as_dict()
|
||||||
|
engine_url = (
|
||||||
|
"sqlite:///{name}".format(**db_url)
|
||||||
|
if Config.db.engine.startswith("sqlite")
|
||||||
|
else Config.db.engine
|
||||||
|
)
|
||||||
|
connect_args = {}
|
||||||
|
if engine_url.startswith("sqlite"):
|
||||||
|
connect_args["check_same_thread"] = False
|
||||||
|
|
||||||
|
self.engine = create_engine(engine_url, connect_args=connect_args, pool_pre_ping=True)
|
||||||
|
if engine_url.startswith("sqlite"):
|
||||||
|
try:
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
return
|
||||||
|
self._stop.clear()
|
||||||
|
self._thread = threading.Thread(target=self._run, name="nettacker-db-writer", daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._stop.set()
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
|
||||||
|
def enqueue(self, job):
|
||||||
|
try:
|
||||||
|
if self._use_litequeue:
|
||||||
|
self._lq_put(json.dumps(job))
|
||||||
|
return True
|
||||||
|
self.queue.put(job)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
log.warn("DBWriter: failed to enqueue job")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _pop_one(self):
|
||||||
|
if self._use_litequeue:
|
||||||
|
try:
|
||||||
|
# litequeue: use peek() to get next message then mark done()
|
||||||
|
# peek returns a Message object or None
|
||||||
|
msg = None
|
||||||
|
if hasattr(self._lq, "peek"):
|
||||||
|
msg = self._lq.peek()
|
||||||
|
elif hasattr(self._lq, "get"):
|
||||||
|
# fallback: try to get next via get with id if available
|
||||||
|
try:
|
||||||
|
# attempt to fetch first ready message via SQL using qsize/read
|
||||||
|
msg = None
|
||||||
|
except Exception:
|
||||||
|
msg = None
|
||||||
|
|
||||||
|
if msg is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if hasattr(msg, "data"):
|
||||||
|
payload = msg.data
|
||||||
|
elif hasattr(msg, "message"):
|
||||||
|
payload = msg.message
|
||||||
|
else:
|
||||||
|
payload = str(msg)
|
||||||
|
|
||||||
|
if isinstance(payload, (bytes, bytearray)):
|
||||||
|
payload = payload.decode()
|
||||||
|
|
||||||
|
# mark message done to remove it from queue
|
||||||
|
try:
|
||||||
|
if hasattr(self._lq, "done") and hasattr(msg, "message_id"):
|
||||||
|
self._lq.done(msg.message_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return json.loads(payload)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return self.queue.get_nowait()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
session = self.Session()
|
||||||
|
pending = []
|
||||||
|
while not self._stop.is_set():
|
||||||
|
try:
|
||||||
|
while len(pending) < self.batch_size:
|
||||||
|
job = self._pop_one()
|
||||||
|
if job is None:
|
||||||
|
break
|
||||||
|
pending.append(job)
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
success_count = 0
|
||||||
|
for job in pending:
|
||||||
|
try:
|
||||||
|
self._apply_job(session, job)
|
||||||
|
success_count += 1
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
try:
|
||||||
|
session.commit()
|
||||||
|
self._processed_count += success_count
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
pending = []
|
||||||
|
else:
|
||||||
|
time.sleep(self.interval)
|
||||||
|
except Exception:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
job = self._pop_one()
|
||||||
|
if job is None:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
self._apply_job(session, job)
|
||||||
|
session.commit()
|
||||||
|
self._processed_count += 1
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
session.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def drain_once(self, max_iterations=100000):
|
||||||
|
"""Consume all queued jobs and return when queue is empty.
|
||||||
|
|
||||||
|
This method is intended for on-demand draining (not long-lived).
|
||||||
|
"""
|
||||||
|
session = self.Session()
|
||||||
|
iterations = 0
|
||||||
|
processed = 0
|
||||||
|
try:
|
||||||
|
while iterations < max_iterations:
|
||||||
|
job = self._pop_one()
|
||||||
|
if job is None:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
self._apply_job(session, job)
|
||||||
|
processed += 1
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
iterations += 1
|
||||||
|
try:
|
||||||
|
session.commit()
|
||||||
|
self._processed_count += processed
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
session.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def _apply_job(self, session, job):
|
||||||
|
action = job.get("action")
|
||||||
|
payload = job.get("payload", {})
|
||||||
|
if action == "insert_report":
|
||||||
|
session.add(
|
||||||
|
Report(
|
||||||
|
date=payload.get("date"),
|
||||||
|
scan_unique_id=payload.get("scan_id"),
|
||||||
|
report_path_filename=payload.get("report_path_filename"),
|
||||||
|
options=json.dumps(payload.get("options", {})),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if action == "insert_hostslog":
|
||||||
|
session.add(
|
||||||
|
HostsLog(
|
||||||
|
target=payload.get("target"),
|
||||||
|
date=payload.get("date"),
|
||||||
|
module_name=payload.get("module_name"),
|
||||||
|
scan_unique_id=payload.get("scan_id"),
|
||||||
|
port=json.dumps(payload.get("port")),
|
||||||
|
event=json.dumps(payload.get("event")),
|
||||||
|
json_event=json.dumps(payload.get("json_event")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if action == "insert_tempevent":
|
||||||
|
session.add(
|
||||||
|
TempEvents(
|
||||||
|
target=payload.get("target"),
|
||||||
|
date=payload.get("date"),
|
||||||
|
module_name=payload.get("module_name"),
|
||||||
|
scan_unique_id=payload.get("scan_id"),
|
||||||
|
event_name=payload.get("event_name"),
|
||||||
|
port=json.dumps(payload.get("port")),
|
||||||
|
event=json.dumps(payload.get("event")),
|
||||||
|
data=json.dumps(payload.get("data")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
log.warn(f"DBWriter: unsupported job action {action}")
|
||||||
|
|
||||||
|
|
||||||
|
# singleton writer
|
||||||
|
_writer = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_writer():
|
||||||
|
global _writer
|
||||||
|
if _writer is None:
|
||||||
|
_writer = DBWriter()
|
||||||
|
try:
|
||||||
|
_writer.start()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return _writer
|
||||||
|
|
||||||
|
|
||||||
|
def get_writer_configured(batch_size=None, interval=None):
|
||||||
|
"""Return singleton writer, applying optional configuration.
|
||||||
|
|
||||||
|
If the writer already exists, provided parameters will update its settings.
|
||||||
|
"""
|
||||||
|
w = get_writer()
|
||||||
|
if batch_size is not None:
|
||||||
|
try:
|
||||||
|
w.batch_size = int(batch_size)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if interval is not None:
|
||||||
|
try:
|
||||||
|
w.interval = float(interval)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return w
|
||||||
|
|
||||||
|
|
||||||
|
def get_stats():
|
||||||
|
w = get_writer()
|
||||||
|
queue_size = None
|
||||||
|
if getattr(w, "_use_litequeue", False) and getattr(w, "_lq", None) is not None:
|
||||||
|
try:
|
||||||
|
if hasattr(w._lq, "qsize"):
|
||||||
|
queue_size = w._lq.qsize()
|
||||||
|
elif hasattr(w._lq, "__len__"):
|
||||||
|
queue_size = len(w._lq)
|
||||||
|
elif hasattr(w._lq, "size"):
|
||||||
|
queue_size = w._lq.size()
|
||||||
|
except Exception:
|
||||||
|
queue_size = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
queue_size = w.queue.qsize()
|
||||||
|
except Exception:
|
||||||
|
queue_size = None
|
||||||
|
return {"processed": getattr(w, "_processed_count", 0), "queue_size": queue_size}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohappyeyeballs"
|
name = "aiohappyeyeballs"
|
||||||
|
|
@ -973,7 +973,7 @@ description = "Read metadata from Python packages"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
markers = "python_version < \"3.10\""
|
markers = "python_version == \"3.9\""
|
||||||
files = [
|
files = [
|
||||||
{file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"},
|
{file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"},
|
||||||
{file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"},
|
{file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"},
|
||||||
|
|
@ -1129,6 +1129,18 @@ files = [
|
||||||
dnspython = "*"
|
dnspython = "*"
|
||||||
ldap3 = ">2.5.0,<2.5.2 || >2.5.2,<2.6 || >2.6"
|
ldap3 = ">2.5.0,<2.5.2 || >2.5.2,<2.6 || >2.6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "litequeue"
|
||||||
|
version = "0.9"
|
||||||
|
description = "Simple queue built on top of SQLite"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "litequeue-0.9-py3-none-any.whl", hash = "sha256:344312748f9d118253ecaa2fb82d432ab6f63dd3bd5c40922a63933bc47cd2e3"},
|
||||||
|
{file = "litequeue-0.9.tar.gz", hash = "sha256:368f56b9de5c76fc6f2adc66177e81a59f545f3b5b95a26cae8562b858914647"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markupsafe"
|
name = "markupsafe"
|
||||||
version = "2.1.5"
|
version = "2.1.5"
|
||||||
|
|
@ -2023,7 +2035,7 @@ files = [
|
||||||
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
|
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
|
||||||
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
||||||
]
|
]
|
||||||
markers = {dev = "python_version < \"3.10\""}
|
markers = {dev = "python_version == \"3.9\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "urllib3"
|
name = "urllib3"
|
||||||
|
|
@ -2254,4 +2266,4 @@ type = ["pytest-mypy"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = "^3.9, <3.13"
|
python-versions = "^3.9, <3.13"
|
||||||
content-hash = "0e1731401cd6acfc4d45ede5e18668530aae6a6b2e359d7dc8d8d635635a1257"
|
content-hash = "aa676fcd9a242a436052e31b320c0c3d99451dc4323d6ae99fbe8f4f49e0d747"
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ release_name = "QUIN"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
nettacker = "nettacker.main:run"
|
nettacker = "nettacker.main:run"
|
||||||
|
nettacker-db-worker = "nettacker.cli.db_worker:run"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9, <3.13"
|
python = "^3.9, <3.13"
|
||||||
|
|
@ -65,6 +66,7 @@ zipp = "^3.19.1"
|
||||||
uvloop = "^0.21.0"
|
uvloop = "^0.21.0"
|
||||||
pymysql = "^1.1.1"
|
pymysql = "^1.1.1"
|
||||||
impacket = "^0.11.0"
|
impacket = "^0.11.0"
|
||||||
|
litequeue = "^0.9"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
ipython = "^8.16.1"
|
ipython = "^8.16.1"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,115 @@
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def start_worker_process(tmp_path):
|
||||||
|
# run worker in separate process
|
||||||
|
env = os.environ.copy()
|
||||||
|
# ensure current project is first on PYTHONPATH
|
||||||
|
env["PYTHONPATH"] = str(Path.cwd()) + os.pathsep + env.get("PYTHONPATH", "")
|
||||||
|
# Pass config through environment
|
||||||
|
data_dir = tmp_path / ".data"
|
||||||
|
env["NETTACKER_DATA_DIR"] = str(data_dir)
|
||||||
|
env["NETTACKER_DB_NAME"] = str(data_dir / "nettacker.db")
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
[
|
||||||
|
env.get("PYTHON_BIN", "python"),
|
||||||
|
"-m",
|
||||||
|
"nettacker.cli.db_worker",
|
||||||
|
"--once", # Process all items and exit
|
||||||
|
"--max-items",
|
||||||
|
"10",
|
||||||
|
"--summary", # Show processing stats
|
||||||
|
],
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
print(f"Started worker process {proc.pid} with data_dir={data_dir}")
|
||||||
|
return proc
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_writes(tmp_path):
|
||||||
|
"""Test that the database writer correctly processes queued jobs and writes to database."""
|
||||||
|
# Create test database
|
||||||
|
data_dir = tmp_path / ".data"
|
||||||
|
data_dir.mkdir()
|
||||||
|
db_path = str(data_dir / "nettacker.db")
|
||||||
|
|
||||||
|
# Create database tables
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
|
|
||||||
|
from nettacker.database.models import Base
|
||||||
|
|
||||||
|
engine = create_engine(f"sqlite:///{db_path}")
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
engine.dispose()
|
||||||
|
|
||||||
|
# Create a writer configured to use the test database
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from nettacker.database.writer import DBWriter
|
||||||
|
|
||||||
|
writer = DBWriter()
|
||||||
|
# Override the database connection to use our test database
|
||||||
|
writer.engine = create_engine(
|
||||||
|
f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, pool_pre_ping=True
|
||||||
|
)
|
||||||
|
# Enable WAL mode for better concurrency
|
||||||
|
with writer.engine.connect() as conn:
|
||||||
|
conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||||
|
conn.commit()
|
||||||
|
writer.Session = sessionmaker(bind=writer.engine)
|
||||||
|
|
||||||
|
# Create test jobs for both report and hosts log
|
||||||
|
jobs = [
|
||||||
|
{
|
||||||
|
"action": "insert_report",
|
||||||
|
"payload": {
|
||||||
|
"date": None,
|
||||||
|
"scan_id": "test-scan",
|
||||||
|
"report_path_filename": str(data_dir / "r.html"),
|
||||||
|
"options": {"report_path_filename": str(data_dir / "r.html")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"action": "insert_hostslog",
|
||||||
|
"payload": {
|
||||||
|
"date": None,
|
||||||
|
"target": "127.0.0.1",
|
||||||
|
"module_name": "m",
|
||||||
|
"scan_id": "test-scan",
|
||||||
|
"port": [],
|
||||||
|
"event": {},
|
||||||
|
"json_event": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Enqueue jobs to the writer
|
||||||
|
for job in jobs:
|
||||||
|
writer.enqueue(job)
|
||||||
|
|
||||||
|
# Process all queued jobs
|
||||||
|
processed_count = writer.drain_once(max_iterations=10)
|
||||||
|
assert processed_count == 2
|
||||||
|
|
||||||
|
# Verify the jobs were written to the database
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
c = conn.cursor()
|
||||||
|
|
||||||
|
c.execute("select count(*) from reports where scan_unique_id = ?", ("test-scan",))
|
||||||
|
report_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
c.execute("select count(*) from scan_events where scan_unique_id = ?", ("test-scan",))
|
||||||
|
hosts_count = c.fetchone()[0]
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert report_count == 1
|
||||||
|
assert hosts_count == 1
|
||||||
Loading…
Reference in New Issue