database lock issues fixed using litequeue

This commit is contained in:
DavdaJames 2025-10-06 03:14:58 +05:30
parent ad76ce537a
commit dc1ea7bd9e
6 changed files with 550 additions and 4 deletions

View File

@ -0,0 +1,48 @@
import argparse
import signal
import time
from nettacker.database.writer import get_writer
def _handle_sig(signum, frame):
writer = get_writer()
writer.stop()
def run():
parser = argparse.ArgumentParser()
parser.add_argument("--once", action="store_true", help="Drain the queue once and exit")
parser.add_argument("--batch-size", type=int, default=None, help="Writer batch size")
parser.add_argument("--interval", type=float, default=None, help="Writer sleep interval")
parser.add_argument(
"--max-items", type=int, default=None, help="Max items to process in --once mode"
)
parser.add_argument("--summary", action="store_true", help="Print a summary after --once")
args = parser.parse_args()
signal.signal(signal.SIGINT, _handle_sig)
signal.signal(signal.SIGTERM, _handle_sig)
# apply runtime config
from nettacker.database.writer import get_writer_configured, get_stats
writer = get_writer_configured(batch_size=args.batch_size, interval=args.interval)
if args.once:
processed = writer.drain_once(max_iterations=args.max_items or 100000)
if args.summary:
stats = get_stats()
print(
f"processed={processed} total_processed={stats.get('processed')} queue_size={stats.get('queue_size')}"
)
return
try:
while True:
time.sleep(1)
except KeyboardInterrupt:
writer.stop()
if __name__ == "__main__":
run()

View File

