fixing test_db.py, the other one is still failing because of some mocking issues

This commit is contained in:
pUrGe12 2025-09-02 14:44:10 +05:30
parent d7c7fd473b
commit 6979f79a39
5 changed files with 102 additions and 75 deletions

View File

@ -83,7 +83,7 @@ class DbConfig(ConfigBase):
fill the name of the DB as sqlite,
DATABASE as the name of the db user wants
Set the journal_mode (default="WAL") and
synchronous_mode (deafault="NORMAL"). Rest
synchronous_mode (default="NORMAL"). Rest
of the fields can be left emptyAdd commentMore actions
This is the default database:
str(CWD / ".data/nettacker.db")

View File

@ -1,7 +1,11 @@
import json
import time
import apsw
try:
import apsw
except ImportError:
apsw = None
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@ -12,7 +16,7 @@ from nettacker.core.messages import messages
from nettacker.database.models import HostsLog, Report, TempEvents
config = Config()
logging = logger.get_logger()
logger = logger.get_logger()
def db_inputs(connection_type):
@ -44,17 +48,24 @@ def create_connection():
connection failed.
"""
if Config.db.engine.startswith("sqlite"):
if apsw is None:
raise ImportError("APSW is required for SQLite backend.")
# In case of sqlite, the name parameter is the database path
DB_PATH = config.db.as_dict()["name"]
connection = apsw.Connection(DB_PATH)
connection.setbusytimeout(int(config.settings.timeout) * 100)
cursor = connection.cursor()
# Performance enhancing configurations. Put WAL cause that helps with concurrency
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
try:
DB_PATH = config.db.as_dict()["name"]
connection = apsw.Connection(DB_PATH)
connection.setbusytimeout(int(config.settings.timeout) * 100)
cursor = connection.cursor()
return connection, cursor
# Performance enhancing configurations. Put WAL cause that helps with concurrency
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
return connection, cursor
except Exception as e:
logger.error(f"Failed to create APSW connection: {e}")
raise
else:
# Both MySQL and PostgreSQL don't need a
@ -94,7 +105,7 @@ def send_submit_query(session):
finally:
connection.close()
connection.close()
logging.warn(messages("database_connect_fail"))
logger.warn(messages("database_connect_fail"))
return False
else:
try:
@ -104,10 +115,10 @@ def send_submit_query(session):
return True
except Exception:
time.sleep(0.1)
logging.warn(messages("database_connect_fail"))
logger.warn(messages("database_connect_fail"))
return False
except Exception:
logging.warn(messages("database_connect_fail"))
logger.warn(messages("database_connect_fail"))
return False
return False
@ -123,7 +134,7 @@ def submit_report_to_db(event):
Returns:
return True if submitted otherwise False
"""
logging.verbose_info(messages("inserting_report_db"))
logger.verbose_info(messages("inserting_report_db"))
session = create_connection()
if isinstance(session, tuple):
@ -146,7 +157,7 @@ def submit_report_to_db(event):
return send_submit_query(session)
except Exception:
cursor.execute("ROLLBACK")
logging.warn("Could not insert report...")
logger.warn("Could not insert report...")
return False
finally:
cursor.close()
@ -197,7 +208,7 @@ def remove_old_logs(options):
return send_submit_query(session)
except Exception:
cursor.execute("ROLLBACK")
logging.warn("Could not remove old logs...")
logger.warn("Could not remove old logs...")
return False
finally:
cursor.close()
@ -253,7 +264,7 @@ def submit_logs_to_db(log):
except apsw.BusyError as e:
if "database is locked" in str(e).lower():
logging.warn(
logger.warn(
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
)
if connection.in_transaction:
@ -272,7 +283,7 @@ def submit_logs_to_db(log):
pass
return False
# All retires exhausted but we want to continue operation
logging.warn("All retries exhausted. Skipping this log.")
logger.warn("All retries exhausted. Skipping this log.")
return True
finally:
cursor.close()
@ -290,7 +301,7 @@ def submit_logs_to_db(log):
)
return send_submit_query(session)
else:
logging.warn(messages("invalid_json_type_to_db").format(log))
logger.warn(messages("invalid_json_type_to_db").format(log))
return False
@ -335,7 +346,7 @@ def submit_temp_logs_to_db(log):
return send_submit_query(session)
except apsw.BusyError as e:
if "database is locked" in str(e).lower():
logging.warn(
logger.warn(
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
)
try:
@ -360,7 +371,7 @@ def submit_temp_logs_to_db(log):
pass
return False
# All retires exhausted but we want to continue operation
logging.warn("All retries exhausted. Skipping this log.")
logger.warn("All retries exhausted. Skipping this log.")
return True
finally:
cursor.close()
@ -379,7 +390,7 @@ def submit_temp_logs_to_db(log):
)
return send_submit_query(session)
else:
logging.warn(messages("invalid_json_type_to_db").format(log))
logger.warn(messages("invalid_json_type_to_db").format(log))
return False
@ -400,28 +411,23 @@ def find_temp_events(target, module_name, scan_id, event_name):
if isinstance(session, tuple):
connection, cursor = session
try:
for _ in range(100):
try:
cursor.execute(
"""
SELECT event
FROM temp_events
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
LIMIT 1
""",
(target, module_name, scan_id, event_name),
)
cursor.execute(
"""
SELECT event
FROM temp_events
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
LIMIT 1
""",
(target, module_name, scan_id, event_name),
)
row = cursor.fetchone()
cursor.close()
if row:
return json.loads(row[0])
return []
except Exception:
logging.warn("Database query failed...")
return []
row = cursor.fetchone()
cursor.close()
if row:
return json.loads(row[0])
return []
except Exception:
logging.warn(messages("database_connect_fail"))
logger.warn(messages("database_connect_fail"))
return []
return []
else:
@ -441,7 +447,7 @@ def find_temp_events(target, module_name, scan_id, event_name):
except Exception:
time.sleep(0.1)
except Exception:
logging.warn(messages("database_connect_fail"))
logger.warn(messages("database_connect_fail"))
return []
return []
@ -477,7 +483,7 @@ def find_events(target, module_name, scan_id):
return [json.dumps((json.loads(row[0]))) for row in rows]
return []
except Exception:
logging.warn("Database query failed...")
logger.warn("Database query failed...")
return []
else:
return [
@ -536,7 +542,7 @@ def select_reports(page):
return selected
except Exception:
logging.warn("Could not retrieve report...")
logger.warn("Could not retrieve report...")
return structure(status="error", msg="database error!")
else:
try:
@ -582,13 +588,23 @@ def get_scan_result(id):
cursor.close()
if row:
filename = row[0]
return filename, open(str(filename), "rb").read()
try:
return filename, open(str(filename), "rb").read()
except IOError as e:
logger.error(f"Failed to read report file: {e}")
return None
else:
return structure(status="error", msg="database error!")
else:
filename = session.query(Report).filter_by(id=id).first().report_path_filename
report = session.query(Report).filter_by(id=id).first()
if not report:
return None
return filename, open(str(filename), "rb").read()
try:
return report.report_path_filename, open(str(report.report_path_filename), "rb").read()
except IOError as e:
logger.error(f"Failed to read report file: {e}")
return None
def last_host_logs(page):
@ -656,7 +672,6 @@ def last_host_logs(page):
)
events = [row[0] for row in cursor.fetchall()]
cursor.close()
hosts.append(
{
"target": target,
@ -667,11 +682,11 @@ def last_host_logs(page):
},
}
)
cursor.close()
return hosts
except Exception:
logging.warn("Database query failed...")
logger.warn("Database query failed...")
return structure(status="error", msg="Database error!")
else:
@ -834,7 +849,7 @@ def logs_to_report_json(target):
"event": json.loads(log[3]),
"json_event": json.loads(log[4]),
}
return_logs.append(data)
return_logs.append(data)
return return_logs
else:
@ -1001,7 +1016,6 @@ def search_logs(page, query):
),
)
targets = cursor.fetchall()
cursor.close()
for target_row in targets:
target = target_row[0]
# Fetch data for each target grouped by key fields
@ -1044,6 +1058,7 @@ def search_logs(page, query):
tmp["info"]["json_event"].append(parsed_json_event)
selected.append(tmp)
cursor.close()
except Exception:
return structure(status="error", msg="database error!")

