fixing test_db.py, the other one is still failing because of some mocking issues

This commit is contained in:
pUrGe12 2025-09-02 14:44:10 +05:30
parent d7c7fd473b
commit 6979f79a39
5 changed files with 102 additions and 75 deletions

View File

@ -83,7 +83,7 @@ class DbConfig(ConfigBase):
fill the name of the DB as sqlite, fill the name of the DB as sqlite,
DATABASE as the name of the db user wants DATABASE as the name of the db user wants
Set the journal_mode (default="WAL") and Set the journal_mode (default="WAL") and
synchronous_mode (deafault="NORMAL"). Rest synchronous_mode (default="NORMAL"). Rest
of the fields can be left emptyAdd commentMore actions of the fields can be left emptyAdd commentMore actions
This is the default database: This is the default database:
str(CWD / ".data/nettacker.db") str(CWD / ".data/nettacker.db")

View File

@ -1,7 +1,11 @@
import json import json
import time import time
import apsw try:
import apsw
except ImportError:
apsw = None
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -12,7 +16,7 @@ from nettacker.core.messages import messages
from nettacker.database.models import HostsLog, Report, TempEvents from nettacker.database.models import HostsLog, Report, TempEvents
config = Config() config = Config()
logging = logger.get_logger() logger = logger.get_logger()
def db_inputs(connection_type): def db_inputs(connection_type):
@ -44,17 +48,24 @@ def create_connection():
connection failed. connection failed.
""" """
if Config.db.engine.startswith("sqlite"): if Config.db.engine.startswith("sqlite"):
if apsw is None:
raise ImportError("APSW is required for SQLite backend.")
# In case of sqlite, the name parameter is the database path # In case of sqlite, the name parameter is the database path
DB_PATH = config.db.as_dict()["name"]
connection = apsw.Connection(DB_PATH)
connection.setbusytimeout(int(config.settings.timeout) * 100)
cursor = connection.cursor()
# Performance enhancing configurations. Put WAL cause that helps with concurrency try:
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}") DB_PATH = config.db.as_dict()["name"]
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}") connection = apsw.Connection(DB_PATH)
connection.setbusytimeout(int(config.settings.timeout) * 100)
cursor = connection.cursor()
return connection, cursor # Performance enhancing configurations. Put WAL cause that helps with concurrency
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
return connection, cursor
except Exception as e:
logger.error(f"Failed to create APSW connection: {e}")
raise
else: else:
# Both MySQL and PostgreSQL don't need a # Both MySQL and PostgreSQL don't need a
@ -94,7 +105,7 @@ def send_submit_query(session):
finally: finally:
connection.close() connection.close()
connection.close() connection.close()
logging.warn(messages("database_connect_fail")) logger.warn(messages("database_connect_fail"))
return False return False
else: else:
try: try:
@ -104,10 +115,10 @@ def send_submit_query(session):
return True return True
except Exception: except Exception:
time.sleep(0.1) time.sleep(0.1)
logging.warn(messages("database_connect_fail")) logger.warn(messages("database_connect_fail"))
return False return False
except Exception: except Exception:
logging.warn(messages("database_connect_fail")) logger.warn(messages("database_connect_fail"))
return False return False
return False return False
@ -123,7 +134,7 @@ def submit_report_to_db(event):
Returns: Returns:
return True if submitted otherwise False return True if submitted otherwise False
""" """
logging.verbose_info(messages("inserting_report_db")) logger.verbose_info(messages("inserting_report_db"))
session = create_connection() session = create_connection()
if isinstance(session, tuple): if isinstance(session, tuple):
@ -146,7 +157,7 @@ def submit_report_to_db(event):
return send_submit_query(session) return send_submit_query(session)
except Exception: except Exception:
cursor.execute("ROLLBACK") cursor.execute("ROLLBACK")
logging.warn("Could not insert report...") logger.warn("Could not insert report...")
return False return False
finally: finally:
cursor.close() cursor.close()
@ -197,7 +208,7 @@ def remove_old_logs(options):
return send_submit_query(session) return send_submit_query(session)
except Exception: except Exception:
cursor.execute("ROLLBACK") cursor.execute("ROLLBACK")
logging.warn("Could not remove old logs...") logger.warn("Could not remove old logs...")
return False return False
finally: finally:
cursor.close() cursor.close()
@ -253,7 +264,7 @@ def submit_logs_to_db(log):
except apsw.BusyError as e: except apsw.BusyError as e:
if "database is locked" in str(e).lower(): if "database is locked" in str(e).lower():
logging.warn( logger.warn(
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..." f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
) )
if connection.in_transaction: if connection.in_transaction:
@ -272,7 +283,7 @@ def submit_logs_to_db(log):
pass pass
return False return False
# All retires exhausted but we want to continue operation # All retires exhausted but we want to continue operation
logging.warn("All retries exhausted. Skipping this log.") logger.warn("All retries exhausted. Skipping this log.")
return True return True
finally: finally:
cursor.close() cursor.close()
@ -290,7 +301,7 @@ def submit_logs_to_db(log):
) )
return send_submit_query(session) return send_submit_query(session)
else: else:
logging.warn(messages("invalid_json_type_to_db").format(log)) logger.warn(messages("invalid_json_type_to_db").format(log))
return False return False
@ -335,7 +346,7 @@ def submit_temp_logs_to_db(log):
return send_submit_query(session) return send_submit_query(session)
except apsw.BusyError as e: except apsw.BusyError as e:
if "database is locked" in str(e).lower(): if "database is locked" in str(e).lower():
logging.warn( logger.warn(
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..." f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
) )
try: try:
@ -360,7 +371,7 @@ def submit_temp_logs_to_db(log):
pass pass
return False return False
# All retires exhausted but we want to continue operation # All retires exhausted but we want to continue operation
logging.warn("All retries exhausted. Skipping this log.") logger.warn("All retries exhausted. Skipping this log.")
return True return True
finally: finally:
cursor.close() cursor.close()
@ -379,7 +390,7 @@ def submit_temp_logs_to_db(log):
) )
return send_submit_query(session) return send_submit_query(session)
else: else:
logging.warn(messages("invalid_json_type_to_db").format(log)) logger.warn(messages("invalid_json_type_to_db").format(log))
return False return False
@ -400,28 +411,23 @@ def find_temp_events(target, module_name, scan_id, event_name):
if isinstance(session, tuple): if isinstance(session, tuple):
connection, cursor = session connection, cursor = session
try: try:
for _ in range(100): cursor.execute(
try: """
cursor.execute( SELECT event
""" FROM temp_events
SELECT event WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
FROM temp_events LIMIT 1
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ? """,
LIMIT 1 (target, module_name, scan_id, event_name),
""", )
(target, module_name, scan_id, event_name),
)
row = cursor.fetchone() row = cursor.fetchone()
cursor.close() cursor.close()
if row: if row:
return json.loads(row[0]) return json.loads(row[0])
return [] return []
except Exception:
logging.warn("Database query failed...")
return []
except Exception: except Exception:
logging.warn(messages("database_connect_fail")) logger.warn(messages("database_connect_fail"))
return [] return []
return [] return []
else: else:
@ -441,7 +447,7 @@ def find_temp_events(target, module_name, scan_id, event_name):
except Exception: except Exception:
time.sleep(0.1) time.sleep(0.1)
except Exception: except Exception:
logging.warn(messages("database_connect_fail")) logger.warn(messages("database_connect_fail"))
return [] return []
return [] return []
@ -477,7 +483,7 @@ def find_events(target, module_name, scan_id):
return [json.dumps((json.loads(row[0]))) for row in rows] return [json.dumps((json.loads(row[0]))) for row in rows]
return [] return []
except Exception: except Exception:
logging.warn("Database query failed...") logger.warn("Database query failed...")
return [] return []
else: else:
return [ return [
@ -536,7 +542,7 @@ def select_reports(page):
return selected return selected
except Exception: except Exception:
logging.warn("Could not retrieve report...") logger.warn("Could not retrieve report...")
return structure(status="error", msg="database error!") return structure(status="error", msg="database error!")
else: else:
try: try:
@ -582,13 +588,23 @@ def get_scan_result(id):
cursor.close() cursor.close()
if row: if row:
filename = row[0] filename = row[0]
return filename, open(str(filename), "rb").read() try:
return filename, open(str(filename), "rb").read()
except IOError as e:
logger.error(f"Failed to read report file: {e}")
return None
else: else:
return structure(status="error", msg="database error!") return structure(status="error", msg="database error!")
else: else:
filename = session.query(Report).filter_by(id=id).first().report_path_filename report = session.query(Report).filter_by(id=id).first()
if not report:
return None
return filename, open(str(filename), "rb").read() try:
return report.report_path_filename, open(str(report.report_path_filename), "rb").read()
except IOError as e:
logger.error(f"Failed to read report file: {e}")
return None
def last_host_logs(page): def last_host_logs(page):
@ -656,7 +672,6 @@ def last_host_logs(page):
) )
events = [row[0] for row in cursor.fetchall()] events = [row[0] for row in cursor.fetchall()]
cursor.close()
hosts.append( hosts.append(
{ {
"target": target, "target": target,
@ -667,11 +682,11 @@ def last_host_logs(page):
}, },
} }
) )
cursor.close()
return hosts return hosts
except Exception: except Exception:
logging.warn("Database query failed...") logger.warn("Database query failed...")
return structure(status="error", msg="Database error!") return structure(status="error", msg="Database error!")
else: else:
@ -834,7 +849,7 @@ def logs_to_report_json(target):
"event": json.loads(log[3]), "event": json.loads(log[3]),
"json_event": json.loads(log[4]), "json_event": json.loads(log[4]),
} }
return_logs.append(data) return_logs.append(data)
return return_logs return return_logs
else: else:
@ -1001,7 +1016,6 @@ def search_logs(page, query):
), ),
) )
targets = cursor.fetchall() targets = cursor.fetchall()
cursor.close()
for target_row in targets: for target_row in targets:
target = target_row[0] target = target_row[0]
# Fetch data for each target grouped by key fields # Fetch data for each target grouped by key fields
@ -1044,6 +1058,7 @@ def search_logs(page, query):
tmp["info"]["json_event"].append(parsed_json_event) tmp["info"]["json_event"].append(parsed_json_event)
selected.append(tmp) selected.append(tmp)
cursor.close()
except Exception: except Exception:
return structure(status="error", msg="database error!") return structure(status="error", msg="database error!")

