mirror of https://github.com/OWASP/Nettacker.git
fixing test_db.py, the other one is still failing because of some mocking issues
This commit is contained in:
parent
d7c7fd473b
commit
6979f79a39
|
|
@ -83,7 +83,7 @@ class DbConfig(ConfigBase):
|
|||
fill the name of the DB as sqlite,
|
||||
DATABASE as the name of the db user wants
|
||||
Set the journal_mode (default="WAL") and
|
||||
synchronous_mode (deafault="NORMAL"). Rest
|
||||
synchronous_mode (default="NORMAL"). Rest
|
||||
of the fields can be left emptyAdd commentMore actions
|
||||
This is the default database:
|
||||
str(CWD / ".data/nettacker.db")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,11 @@
|
|||
import json
|
||||
import time
|
||||
|
||||
import apsw
|
||||
try:
|
||||
import apsw
|
||||
except ImportError:
|
||||
apsw = None
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
|
@ -12,7 +16,7 @@ from nettacker.core.messages import messages
|
|||
from nettacker.database.models import HostsLog, Report, TempEvents
|
||||
|
||||
config = Config()
|
||||
logging = logger.get_logger()
|
||||
logger = logger.get_logger()
|
||||
|
||||
|
||||
def db_inputs(connection_type):
|
||||
|
|
@ -44,7 +48,11 @@ def create_connection():
|
|||
connection failed.
|
||||
"""
|
||||
if Config.db.engine.startswith("sqlite"):
|
||||
if apsw is None:
|
||||
raise ImportError("APSW is required for SQLite backend.")
|
||||
# In case of sqlite, the name parameter is the database path
|
||||
|
||||
try:
|
||||
DB_PATH = config.db.as_dict()["name"]
|
||||
connection = apsw.Connection(DB_PATH)
|
||||
connection.setbusytimeout(int(config.settings.timeout) * 100)
|
||||
|
|
@ -55,6 +63,9 @@ def create_connection():
|
|||
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
|
||||
|
||||
return connection, cursor
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create APSW connection: {e}")
|
||||
raise
|
||||
|
||||
else:
|
||||
# Both MySQL and PostgreSQL don't need a
|
||||
|
|
@ -94,7 +105,7 @@ def send_submit_query(session):
|
|||
finally:
|
||||
connection.close()
|
||||
connection.close()
|
||||
logging.warn(messages("database_connect_fail"))
|
||||
logger.warn(messages("database_connect_fail"))
|
||||
return False
|
||||
else:
|
||||
try:
|
||||
|
|
@ -104,10 +115,10 @@ def send_submit_query(session):
|
|||
return True
|
||||
except Exception:
|
||||
time.sleep(0.1)
|
||||
logging.warn(messages("database_connect_fail"))
|
||||
logger.warn(messages("database_connect_fail"))
|
||||
return False
|
||||
except Exception:
|
||||
logging.warn(messages("database_connect_fail"))
|
||||
logger.warn(messages("database_connect_fail"))
|
||||
return False
|
||||
return False
|
||||
|
||||
|
|
@ -123,7 +134,7 @@ def submit_report_to_db(event):
|
|||
Returns:
|
||||
return True if submitted otherwise False
|
||||
"""
|
||||
logging.verbose_info(messages("inserting_report_db"))
|
||||
logger.verbose_info(messages("inserting_report_db"))
|
||||
session = create_connection()
|
||||
|
||||
if isinstance(session, tuple):
|
||||
|
|
@ -146,7 +157,7 @@ def submit_report_to_db(event):
|
|||
return send_submit_query(session)
|
||||
except Exception:
|
||||
cursor.execute("ROLLBACK")
|
||||
logging.warn("Could not insert report...")
|
||||
logger.warn("Could not insert report...")
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
|
@ -197,7 +208,7 @@ def remove_old_logs(options):
|
|||
return send_submit_query(session)
|
||||
except Exception:
|
||||
cursor.execute("ROLLBACK")
|
||||
logging.warn("Could not remove old logs...")
|
||||
logger.warn("Could not remove old logs...")
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
|
@ -253,7 +264,7 @@ def submit_logs_to_db(log):
|
|||
|
||||
except apsw.BusyError as e:
|
||||
if "database is locked" in str(e).lower():
|
||||
logging.warn(
|
||||
logger.warn(
|
||||
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
|
||||
)
|
||||
if connection.in_transaction:
|
||||
|
|
@ -272,7 +283,7 @@ def submit_logs_to_db(log):
|
|||
pass
|
||||
return False
|
||||
# All retires exhausted but we want to continue operation
|
||||
logging.warn("All retries exhausted. Skipping this log.")
|
||||
logger.warn("All retries exhausted. Skipping this log.")
|
||||
return True
|
||||
finally:
|
||||
cursor.close()
|
||||
|
|
@ -290,7 +301,7 @@ def submit_logs_to_db(log):
|
|||
)
|
||||
return send_submit_query(session)
|
||||
else:
|
||||
logging.warn(messages("invalid_json_type_to_db").format(log))
|
||||
logger.warn(messages("invalid_json_type_to_db").format(log))
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -335,7 +346,7 @@ def submit_temp_logs_to_db(log):
|
|||
return send_submit_query(session)
|
||||
except apsw.BusyError as e:
|
||||
if "database is locked" in str(e).lower():
|
||||
logging.warn(
|
||||
logger.warn(
|
||||
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
|
||||
)
|
||||
try:
|
||||
|
|
@ -360,7 +371,7 @@ def submit_temp_logs_to_db(log):
|
|||
pass
|
||||
return False
|
||||
# All retires exhausted but we want to continue operation
|
||||
logging.warn("All retries exhausted. Skipping this log.")
|
||||
logger.warn("All retries exhausted. Skipping this log.")
|
||||
return True
|
||||
finally:
|
||||
cursor.close()
|
||||
|
|
@ -379,7 +390,7 @@ def submit_temp_logs_to_db(log):
|
|||
)
|
||||
return send_submit_query(session)
|
||||
else:
|
||||
logging.warn(messages("invalid_json_type_to_db").format(log))
|
||||
logger.warn(messages("invalid_json_type_to_db").format(log))
|
||||
return False
|
||||
|
||||
|
||||
|
|
@ -399,8 +410,6 @@ def find_temp_events(target, module_name, scan_id, event_name):
|
|||
session = create_connection()
|
||||
if isinstance(session, tuple):
|
||||
connection, cursor = session
|
||||
try:
|
||||
for _ in range(100):
|
||||
try:
|
||||
cursor.execute(
|
||||
"""
|
||||
|
|
@ -418,10 +427,7 @@ def find_temp_events(target, module_name, scan_id, event_name):
|
|||
return json.loads(row[0])
|
||||
return []
|
||||
except Exception:
|
||||
logging.warn("Database query failed...")
|
||||
return []
|
||||
except Exception:
|
||||
logging.warn(messages("database_connect_fail"))
|
||||
logger.warn(messages("database_connect_fail"))
|
||||
return []
|
||||
return []
|
||||
else:
|
||||
|
|
@ -441,7 +447,7 @@ def find_temp_events(target, module_name, scan_id, event_name):
|
|||
except Exception:
|
||||
time.sleep(0.1)
|
||||
except Exception:
|
||||
logging.warn(messages("database_connect_fail"))
|
||||
logger.warn(messages("database_connect_fail"))
|
||||
return []
|
||||
return []
|
||||
|
||||
|
|
@ -477,7 +483,7 @@ def find_events(target, module_name, scan_id):
|
|||
return [json.dumps((json.loads(row[0]))) for row in rows]
|
||||
return []
|
||||
except Exception:
|
||||
logging.warn("Database query failed...")
|
||||
logger.warn("Database query failed...")
|
||||
return []
|
||||
else:
|
||||
return [
|
||||
|
|
@ -536,7 +542,7 @@ def select_reports(page):
|
|||
return selected
|
||||
|
||||
except Exception:
|
||||
logging.warn("Could not retrieve report...")
|
||||
logger.warn("Could not retrieve report...")
|
||||
return structure(status="error", msg="database error!")
|
||||
else:
|
||||
try:
|
||||
|
|
@ -582,13 +588,23 @@ def get_scan_result(id):
|
|||
cursor.close()
|
||||
if row:
|
||||
filename = row[0]
|
||||
try:
|
||||
return filename, open(str(filename), "rb").read()
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to read report file: {e}")
|
||||
return None
|
||||
else:
|
||||
return structure(status="error", msg="database error!")
|
||||
else:
|
||||
filename = session.query(Report).filter_by(id=id).first().report_path_filename
|
||||
report = session.query(Report).filter_by(id=id).first()
|
||||
if not report:
|
||||
return None
|
||||
|
||||
return filename, open(str(filename), "rb").read()
|
||||
try:
|
||||
return report.report_path_filename, open(str(report.report_path_filename), "rb").read()
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to read report file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def last_host_logs(page):
|
||||
|
|
@ -656,7 +672,6 @@ def last_host_logs(page):
|
|||
)
|
||||
events = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
cursor.close()
|
||||
hosts.append(
|
||||
{
|
||||
"target": target,
|
||||
|
|
@ -667,11 +682,11 @@ def last_host_logs(page):
|
|||
},
|
||||
}
|
||||
)
|
||||
|
||||
cursor.close()
|
||||
return hosts
|
||||
|
||||
except Exception:
|
||||
logging.warn("Database query failed...")
|
||||
logger.warn("Database query failed...")
|
||||
return structure(status="error", msg="Database error!")
|
||||
|
||||
else:
|
||||
|
|
@ -1001,7 +1016,6 @@ def search_logs(page, query):
|
|||
),
|
||||
)
|
||||
targets = cursor.fetchall()
|
||||
cursor.close()
|
||||
for target_row in targets:
|
||||
target = target_row[0]
|
||||
# Fetch data for each target grouped by key fields
|
||||
|
|
@ -1044,6 +1058,7 @@ def search_logs(page, query):
|
|||
tmp["info"]["json_event"].append(parsed_json_event)
|
||||
|
||||
selected.append(tmp)
|
||||
cursor.close()
|
||||
|
||||
except Exception:
|
||||
return structure(status="error", msg="database error!")
|
||||
|
|
|
|||
|
|
@ -94,6 +94,9 @@ profile = "black"
|
|||
addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
testpaths = ["tests"]
|
||||
markers = [
|
||||
"asyncio: mark test as async"
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 99
|
||||
|
|
|
|||
|
|
@ -59,9 +59,12 @@ def test_load_with_service_discovery(
|
|||
}
|
||||
mock_loader.return_value = mock_loader_inst
|
||||
|
||||
mock_find_events.return_value = [
|
||||
MagicMock(json_event='{"port": 80, "response": {"conditions_results": {"http": {}}}}')
|
||||
]
|
||||
mock_event1 = MagicMock()
|
||||
mock_event1.json_event = json.dumps(
|
||||
{"port": 80, "response": {"conditions_results": {"http": {}}}}
|
||||
)
|
||||
|
||||
mock_find_events.return_value = [mock_event1]
|
||||
|
||||
module = Module("test_module", options, **module_args)
|
||||
module.load()
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class TestDatabase:
|
|||
assert result is True
|
||||
|
||||
@patch("nettacker.database.db.messages", return_value="mocked fail message")
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
def test_send_submit_query_sqlite_failure(self, mock_warn, mock_messages):
|
||||
def sqlite_execute_side_effect(query):
|
||||
if query == "COMMIT":
|
||||
|
|
@ -210,7 +210,7 @@ class TestDatabase:
|
|||
assert result is True
|
||||
|
||||
@patch("nettacker.database.db.messages", return_value="mocked fail message")
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
def test_send_submit_query_sqlalchemy_failure(self, mock_warn, mock_messages):
|
||||
mock_session = Mock()
|
||||
mock_session.commit.side_effect = [Exception("fail")] * 100
|
||||
|
|
@ -372,7 +372,7 @@ class TestDatabase:
|
|||
assert result is True
|
||||
|
||||
@patch("nettacker.database.db.messages", return_value="invalid log")
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
def test_log_not_dict(self, mock_warn, mock_messages):
|
||||
result = submit_logs_to_db("notadict")
|
||||
assert result is False
|
||||
|
|
@ -402,7 +402,7 @@ class TestDatabase:
|
|||
|
||||
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
|
||||
@patch("nettacker.database.db.Config.settings.max_retries", 1)
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
@patch("nettacker.database.db.create_connection")
|
||||
def test_apsw_busy_error(self, mock_create_conn, mock_warn):
|
||||
mock_conn = Mock()
|
||||
|
|
@ -544,7 +544,7 @@ class TestDatabase:
|
|||
assert result is True
|
||||
|
||||
@patch("nettacker.database.db.messages", return_value="invalid log")
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
def test_temp_log_not_dict(self, mock_warn, mock_messages):
|
||||
result = submit_temp_logs_to_db("notadict")
|
||||
assert result is False
|
||||
|
|
@ -552,7 +552,7 @@ class TestDatabase:
|
|||
|
||||
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
|
||||
@patch("nettacker.database.db.Config.settings.max_retries", 1)
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
@patch("nettacker.database.db.create_connection")
|
||||
def test_temp_log_busy_error(self, mock_create_conn, mock_warn):
|
||||
mock_conn = Mock()
|
||||
|
|
@ -687,7 +687,7 @@ class TestDatabase:
|
|||
result = find_temp_events(self.target, self.module, self.scan_id, self.event_name)
|
||||
assert result == []
|
||||
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
@patch("nettacker.database.db.messages", return_value="database fail")
|
||||
@patch("nettacker.database.db.create_connection")
|
||||
def test_sqlite_outer_exception(self, mock_create_conn, mock_messages, mock_warn):
|
||||
|
|
@ -732,15 +732,21 @@ class TestDatabase:
|
|||
|
||||
result = find_temp_events("192.168.1.1", "port_scan", "scan_123", "event_1")
|
||||
|
||||
mock_cursor.execute.assert_called_with(
|
||||
"""
|
||||
called_query, called_params = mock_cursor.execute.call_args[0]
|
||||
|
||||
expected_query = """
|
||||
SELECT event
|
||||
FROM temp_events
|
||||
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
("192.168.1.1", "port_scan", "scan_123", "event_1"),
|
||||
)
|
||||
"""
|
||||
|
||||
# Normalize whitespace (collapse multiple spaces/newlines into one space)
|
||||
def normalize(sql: str) -> str:
|
||||
return " ".join(sql.split())
|
||||
|
||||
assert normalize(called_query) == normalize(expected_query)
|
||||
assert called_params == ("192.168.1.1", "port_scan", "scan_123", "event_1")
|
||||
assert result == {"test": "data"}
|
||||
|
||||
# -------------------------------------------------------
|
||||
|
|
@ -768,7 +774,7 @@ class TestDatabase:
|
|||
expected = ['{"event1": "data1"}', '{"event2": "data2"}']
|
||||
assert result == expected
|
||||
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
@patch("nettacker.database.db.create_connection")
|
||||
def test_find_events_sqlite_exception(self, mock_create_conn, mock_warn):
|
||||
mock_connection = Mock()
|
||||
|
|
@ -839,7 +845,7 @@ class TestDatabase:
|
|||
]
|
||||
assert result == expected
|
||||
|
||||
@patch("nettacker.database.db.logging.warn")
|
||||
@patch("nettacker.database.db.logger.warn")
|
||||
@patch("nettacker.database.db.create_connection")
|
||||
def test_select_reports_sqlite_exception(self, mock_create_conn, mock_warn):
|
||||
mock_connection = Mock()
|
||||
|
|
|
|||
Loading…
Reference in New Issue