View File

@ -94,6 +94,9 @@ profile = "black"
addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
markers = [
"asyncio: mark test as async"
]
[tool.ruff]
line-length = 99

View File

@ -59,9 +59,12 @@ def test_load_with_service_discovery(
}
mock_loader.return_value = mock_loader_inst
mock_find_events.return_value = [
MagicMock(json_event='{"port": 80, "response": {"conditions_results": {"http": {}}}}')
]
mock_event1 = MagicMock()
mock_event1.json_event = json.dumps(
{"port": 80, "response": {"conditions_results": {"http": {}}}}
)
mock_find_events.return_value = [mock_event1]
module = Module("test_module", options, **module_args)
module.load()

View File

@ -181,7 +181,7 @@ class TestDatabase:
assert result is True
@patch("nettacker.database.db.messages", return_value="mocked fail message")
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
def test_send_submit_query_sqlite_failure(self, mock_warn, mock_messages):
def sqlite_execute_side_effect(query):
if query == "COMMIT":
@ -210,7 +210,7 @@ class TestDatabase:
assert result is True
@patch("nettacker.database.db.messages", return_value="mocked fail message")
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
def test_send_submit_query_sqlalchemy_failure(self, mock_warn, mock_messages):
mock_session = Mock()
mock_session.commit.side_effect = [Exception("fail")] * 100
@ -372,7 +372,7 @@ class TestDatabase:
assert result is True
@patch("nettacker.database.db.messages", return_value="invalid log")
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
def test_log_not_dict(self, mock_warn, mock_messages):
result = submit_logs_to_db("notadict")
assert result is False
@ -402,7 +402,7 @@ class TestDatabase:
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
@patch("nettacker.database.db.Config.settings.max_retries", 1)
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection")
def test_apsw_busy_error(self, mock_create_conn, mock_warn):
mock_conn = Mock()
@ -544,7 +544,7 @@ class TestDatabase:
assert result is True
@patch("nettacker.database.db.messages", return_value="invalid log")
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
def test_temp_log_not_dict(self, mock_warn, mock_messages):
result = submit_temp_logs_to_db("notadict")
assert result is False
@ -552,7 +552,7 @@ class TestDatabase:
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
@patch("nettacker.database.db.Config.settings.max_retries", 1)
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection")
def test_temp_log_busy_error(self, mock_create_conn, mock_warn):
mock_conn = Mock()
@ -687,7 +687,7 @@ class TestDatabase:
result = find_temp_events(self.target, self.module, self.scan_id, self.event_name)
assert result == []
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.messages", return_value="database fail")
@patch("nettacker.database.db.create_connection")
def test_sqlite_outer_exception(self, mock_create_conn, mock_messages, mock_warn):
@ -732,15 +732,21 @@ class TestDatabase:
result = find_temp_events("192.168.1.1", "port_scan", "scan_123", "event_1")
mock_cursor.execute.assert_called_with(
"""
SELECT event
FROM temp_events
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
LIMIT 1
""",
("192.168.1.1", "port_scan", "scan_123", "event_1"),
)
called_query, called_params = mock_cursor.execute.call_args[0]
expected_query = """
SELECT event
FROM temp_events
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
LIMIT 1
"""
# Normalize whitespace (collapse multiple spaces/newlines into one space)
def normalize(sql: str) -> str:
return " ".join(sql.split())
assert normalize(called_query) == normalize(expected_query)
assert called_params == ("192.168.1.1", "port_scan", "scan_123", "event_1")
assert result == {"test": "data"}
# -------------------------------------------------------
@ -768,7 +774,7 @@ class TestDatabase:
expected = ['{"event1": "data1"}', '{"event2": "data2"}']
assert result == expected
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection")
def test_find_events_sqlite_exception(self, mock_create_conn, mock_warn):
mock_connection = Mock()
@ -839,7 +845,7 @@ class TestDatabase:
]
assert result == expected
@patch("nettacker.database.db.logging.warn")
@patch("nettacker.database.db.logger.warn")
@patch("nettacker.database.db.create_connection")
def test_select_reports_sqlite_exception(self, mock_create_conn, mock_warn):
mock_connection = Mock()