@ -9,6 +9,7 @@ from nettacker.api.helpers import structure
from nettacker.config import Config from nettacker.config import Config
from nettacker.core.messages import messages from nettacker.core.messages import messages
from nettacker.database.models import HostsLog, Report, TempEvents from nettacker.database.models import HostsLog, Report, TempEvents
from nettacker.database.writer import get_writer
config = Config() config = Config()
log = logger.get_logger() log = logger.get_logger()
@ -95,6 +96,22 @@ def submit_report_to_db(event):
return True if submitted otherwise False return True if submitted otherwise False
""" """
log.verbose_info(messages("inserting_report_db")) log.verbose_info(messages("inserting_report_db"))
writer = get_writer()
job = {
"action": "insert_report",
"payload": {
"date": event["date"],
"scan_id": event["scan_id"],
"report_path_filename": event["options"]["report_path_filename"],
"options": event["options"],
},
}
try:
if writer.enqueue(job):
return True
except Exception:
pass
# Fallback to direct write
session = create_connection() session = create_connection()
session.add( session.add(
Report( Report(
@ -140,6 +157,14 @@ def submit_logs_to_db(log):
True if success otherwise False True if success otherwise False
""" """
if isinstance(log, dict): if isinstance(log, dict):
writer = get_writer()
job = {"action": "insert_hostslog", "payload": log}
try:
if writer.enqueue(job):
return True
except Exception:
pass
# Fallback
session = create_connection() session = create_connection()
session.add( session.add(
HostsLog( HostsLog(
@ -169,6 +194,14 @@ def submit_temp_logs_to_db(log):
True if success otherwise False True if success otherwise False
""" """
if isinstance(log, dict): if isinstance(log, dict):
writer = get_writer()
job = {"action": "insert_tempevent", "payload": log}
try:
if writer.enqueue(job):
return True
except Exception:
pass
# Fallback
session = create_connection() session = create_connection()
session.add( session.add(
TempEvents( TempEvents(

View File

@ -0,0 +1,336 @@
import json
import threading
import time
from multiprocessing import Queue
from pathlib import Path
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nettacker import logger
from nettacker.config import Config
from nettacker.database.models import Report, HostsLog, TempEvents
log = logger.get_logger()
class DBWriter:
def __init__(self, batch_size=100, interval=0.5):
self.batch_size = int(batch_size)
self.interval = float(interval)
self._stop = threading.Event()
self._thread = None
# total processed across lifetime
self._processed_count = 0
self._use_litequeue = False
self._lq = None
self._lq_put = None
self._lq_get = None
try:
import litequeue as _litequeue
queue_file = Path(Config.path.data_dir) / "nettacker_db_queue.lq"
queue_file.parent.mkdir(parents=True, exist_ok=True)
# try common constructors
if hasattr(_litequeue, "LiteQueue"):
self._lq = _litequeue.LiteQueue(str(queue_file))
elif hasattr(_litequeue, "Queue"):
self._lq = _litequeue.Queue(str(queue_file))
else:
# fallback to a module-level factory
try:
self._lq = _litequeue.open(str(queue_file))
except Exception:
self._lq = None
if self._lq is not None:
# prefer destructive pop/get ordering
if hasattr(self._lq, "put"):
self._lq_put = self._lq.put
elif hasattr(self._lq, "push"):
self._lq_put = self._lq.push
elif hasattr(self._lq, "add"):
self._lq_put = self._lq.add
if hasattr(self._lq, "pop"):
self._lq_get = self._lq.pop
elif hasattr(self._lq, "get"):
# note: some implementations require message_id; prefer pop above
self._lq_get = self._lq.get
elif hasattr(self._lq, "take"):
self._lq_get = self._lq.take
if self._lq_put and self._lq_get:
self._use_litequeue = True
except Exception:
self._use_litequeue = False
if not self._use_litequeue:
self.queue = Queue()
db_url = Config.db.as_dict()
engine_url = (
"sqlite:///{name}".format(**db_url)
if Config.db.engine.startswith("sqlite")
else Config.db.engine
)
connect_args = {}
if engine_url.startswith("sqlite"):
connect_args["check_same_thread"] = False
self.engine = create_engine(engine_url, connect_args=connect_args, pool_pre_ping=True)
if engine_url.startswith("sqlite"):
try:
with self.engine.connect() as conn:
conn.execute("PRAGMA journal_mode=WAL")
except Exception:
pass
self.Session = sessionmaker(bind=self.engine)
def start(self):
if self._thread and self._thread.is_alive():
return
self._stop.clear()
self._thread = threading.Thread(target=self._run, name="nettacker-db-writer", daemon=True)
self._thread.start()
def stop(self):
self._stop.set()
if self._thread:
self._thread.join(timeout=5)
def enqueue(self, job):
try:
if self._use_litequeue:
self._lq_put(json.dumps(job))
return True
self.queue.put(job)
return True
except Exception:
log.warn("DBWriter: failed to enqueue job")
return False
def _pop_one(self):
if self._use_litequeue:
try:
# litequeue: use peek() to get next message then mark done()
# peek returns a Message object or None
msg = None
if hasattr(self._lq, "peek"):
msg = self._lq.peek()
elif hasattr(self._lq, "get"):
# fallback: try to get next via get with id if available
try:
# attempt to fetch first ready message via SQL using qsize/read
msg = None
except Exception:
msg = None
if msg is None:
return None
if hasattr(msg, "data"):
payload = msg.data
elif hasattr(msg, "message"):
payload = msg.message
else:
payload = str(msg)
if isinstance(payload, (bytes, bytearray)):
payload = payload.decode()
# mark message done to remove it from queue
try:
if hasattr(self._lq, "done") and hasattr(msg, "message_id"):
self._lq.done(msg.message_id)
except Exception:
pass
return json.loads(payload)
except Exception:
return None
else:
try:
return self.queue.get_nowait()
except Exception:
return None
def _run(self):
session = self.Session()
pending = []
while not self._stop.is_set():
try:
while len(pending) < self.batch_size:
job = self._pop_one()
if job is None:
break
pending.append(job)
if pending:
success_count = 0
for job in pending:
try:
self._apply_job(session, job)
success_count += 1
except Exception:
session.rollback()
try:
session.commit()
self._processed_count += success_count
except Exception:
session.rollback()
pending = []
else:
time.sleep(self.interval)
except Exception:
time.sleep(0.1)
try:
while True:
job = self._pop_one()
if job is None:
break
try:
self._apply_job(session, job)
session.commit()
self._processed_count += 1
except Exception:
session.rollback()
except Exception:
pass
try:
session.close()
except Exception:
pass
def drain_once(self, max_iterations=100000):
"""Consume all queued jobs and return when queue is empty.
This method is intended for on-demand draining (not long-lived).
"""
session = self.Session()
iterations = 0
processed = 0
try:
while iterations < max_iterations:
job = self._pop_one()
if job is None:
break
try:
self._apply_job(session, job)
processed += 1
except Exception:
session.rollback()
iterations += 1
try:
session.commit()
self._processed_count += processed
except Exception:
session.rollback()
finally:
try:
session.close()
except Exception:
pass
return processed
def _apply_job(self, session, job):
action = job.get("action")
payload = job.get("payload", {})
if action == "insert_report":
session.add(
Report(
date=payload.get("date"),
scan_unique_id=payload.get("scan_id"),
report_path_filename=payload.get("report_path_filename"),
options=json.dumps(payload.get("options", {})),
)
)
return
if action == "insert_hostslog":
session.add(
HostsLog(
target=payload.get("target"),
date=payload.get("date"),
module_name=payload.get("module_name"),
scan_unique_id=payload.get("scan_id"),
port=json.dumps(payload.get("port")),
event=json.dumps(payload.get("event")),
json_event=json.dumps(payload.get("json_event")),
)
)
return
if action == "insert_tempevent":
session.add(
TempEvents(
target=payload.get("target"),
date=payload.get("date"),
module_name=payload.get("module_name"),
scan_unique_id=payload.get("scan_id"),
event_name=payload.get("event_name"),
port=json.dumps(payload.get("port")),
event=json.dumps(payload.get("event")),
data=json.dumps(payload.get("data")),
)
)
return
log.warn(f"DBWriter: unsupported job action {action}")
# singleton writer
_writer = None
def get_writer():
global _writer
if _writer is None:
_writer = DBWriter()
try:
_writer.start()
except Exception:
pass
return _writer
def get_writer_configured(batch_size=None, interval=None):
"""Return singleton writer, applying optional configuration.
If the writer already exists, provided parameters will update its settings.
"""
w = get_writer()
if batch_size is not None:
try:
w.batch_size = int(batch_size)
except Exception:
pass
if interval is not None:
try:
w.interval = float(interval)
except Exception:
pass
return w
def get_stats():
w = get_writer()
queue_size = None
if getattr(w, "_use_litequeue", False) and getattr(w, "_lq", None) is not None:
try:
if hasattr(w._lq, "qsize"):
queue_size = w._lq.qsize()
elif hasattr(w._lq, "__len__"):
queue_size = len(w._lq)
elif hasattr(w._lq, "size"):
queue_size = w._lq.size()
except Exception:
queue_size = None
else:
try:
queue_size = w.queue.qsize()
except Exception:
queue_size = None
return {"processed": getattr(w, "_processed_count", 0), "queue_size": queue_size}

20
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. # This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
[[package]] [[package]]
name = "aiohappyeyeballs" name = "aiohappyeyeballs"
@ -973,7 +973,7 @@ description = "Read metadata from Python packages"
optional = false optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
groups = ["main"] groups = ["main"]
markers = "python_version < \"3.10\"" markers = "python_version == \"3.9\""
files = [ files = [
{file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"}, {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"},
{file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"}, {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"},
@ -1129,6 +1129,18 @@ files = [
dnspython = "*" dnspython = "*"
ldap3 = ">2.5.0,<2.5.2 || >2.5.2,<2.6 || >2.6" ldap3 = ">2.5.0,<2.5.2 || >2.5.2,<2.6 || >2.6"
[[package]]
name = "litequeue"
version = "0.9"
description = "Simple queue built on top of SQLite"
optional = false
python-versions = ">=3.6"
groups = ["main"]
files = [
{file = "litequeue-0.9-py3-none-any.whl", hash = "sha256:344312748f9d118253ecaa2fb82d432ab6f63dd3bd5c40922a63933bc47cd2e3"},
{file = "litequeue-0.9.tar.gz", hash = "sha256:368f56b9de5c76fc6f2adc66177e81a59f545f3b5b95a26cae8562b858914647"},
]
[[package]] [[package]]
name = "markupsafe" name = "markupsafe"
version = "2.1.5" version = "2.1.5"
@ -2023,7 +2035,7 @@ files = [
{file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"},
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
] ]
markers = {dev = "python_version < \"3.10\""} markers = {dev = "python_version == \"3.9\""}
[[package]] [[package]]
name = "urllib3" name = "urllib3"
@ -2254,4 +2266,4 @@ type = ["pytest-mypy"]
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = "^3.9, <3.13" python-versions = "^3.9, <3.13"
content-hash = "0e1731401cd6acfc4d45ede5e18668530aae6a6b2e359d7dc8d8d635635a1257" content-hash = "aa676fcd9a242a436052e31b320c0c3d99451dc4323d6ae99fbe8f4f49e0d747"

View File

@ -43,6 +43,7 @@ release_name = "QUIN"
[tool.poetry.scripts] [tool.poetry.scripts]
nettacker = "nettacker.main:run" nettacker = "nettacker.main:run"
nettacker-db-worker = "nettacker.cli.db_worker:run"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.9, <3.13" python = "^3.9, <3.13"
@ -65,6 +66,7 @@ zipp = "^3.19.1"
uvloop = "^0.21.0" uvloop = "^0.21.0"
pymysql = "^1.1.1" pymysql = "^1.1.1"
impacket = "^0.11.0" impacket = "^0.11.0"
litequeue = "^0.9"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
ipython = "^8.16.1" ipython = "^8.16.1"

View File

@ -0,0 +1,115 @@
import os
import subprocess
from pathlib import Path
def start_worker_process(tmp_path):
# run worker in separate process
env = os.environ.copy()
# ensure current project is first on PYTHONPATH
env["PYTHONPATH"] = str(Path.cwd()) + os.pathsep + env.get("PYTHONPATH", "")
# Pass config through environment
data_dir = tmp_path / ".data"
env["NETTACKER_DATA_DIR"] = str(data_dir)
env["NETTACKER_DB_NAME"] = str(data_dir / "nettacker.db")
proc = subprocess.Popen(
[
env.get("PYTHON_BIN", "python"),
"-m",
"nettacker.cli.db_worker",
"--once", # Process all items and exit
"--max-items",
"10",
"--summary", # Show processing stats
],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
print(f"Started worker process {proc.pid} with data_dir={data_dir}")
return proc
def test_worker_writes(tmp_path):
"""Test that the database writer correctly processes queued jobs and writes to database."""
# Create test database
data_dir = tmp_path / ".data"
data_dir.mkdir()
db_path = str(data_dir / "nettacker.db")
# Create database tables
from sqlalchemy import create_engine, text
from nettacker.database.models import Base
engine = create_engine(f"sqlite:///{db_path}")
Base.metadata.create_all(engine)
engine.dispose()
# Create a writer configured to use the test database
from sqlalchemy.orm import sessionmaker
from nettacker.database.writer import DBWriter
writer = DBWriter()
# Override the database connection to use our test database
writer.engine = create_engine(
f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, pool_pre_ping=True
)
# Enable WAL mode for better concurrency
with writer.engine.connect() as conn:
conn.execute(text("PRAGMA journal_mode=WAL"))
conn.commit()
writer.Session = sessionmaker(bind=writer.engine)
# Create test jobs for both report and hosts log
jobs = [
{
"action": "insert_report",
"payload": {
"date": None,
"scan_id": "test-scan",
"report_path_filename": str(data_dir / "r.html"),
"options": {"report_path_filename": str(data_dir / "r.html")},
},
},
{
"action": "insert_hostslog",
"payload": {
"date": None,
"target": "127.0.0.1",
"module_name": "m",
"scan_id": "test-scan",
"port": [],
"event": {},
"json_event": {},
},
},
]
# Enqueue jobs to the writer
for job in jobs:
writer.enqueue(job)
# Process all queued jobs
processed_count = writer.drain_once(max_iterations=10)
assert processed_count == 2
# Verify the jobs were written to the database
import sqlite3
conn = sqlite3.connect(db_path)
c = conn.cursor()
c.execute("select count(*) from reports where scan_unique_id = ?", ("test-scan",))
report_count = c.fetchone()[0]
c.execute("select count(*) from scan_events where scan_unique_id = ?", ("test-scan",))
hosts_count = c.fetchone()[0]
conn.close()
assert report_count == 1
assert hosts_count == 1