View File

@ -94,6 +94,9 @@ profile = "black"
addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto" addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto"
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"] testpaths = ["tests"]
markers = [
"asyncio: mark test as async"
]
[tool.ruff] [tool.ruff]
line-length = 99 line-length = 99

View File

@ -59,9 +59,12 @@ def test_load_with_service_discovery(
} }
mock_loader.return_value = mock_loader_inst mock_loader.return_value = mock_loader_inst
mock_find_events.return_value = [ mock_event1 = MagicMock()
MagicMock(json_event='{"port": 80, "response": {"conditions_results": {"http": {}}}}') mock_event1.json_event = json.dumps(
] {"port": 80, "response": {"conditions_results": {"http": {}}}}
)
mock_find_events.return_value = [mock_event1]
module = Module("test_module", options, **module_args) module = Module("test_module", options, **module_args)
module.load() module.load()

View File

@ -181,7 +181,7 @@ class TestDatabase:
assert result is True assert result is True
@patch("nettacker.database.db.messages", return_value="mocked fail message") @patch("nettacker.database.db.messages", return_value="mocked fail message")
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
def test_send_submit_query_sqlite_failure(self, mock_warn, mock_messages): def test_send_submit_query_sqlite_failure(self, mock_warn, mock_messages):
def sqlite_execute_side_effect(query): def sqlite_execute_side_effect(query):
if query == "COMMIT": if query == "COMMIT":
@ -210,7 +210,7 @@ class TestDatabase:
assert result is True assert result is True
@patch("nettacker.database.db.messages", return_value="mocked fail message") @patch("nettacker.database.db.messages", return_value="mocked fail message")
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
def test_send_submit_query_sqlalchemy_failure(self, mock_warn, mock_messages): def test_send_submit_query_sqlalchemy_failure(self, mock_warn, mock_messages):
mock_session = Mock() mock_session = Mock()
mock_session.commit.side_effect = [Exception("fail")] * 100 mock_session.commit.side_effect = [Exception("fail")] * 100
@ -372,7 +372,7 @@ class TestDatabase:
assert result is True assert result is True
@patch("nettacker.database.db.messages", return_value="invalid log") @patch("nettacker.database.db.messages", return_value="invalid log")
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
def test_log_not_dict(self, mock_warn, mock_messages): def test_log_not_dict(self, mock_warn, mock_messages):
result = submit_logs_to_db("notadict") result = submit_logs_to_db("notadict")
assert result is False assert result is False
@ -402,7 +402,7 @@ class TestDatabase:
@patch("nettacker.database.db.Config.settings.retry_delay", 0) @patch("nettacker.database.db.Config.settings.retry_delay", 0)
@patch("nettacker.database.db.Config.settings.max_retries", 1) @patch("nettacker.database.db.Config.settings.max_retries", 1)
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection") @patch("nettacker.database.db.create_connection")
def test_apsw_busy_error(self, mock_create_conn, mock_warn): def test_apsw_busy_error(self, mock_create_conn, mock_warn):
mock_conn = Mock() mock_conn = Mock()
@ -544,7 +544,7 @@ class TestDatabase:
assert result is True assert result is True
@patch("nettacker.database.db.messages", return_value="invalid log") @patch("nettacker.database.db.messages", return_value="invalid log")
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
def test_temp_log_not_dict(self, mock_warn, mock_messages): def test_temp_log_not_dict(self, mock_warn, mock_messages):
result = submit_temp_logs_to_db("notadict") result = submit_temp_logs_to_db("notadict")
assert result is False assert result is False
@ -552,7 +552,7 @@ class TestDatabase:
@patch("nettacker.database.db.Config.settings.retry_delay", 0) @patch("nettacker.database.db.Config.settings.retry_delay", 0)
@patch("nettacker.database.db.Config.settings.max_retries", 1) @patch("nettacker.database.db.Config.settings.max_retries", 1)
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection") @patch("nettacker.database.db.create_connection")
def test_temp_log_busy_error(self, mock_create_conn, mock_warn): def test_temp_log_busy_error(self, mock_create_conn, mock_warn):
mock_conn = Mock() mock_conn = Mock()
@ -687,7 +687,7 @@ class TestDatabase:
result = find_temp_events(self.target, self.module, self.scan_id, self.event_name) result = find_temp_events(self.target, self.module, self.scan_id, self.event_name)
assert result == [] assert result == []
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.messages", return_value="database fail") @patch("nettacker.database.db.messages", return_value="database fail")
@patch("nettacker.database.db.create_connection") @patch("nettacker.database.db.create_connection")
def test_sqlite_outer_exception(self, mock_create_conn, mock_messages, mock_warn): def test_sqlite_outer_exception(self, mock_create_conn, mock_messages, mock_warn):
@ -732,15 +732,21 @@ class TestDatabase:
result = find_temp_events("192.168.1.1", "port_scan", "scan_123", "event_1") result = find_temp_events("192.168.1.1", "port_scan", "scan_123", "event_1")
mock_cursor.execute.assert_called_with( called_query, called_params = mock_cursor.execute.call_args[0]
"""
SELECT event expected_query = """
FROM temp_events SELECT event
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ? FROM temp_events
LIMIT 1 WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
""", LIMIT 1
("192.168.1.1", "port_scan", "scan_123", "event_1"), """
)
# Normalize whitespace (collapse multiple spaces/newlines into one space)
def normalize(sql: str) -> str:
return " ".join(sql.split())
assert normalize(called_query) == normalize(expected_query)
assert called_params == ("192.168.1.1", "port_scan", "scan_123", "event_1")
assert result == {"test": "data"} assert result == {"test": "data"}
# ------------------------------------------------------- # -------------------------------------------------------
@ -768,7 +774,7 @@ class TestDatabase:
expected = ['{"event1": "data1"}', '{"event2": "data2"}'] expected = ['{"event1": "data1"}', '{"event2": "data2"}']
assert result == expected assert result == expected
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection") @patch("nettacker.database.db.create_connection")
def test_find_events_sqlite_exception(self, mock_create_conn, mock_warn): def test_find_events_sqlite_exception(self, mock_create_conn, mock_warn):
mock_connection = Mock() mock_connection = Mock()
@ -839,7 +845,7 @@ class TestDatabase:
] ]
assert result == expected assert result == expected
@patch("nettacker.database.db.logging.warn") @patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection") @patch("nettacker.database.db.create_connection")
def test_select_reports_sqlite_exception(self, mock_create_conn, mock_warn): def test_select_reports_sqlite_exception(self, mock_create_conn, mock_warn):
mock_connection = Mock() mock_connection = Mock()