This commit is contained in:
James 2025-11-20 18:54:07 +07:00 committed by GitHub
commit cc915590c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 581 additions and 4 deletions

View File

@ -0,0 +1,47 @@
import argparse
import signal
import time
from nettacker.database.writer import get_writer
def _handle_sig(signum, frame):
writer = get_writer()
writer.stop()
raise SystemExit(0)
def run():
parser = argparse.ArgumentParser()
parser.add_argument("--once", action="store_true", help="Drain the queue once and exit")
parser.add_argument("--batch-size", type=int, default=None, help="Writer batch size")
parser.add_argument("--interval", type=float, default=None, help="Writer sleep interval")
parser.add_argument(
"--max-items", type=int, default=None, help="Max items to process in --once mode"
)
parser.add_argument("--summary", action="store_true", help="Print a summary after --once")
args = parser.parse_args()
signal.signal(signal.SIGINT, _handle_sig)
signal.signal(signal.SIGTERM, _handle_sig)
# apply runtime config
from nettacker.database.writer import get_writer_configured, get_stats
writer = get_writer_configured(batch_size=args.batch_size, interval=args.interval)
if args.once:
processed = writer.drain_once(max_iterations=args.max_items or 100000)
if args.summary:
stats = get_stats()
print(
f"processed={processed} total_processed={stats.get('processed')} queue_size={stats.get('queue_size')}"
)
return
# Main loop - will be terminated by signal handlers
while True:
time.sleep(1)
if __name__ == "__main__":
run()

View File

@ -9,6 +9,7 @@ from nettacker.api.helpers import structure
from nettacker.config import Config
from nettacker.core.messages import messages
from nettacker.database.models import HostsLog, Report, TempEvents
from nettacker.database.writer import get_writer
config = Config()
log = logger.get_logger()
@ -95,6 +96,22 @@ def submit_report_to_db(event):
return True if submitted otherwise False
"""
log.verbose_info(messages("inserting_report_db"))
writer = get_writer()
job = {
"action": "insert_report",
"payload": {
"date": event["date"],
"scan_id": event["scan_id"],
"report_path_filename": event["options"]["report_path_filename"],
"options": event["options"],
},
}
try:
if writer.enqueue(job):
return True
except Exception:
pass
# Fallback to direct write
session = create_connection()
session.add(
Report(
@ -140,6 +157,14 @@ def submit_logs_to_db(log):
True if success otherwise False
"""
if isinstance(log, dict):
writer = get_writer()
job = {"action": "insert_hostslog", "payload": log}
try:
if writer.enqueue(job):
return True
except Exception:
pass
# Fallback
session = create_connection()
session.add(
HostsLog(
@ -169,6 +194,14 @@ def submit_temp_logs_to_db(log):
True if success otherwise False
"""
if isinstance(log, dict):
writer = get_writer()
job = {"action": "insert_tempevent", "payload": log}
try:
if writer.enqueue(job):
return True
except Exception:
pass
# Fallback
session = create_connection()
session.add(
TempEvents(

View File

@ -0,0 +1,368 @@
import json
import threading
import time
from multiprocessing import Queue
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nettacker import logger
from nettacker.config import Config
from nettacker.database.models import Report, HostsLog, TempEvents
log = logger.get_logger()
class DBWriter:
def __init__(self, batch_size=100, interval=0.5):
self.batch_size = int(batch_size)
self.interval = float(interval)
self._stop = threading.Event()
self._thread = None
# total processed across lifetime
self._processed_count = 0
self._use_litequeue = False
self._lq = None
self._lq_put = None
self._lq_get = None
try:
import litequeue as _litequeue
queue_file = Path(Config.path.data_dir) / "nettacker_db_queue.lq"
queue_file.parent.mkdir(parents=True, exist_ok=True)
# try common constructors
if hasattr(_litequeue, "LiteQueue"):
self._lq = _litequeue.LiteQueue(str(queue_file))
elif hasattr(_litequeue, "Queue"):
self._lq = _litequeue.Queue(str(queue_file))
else:
# fallback to a module-level factory
try:
self._lq = _litequeue.open(str(queue_file))
except Exception:
self._lq = None
if self._lq is not None:
# prefer destructive pop/get ordering
if hasattr(self._lq, "put"):
self._lq_put = self._lq.put
elif hasattr(self._lq, "push"):
self._lq_put = self._lq.push
elif hasattr(self._lq, "add"):
self._lq_put = self._lq.add
if hasattr(self._lq, "pop"):
self._lq_get = self._lq.pop
elif hasattr(self._lq, "get"):
# note: some implementations require message_id; prefer pop above
self._lq_get = self._lq.get
elif hasattr(self._lq, "take"):
self._lq_get = self._lq.take
if self._lq_put and self._lq_get:
self._use_litequeue = True
except Exception:
self._use_litequeue = False
if not self._use_litequeue:
self.queue = Queue()
db_url = Config.db.as_dict()
engine_url = (
"sqlite:///{name}".format(**db_url)
if Config.db.engine.startswith("sqlite")
else Config.db.engine
)
connect_args = {}
if engine_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
self.engine = create_engine(engine_url, connect_args=connect_args, pool_pre_ping=True)
if engine_url.startswith("sqlite"):
try:
with self.engine.connect() as conn:
conn.execute("PRAGMA journal_mode=WAL")
except Exception:
pass
self.Session = sessionmaker(bind=self.engine)
def start(self):
if self._thread and self._thread.is_alive():
return
self._stop.clear()
self._thread = threading.Thread(target=self._run, name="nettacker-db-writer", daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=5)
def enqueue(self, job):
try:
if self._use_litequeue:
self._lq_put(json.dumps(job))
return True
self.queue.put(job)
return True
except Exception:
log.warn("DBWriter: failed to enqueue job")
return False
def _acknowledge_message(self, message_id):
"""Acknowledge a successfully processed message."""
if self._use_litequeue and message_id is not None:
try:
if hasattr(self._lq, "done"):
self._lq.done(message_id)
except Exception:
pass
def _pop_one(self):
if self._use_litequeue:
try:
# litequeue: use pop() to get and lock message, then mark done() AFTER processing
msg = None
if hasattr(self._lq, "pop"):
msg = self._lq.pop()
elif hasattr(self._lq, "get"):
# fallback: try to get next via get
msg = self._lq.get()
if msg is None:
return None
if hasattr(msg, "data"):
payload = msg.data
elif hasattr(msg, "message"):
payload = msg.message
else:
payload = str(msg)
if isinstance(payload, (bytes, bytearray)):
payload = payload.decode()
# Return both the payload and message_id for deferred acknowledgment
job_data = json.loads(payload)
if hasattr(msg, "message_id"):
return {"data": job_data, "message_id": msg.message_id}
else:
return {"data": job_data, "message_id": None}
except Exception:
return None
else:
try:
job_data = self.queue.get_nowait()
return {"data": job_data, "message_id": None}
except Exception:
return None
def _run(self):
pending = []
while not self._stop.is_set():
try:
while len(pending) < self.batch_size:
job = self._pop_one()
if job is None:
break
pending.append(job)
if pending:
# Process each job individually with immediate commit
for job in pending:
job_session = self.Session() # Fresh session per job
try:
# Handle both litequeue format {"data": job, "message_id": id} and direct job
job_data = (
job["data"] if isinstance(job, dict) and "data" in job else job
)
self._apply_job(job_session, job_data)
job_session.commit() # Immediate commit per job
self._processed_count += 1
# Only acknowledge after successful commit
if isinstance(job, dict) and "message_id" in job:
self._acknowledge_message(job["message_id"])
except Exception as e:
job_session.rollback()
log.error(f"Failed to process job: {e}")
# Job is not acknowledged, so it can be retried
finally:
job_session.close()
pending = []
else:
time.sleep(self.interval)
except Exception:
time.sleep(0.1)
# Final cleanup: process any remaining jobs individually
try:
while True:
job = self._pop_one()
if job is None:
break
# Process final job individually with immediate commit
cleanup_session = self.Session()
try:
job_data = job["data"] if isinstance(job, dict) and "data" in job else job
self._apply_job(cleanup_session, job_data)
cleanup_session.commit()
self._processed_count += 1
# Only acknowledge after successful commit
if isinstance(job, dict) and "message_id" in job:
self._acknowledge_message(job["message_id"])
except Exception as e:
cleanup_session.rollback()
log.error(f"Failed to process cleanup job: {e}")
finally:
cleanup_session.close()
except Exception:
pass
def drain_once(self, max_iterations=100000):
"""Consume all queued jobs and return when queue is empty.
This method is intended for on-demand draining (not long-lived).
"""
iterations = 0
processed = 0
try:
while iterations < max_iterations:
job = self._pop_one()
if job is None:
break
# Process each job individually with immediate commit for durability
job_session = self.Session() # Fresh session per job
try:
# Handle both litequeue format {"data": job, "message_id": id} and direct job
job_data = job["data"] if isinstance(job, dict) and "data" in job else job
self._apply_job(job_session, job_data)
job_session.commit() # Immediate commit per job
processed += 1
self._processed_count += 1
# Only acknowledge after successful commit
if isinstance(job, dict) and "message_id" in job:
self._acknowledge_message(job["message_id"])
except Exception as e:
job_session.rollback()
log.error(f"Failed to process job during drain: {e}")
# Job is not acknowledged, so it can be retried
finally:
job_session.close()
iterations += 1
except Exception as e:
log.error(f"Error during drain operation: {e}")
return processed
def _apply_job(self, session, job):
action = job.get("action")
payload = job.get("payload", {})
if action == "insert_report":
session.add(
Report(
date=payload.get("date"),
scan_unique_id=payload.get("scan_id"),
report_path_filename=payload.get("report_path_filename"),
options=json.dumps(payload.get("options", {})),
)
)
return
if action == "insert_hostslog":
session.add(
HostsLog(
target=payload.get("target"),
date=payload.get("date"),
module_name=payload.get("module_name"),
scan_unique_id=payload.get("scan_id"),
port=json.dumps(payload.get("port")),
event=json.dumps(payload.get("event")),
json_event=json.dumps(payload.get("json_event")),
)
)
return
if action == "insert_tempevent":
session.add(
TempEvents(
target=payload.get("target"),
date=payload.get("date"),
module_name=payload.get("module_name"),
scan_unique_id=payload.get("scan_id"),
event_name=payload.get("event_name"),
port=json.dumps(payload.get("port")),
event=json.dumps(payload.get("event")),
data=json.dumps(payload.get("data")),
)
)
return
log.warn(f"DBWriter: unsupported job action {action}")
# singleton writer
_writer = None
def get_writer():
global _writer
if _writer is None:
_writer = DBWriter()
try:
_writer.start()
except Exception:
pass
return _writer
def get_writer_configured(batch_size=None, interval=None):
"""Return singleton writer, applying optional configuration.
If the writer already exists, provided parameters will update its settings.
"""
w = get_writer()
if batch_size is not None:
try:
w.batch_size = int(batch_size)
except Exception:
pass
if interval is not None:
try:
w.interval = float(interval)
except Exception:
pass
return w
def get_stats():
w = get_writer()
queue_size = None
if getattr(w, "_use_litequeue", False) and getattr(w, "_lq", None) is not None:
try:
if hasattr(w._lq, "qsize"):
queue_size = w._lq.qsize()
elif hasattr(w._lq, "__len__"):
queue_size = len(w._lq)
elif hasattr(w._lq, "size"):
queue_size = w._lq.size()
except Exception:
queue_size = None
else:
try:
queue_size = w.queue.qsize()
except Exception:
queue_size = None
return {"processed": getattr(w, "_processed_count", 0), "queue_size": queue_size}

20
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -973,7 +973,7 @@ description = "Read metadata from Python packages"
optional = false
python-versions = ">=3.8"
groups = ["main"]
markers = "python_version < \"3.10\""
markers = "python_version == \"3.9\""
files = [
{file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"},
{file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"},
@ -1129,6 +1129,18 @@ files = [
dnspython = "*"
ldap3 = ">2.5.0,<2.5.2 || >2.5.2,<2.6 || >2.6"
[[package]]
name = "litequeue"
version = "0.9"
description = "Simple queue built on top of SQLite"
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "litequeue-0.9-py3-none-any.whl", hash = "sha256:344312748f9d118253ecaa2fb82d432ab6f63dd3bd5c40922a63933bc47cd2e3"},
{file = "litequeue-0.9.tar.gz", hash = "sha256:368f56b9de5c76fc6f2adc66177e81a59f545f3b5b95a26cae8562b858914647"},
]
[[package]]
name = "markupsafe"
version = "2.1.5"
@ -2023,7 +2035,7 @@ files = [
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
]
markers = {dev = "python_version < \"3.10\""}
markers = {dev = "python_version == \"3.9\""}
[[package]]
name = "urllib3"
@ -2254,4 +2266,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = "^3.9, <3.13"
content-hash = "0e1731401cd6acfc4d45ede5e18668530aae6a6b2e359d7dc8d8d635635a1257"
content-hash = "aa676fcd9a242a436052e31b320c0c3d99451dc4323d6ae99fbe8f4f49e0d747"

View File

@ -43,6 +43,7 @@ release_name = "QUIN"
[tool.poetry.scripts]
nettacker = "nettacker.main:run"
nettacker-db-worker = "nettacker.cli.db_worker:run"
[tool.poetry.dependencies]
python = "^3.9, <3.13"
@ -65,6 +66,7 @@ zipp = "^3.19.1"
uvloop = "^0.21.0"
pymysql = "^1.1.1"
impacket = "^0.11.0"
litequeue = "^0.9"
[tool.poetry.group.dev.dependencies]
ipython = "^8.16.1"

View File

@ -0,0 +1,115 @@
import os
import subprocess
from pathlib import Path
def start_worker_process(tmp_path):
# run worker in separate process
env = os.environ.copy()
# ensure current project is first on PYTHONPATH
env["PYTHONPATH"] = str(Path.cwd()) + os.pathsep + env.get("PYTHONPATH", "")
# Pass config through environment
data_dir = tmp_path / ".data"
env["NETTACKER_DATA_DIR"] = str(data_dir)
env["NETTACKER_DB_NAME"] = str(data_dir / "nettacker.db")
proc = subprocess.Popen(
[
env.get("PYTHON_BIN", "python"),
"-m",
"nettacker.cli.db_worker",
"--once", # Process all items and exit
"--max-items",
"10",
"--summary", # Show processing stats
],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
print(f"Started worker process {proc.pid} with data_dir={data_dir}")
return proc
def test_worker_writes(tmp_path):
"""Test that the database writer correctly processes queued jobs and writes to database."""
# Create test database
data_dir = tmp_path / ".data"
data_dir.mkdir()
db_path = str(data_dir / "nettacker.db")
# Create database tables
from sqlalchemy import create_engine, text
from nettacker.database.models import Base
engine = create_engine(f"sqlite:///{db_path}")
Base.metadata.create_all(engine)
engine.dispose()
# Create a writer configured to use the test database
from sqlalchemy.orm import sessionmaker
from nettacker.database.writer import DBWriter
writer = DBWriter()
# Override the database connection to use our test database
writer.engine = create_engine(
f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, pool_pre_ping=True
)
# Enable WAL mode for better concurrency
with writer.engine.connect() as conn:
conn.execute(text("PRAGMA journal_mode=WAL"))
conn.commit()
writer.Session = sessionmaker(bind=writer.engine)
# Create test jobs for both report and hosts log
jobs = [
{
"action": "insert_report",
"payload": {
"date": None,
"scan_id": "test-scan",
"report_path_filename": str(data_dir / "r.html"),
"options": {"report_path_filename": str(data_dir / "r.html")},
},
},
{
"action": "insert_hostslog",
"payload": {
"date": None,
"target": "127.0.0.1",
"module_name": "m",
"scan_id": "test-scan",
"port": [],
"event": {},
"json_event": {},
},
},
]
# Enqueue jobs to the writer
for job in jobs:
writer.enqueue(job)
# Process all queued jobs
processed_count = writer.drain_once(max_iterations=10)
assert processed_count == 2
# Verify the jobs were written to the database
import sqlite3
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute("select count(*) from reports where scan_unique_id = ?", ("test-scan",))
report_count = c.fetchone()[0]
c.execute("select count(*) from scan_events where scan_unique_id = ?", ("test-scan",))
hosts_count = c.fetchone()[0]
conn.close()
assert report_count == 1
assert hosts_count == 1