mirror of https://github.com/OWASP/Nettacker.git
fixing test_db.py, the other one is still failing because of some mocking issues
This commit is contained in:
parent
d7c7fd473b
commit
6979f79a39
|
|
@ -83,7 +83,7 @@ class DbConfig(ConfigBase):
|
||||||
fill the name of the DB as sqlite,
|
fill the name of the DB as sqlite,
|
||||||
DATABASE as the name of the db user wants
|
DATABASE as the name of the db user wants
|
||||||
Set the journal_mode (default="WAL") and
|
Set the journal_mode (default="WAL") and
|
||||||
synchronous_mode (deafault="NORMAL"). Rest
|
synchronous_mode (default="NORMAL"). Rest
|
||||||
of the fields can be left emptyAdd commentMore actions
|
of the fields can be left emptyAdd commentMore actions
|
||||||
This is the default database:
|
This is the default database:
|
||||||
str(CWD / ".data/nettacker.db")
|
str(CWD / ".data/nettacker.db")
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,11 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import apsw
|
try:
|
||||||
|
import apsw
|
||||||
|
except ImportError:
|
||||||
|
apsw = None
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
|
@ -12,7 +16,7 @@ from nettacker.core.messages import messages
|
||||||
from nettacker.database.models import HostsLog, Report, TempEvents
|
from nettacker.database.models import HostsLog, Report, TempEvents
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
logging = logger.get_logger()
|
logger = logger.get_logger()
|
||||||
|
|
||||||
|
|
||||||
def db_inputs(connection_type):
|
def db_inputs(connection_type):
|
||||||
|
|
@ -44,17 +48,24 @@ def create_connection():
|
||||||
connection failed.
|
connection failed.
|
||||||
"""
|
"""
|
||||||
if Config.db.engine.startswith("sqlite"):
|
if Config.db.engine.startswith("sqlite"):
|
||||||
|
if apsw is None:
|
||||||
|
raise ImportError("APSW is required for SQLite backend.")
|
||||||
# In case of sqlite, the name parameter is the database path
|
# In case of sqlite, the name parameter is the database path
|
||||||
DB_PATH = config.db.as_dict()["name"]
|
|
||||||
connection = apsw.Connection(DB_PATH)
|
|
||||||
connection.setbusytimeout(int(config.settings.timeout) * 100)
|
|
||||||
cursor = connection.cursor()
|
|
||||||
|
|
||||||
# Performance enhancing configurations. Put WAL cause that helps with concurrency
|
try:
|
||||||
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
|
DB_PATH = config.db.as_dict()["name"]
|
||||||
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
|
connection = apsw.Connection(DB_PATH)
|
||||||
|
connection.setbusytimeout(int(config.settings.timeout) * 100)
|
||||||
|
cursor = connection.cursor()
|
||||||
|
|
||||||
return connection, cursor
|
# Performance enhancing configurations. Put WAL cause that helps with concurrency
|
||||||
|
cursor.execute(f"PRAGMA journal_mode={Config.db.journal_mode}")
|
||||||
|
cursor.execute(f"PRAGMA synchronous={Config.db.synchronous_mode}")
|
||||||
|
|
||||||
|
return connection, cursor
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create APSW connection: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Both MySQL and PostgreSQL don't need a
|
# Both MySQL and PostgreSQL don't need a
|
||||||
|
|
@ -94,7 +105,7 @@ def send_submit_query(session):
|
||||||
finally:
|
finally:
|
||||||
connection.close()
|
connection.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
logging.warn(messages("database_connect_fail"))
|
logger.warn(messages("database_connect_fail"))
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|
@ -104,10 +115,10 @@ def send_submit_query(session):
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
logging.warn(messages("database_connect_fail"))
|
logger.warn(messages("database_connect_fail"))
|
||||||
return False
|
return False
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warn(messages("database_connect_fail"))
|
logger.warn(messages("database_connect_fail"))
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -123,7 +134,7 @@ def submit_report_to_db(event):
|
||||||
Returns:
|
Returns:
|
||||||
return True if submitted otherwise False
|
return True if submitted otherwise False
|
||||||
"""
|
"""
|
||||||
logging.verbose_info(messages("inserting_report_db"))
|
logger.verbose_info(messages("inserting_report_db"))
|
||||||
session = create_connection()
|
session = create_connection()
|
||||||
|
|
||||||
if isinstance(session, tuple):
|
if isinstance(session, tuple):
|
||||||
|
|
@ -146,7 +157,7 @@ def submit_report_to_db(event):
|
||||||
return send_submit_query(session)
|
return send_submit_query(session)
|
||||||
except Exception:
|
except Exception:
|
||||||
cursor.execute("ROLLBACK")
|
cursor.execute("ROLLBACK")
|
||||||
logging.warn("Could not insert report...")
|
logger.warn("Could not insert report...")
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
@ -197,7 +208,7 @@ def remove_old_logs(options):
|
||||||
return send_submit_query(session)
|
return send_submit_query(session)
|
||||||
except Exception:
|
except Exception:
|
||||||
cursor.execute("ROLLBACK")
|
cursor.execute("ROLLBACK")
|
||||||
logging.warn("Could not remove old logs...")
|
logger.warn("Could not remove old logs...")
|
||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
@ -253,7 +264,7 @@ def submit_logs_to_db(log):
|
||||||
|
|
||||||
except apsw.BusyError as e:
|
except apsw.BusyError as e:
|
||||||
if "database is locked" in str(e).lower():
|
if "database is locked" in str(e).lower():
|
||||||
logging.warn(
|
logger.warn(
|
||||||
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
|
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
|
||||||
)
|
)
|
||||||
if connection.in_transaction:
|
if connection.in_transaction:
|
||||||
|
|
@ -272,7 +283,7 @@ def submit_logs_to_db(log):
|
||||||
pass
|
pass
|
||||||
return False
|
return False
|
||||||
# All retires exhausted but we want to continue operation
|
# All retires exhausted but we want to continue operation
|
||||||
logging.warn("All retries exhausted. Skipping this log.")
|
logger.warn("All retries exhausted. Skipping this log.")
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
@ -290,7 +301,7 @@ def submit_logs_to_db(log):
|
||||||
)
|
)
|
||||||
return send_submit_query(session)
|
return send_submit_query(session)
|
||||||
else:
|
else:
|
||||||
logging.warn(messages("invalid_json_type_to_db").format(log))
|
logger.warn(messages("invalid_json_type_to_db").format(log))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -335,7 +346,7 @@ def submit_temp_logs_to_db(log):
|
||||||
return send_submit_query(session)
|
return send_submit_query(session)
|
||||||
except apsw.BusyError as e:
|
except apsw.BusyError as e:
|
||||||
if "database is locked" in str(e).lower():
|
if "database is locked" in str(e).lower():
|
||||||
logging.warn(
|
logger.warn(
|
||||||
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
|
f"[Retry {_ + 1}/{Config.settings.max_retries}] Database is locked. Retrying..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
@ -360,7 +371,7 @@ def submit_temp_logs_to_db(log):
|
||||||
pass
|
pass
|
||||||
return False
|
return False
|
||||||
# All retires exhausted but we want to continue operation
|
# All retires exhausted but we want to continue operation
|
||||||
logging.warn("All retries exhausted. Skipping this log.")
|
logger.warn("All retries exhausted. Skipping this log.")
|
||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
@ -379,7 +390,7 @@ def submit_temp_logs_to_db(log):
|
||||||
)
|
)
|
||||||
return send_submit_query(session)
|
return send_submit_query(session)
|
||||||
else:
|
else:
|
||||||
logging.warn(messages("invalid_json_type_to_db").format(log))
|
logger.warn(messages("invalid_json_type_to_db").format(log))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -400,28 +411,23 @@ def find_temp_events(target, module_name, scan_id, event_name):
|
||||||
if isinstance(session, tuple):
|
if isinstance(session, tuple):
|
||||||
connection, cursor = session
|
connection, cursor = session
|
||||||
try:
|
try:
|
||||||
for _ in range(100):
|
cursor.execute(
|
||||||
try:
|
"""
|
||||||
cursor.execute(
|
SELECT event
|
||||||
"""
|
FROM temp_events
|
||||||
SELECT event
|
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
|
||||||
FROM temp_events
|
LIMIT 1
|
||||||
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
|
""",
|
||||||
LIMIT 1
|
(target, module_name, scan_id, event_name),
|
||||||
""",
|
)
|
||||||
(target, module_name, scan_id, event_name),
|
|
||||||
)
|
|
||||||
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
cursor.close()
|
cursor.close()
|
||||||
if row:
|
if row:
|
||||||
return json.loads(row[0])
|
return json.loads(row[0])
|
||||||
return []
|
return []
|
||||||
except Exception:
|
|
||||||
logging.warn("Database query failed...")
|
|
||||||
return []
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warn(messages("database_connect_fail"))
|
logger.warn(messages("database_connect_fail"))
|
||||||
return []
|
return []
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
|
@ -441,7 +447,7 @@ def find_temp_events(target, module_name, scan_id, event_name):
|
||||||
except Exception:
|
except Exception:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warn(messages("database_connect_fail"))
|
logger.warn(messages("database_connect_fail"))
|
||||||
return []
|
return []
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
@ -477,7 +483,7 @@ def find_events(target, module_name, scan_id):
|
||||||
return [json.dumps((json.loads(row[0]))) for row in rows]
|
return [json.dumps((json.loads(row[0]))) for row in rows]
|
||||||
return []
|
return []
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warn("Database query failed...")
|
logger.warn("Database query failed...")
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
return [
|
return [
|
||||||
|
|
@ -536,7 +542,7 @@ def select_reports(page):
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warn("Could not retrieve report...")
|
logger.warn("Could not retrieve report...")
|
||||||
return structure(status="error", msg="database error!")
|
return structure(status="error", msg="database error!")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
|
@ -582,13 +588,23 @@ def get_scan_result(id):
|
||||||
cursor.close()
|
cursor.close()
|
||||||
if row:
|
if row:
|
||||||
filename = row[0]
|
filename = row[0]
|
||||||
return filename, open(str(filename), "rb").read()
|
try:
|
||||||
|
return filename, open(str(filename), "rb").read()
|
||||||
|
except IOError as e:
|
||||||
|
logger.error(f"Failed to read report file: {e}")
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
return structure(status="error", msg="database error!")
|
return structure(status="error", msg="database error!")
|
||||||
else:
|
else:
|
||||||
filename = session.query(Report).filter_by(id=id).first().report_path_filename
|
report = session.query(Report).filter_by(id=id).first()
|
||||||
|
if not report:
|
||||||
|
return None
|
||||||
|
|
||||||
return filename, open(str(filename), "rb").read()
|
try:
|
||||||
|
return report.report_path_filename, open(str(report.report_path_filename), "rb").read()
|
||||||
|
except IOError as e:
|
||||||
|
logger.error(f"Failed to read report file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def last_host_logs(page):
|
def last_host_logs(page):
|
||||||
|
|
@ -656,7 +672,6 @@ def last_host_logs(page):
|
||||||
)
|
)
|
||||||
events = [row[0] for row in cursor.fetchall()]
|
events = [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
cursor.close()
|
|
||||||
hosts.append(
|
hosts.append(
|
||||||
{
|
{
|
||||||
"target": target,
|
"target": target,
|
||||||
|
|
@ -667,11 +682,11 @@ def last_host_logs(page):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
cursor.close()
|
||||||
return hosts
|
return hosts
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logging.warn("Database query failed...")
|
logger.warn("Database query failed...")
|
||||||
return structure(status="error", msg="Database error!")
|
return structure(status="error", msg="Database error!")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -834,7 +849,7 @@ def logs_to_report_json(target):
|
||||||
"event": json.loads(log[3]),
|
"event": json.loads(log[3]),
|
||||||
"json_event": json.loads(log[4]),
|
"json_event": json.loads(log[4]),
|
||||||
}
|
}
|
||||||
return_logs.append(data)
|
return_logs.append(data)
|
||||||
return return_logs
|
return return_logs
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -1001,7 +1016,6 @@ def search_logs(page, query):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
targets = cursor.fetchall()
|
targets = cursor.fetchall()
|
||||||
cursor.close()
|
|
||||||
for target_row in targets:
|
for target_row in targets:
|
||||||
target = target_row[0]
|
target = target_row[0]
|
||||||
# Fetch data for each target grouped by key fields
|
# Fetch data for each target grouped by key fields
|
||||||
|
|
@ -1044,6 +1058,7 @@ def search_logs(page, query):
|
||||||
tmp["info"]["json_event"].append(parsed_json_event)
|
tmp["info"]["json_event"].append(parsed_json_event)
|
||||||
|
|
||||||
selected.append(tmp)
|
selected.append(tmp)
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return structure(status="error", msg="database error!")
|
return structure(status="error", msg="database error!")
|
||||||
|
|
|
||||||
|
|
@ -94,6 +94,9 @@ profile = "black"
|
||||||
addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto"
|
addopts = "--cov=nettacker --cov-config=pyproject.toml --cov-report term --cov-report xml --dist loadscope --no-cov-on-fail --numprocesses auto"
|
||||||
asyncio_default_fixture_loop_scope = "function"
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
markers = [
|
||||||
|
"asyncio: mark test as async"
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 99
|
line-length = 99
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,12 @@ def test_load_with_service_discovery(
|
||||||
}
|
}
|
||||||
mock_loader.return_value = mock_loader_inst
|
mock_loader.return_value = mock_loader_inst
|
||||||
|
|
||||||
mock_find_events.return_value = [
|
mock_event1 = MagicMock()
|
||||||
MagicMock(json_event='{"port": 80, "response": {"conditions_results": {"http": {}}}}')
|
mock_event1.json_event = json.dumps(
|
||||||
]
|
{"port": 80, "response": {"conditions_results": {"http": {}}}}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_find_events.return_value = [mock_event1]
|
||||||
|
|
||||||
module = Module("test_module", options, **module_args)
|
module = Module("test_module", options, **module_args)
|
||||||
module.load()
|
module.load()
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ class TestDatabase:
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
@patch("nettacker.database.db.messages", return_value="mocked fail message")
|
@patch("nettacker.database.db.messages", return_value="mocked fail message")
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
def test_send_submit_query_sqlite_failure(self, mock_warn, mock_messages):
|
def test_send_submit_query_sqlite_failure(self, mock_warn, mock_messages):
|
||||||
def sqlite_execute_side_effect(query):
|
def sqlite_execute_side_effect(query):
|
||||||
if query == "COMMIT":
|
if query == "COMMIT":
|
||||||
|
|
@ -210,7 +210,7 @@ class TestDatabase:
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
@patch("nettacker.database.db.messages", return_value="mocked fail message")
|
@patch("nettacker.database.db.messages", return_value="mocked fail message")
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
def test_send_submit_query_sqlalchemy_failure(self, mock_warn, mock_messages):
|
def test_send_submit_query_sqlalchemy_failure(self, mock_warn, mock_messages):
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session.commit.side_effect = [Exception("fail")] * 100
|
mock_session.commit.side_effect = [Exception("fail")] * 100
|
||||||
|
|
@ -372,7 +372,7 @@ class TestDatabase:
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
@patch("nettacker.database.db.messages", return_value="invalid log")
|
@patch("nettacker.database.db.messages", return_value="invalid log")
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
def test_log_not_dict(self, mock_warn, mock_messages):
|
def test_log_not_dict(self, mock_warn, mock_messages):
|
||||||
result = submit_logs_to_db("notadict")
|
result = submit_logs_to_db("notadict")
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
@ -402,7 +402,7 @@ class TestDatabase:
|
||||||
|
|
||||||
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
|
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
|
||||||
@patch("nettacker.database.db.Config.settings.max_retries", 1)
|
@patch("nettacker.database.db.Config.settings.max_retries", 1)
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
@patch("nettacker.database.db.create_connection")
|
@patch("nettacker.database.db.create_connection")
|
||||||
def test_apsw_busy_error(self, mock_create_conn, mock_warn):
|
def test_apsw_busy_error(self, mock_create_conn, mock_warn):
|
||||||
mock_conn = Mock()
|
mock_conn = Mock()
|
||||||
|
|
@ -544,7 +544,7 @@ class TestDatabase:
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
@patch("nettacker.database.db.messages", return_value="invalid log")
|
@patch("nettacker.database.db.messages", return_value="invalid log")
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
def test_temp_log_not_dict(self, mock_warn, mock_messages):
|
def test_temp_log_not_dict(self, mock_warn, mock_messages):
|
||||||
result = submit_temp_logs_to_db("notadict")
|
result = submit_temp_logs_to_db("notadict")
|
||||||
assert result is False
|
assert result is False
|
||||||
|
|
@ -552,7 +552,7 @@ class TestDatabase:
|
||||||
|
|
||||||
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
|
@patch("nettacker.database.db.Config.settings.retry_delay", 0)
|
||||||
@patch("nettacker.database.db.Config.settings.max_retries", 1)
|
@patch("nettacker.database.db.Config.settings.max_retries", 1)
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
@patch("nettacker.database.db.create_connection")
|
@patch("nettacker.database.db.create_connection")
|
||||||
def test_temp_log_busy_error(self, mock_create_conn, mock_warn):
|
def test_temp_log_busy_error(self, mock_create_conn, mock_warn):
|
||||||
mock_conn = Mock()
|
mock_conn = Mock()
|
||||||
|
|
@ -687,7 +687,7 @@ class TestDatabase:
|
||||||
result = find_temp_events(self.target, self.module, self.scan_id, self.event_name)
|
result = find_temp_events(self.target, self.module, self.scan_id, self.event_name)
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
@patch("nettacker.database.db.messages", return_value="database fail")
|
@patch("nettacker.database.db.messages", return_value="database fail")
|
||||||
@patch("nettacker.database.db.create_connection")
|
@patch("nettacker.database.db.create_connection")
|
||||||
def test_sqlite_outer_exception(self, mock_create_conn, mock_messages, mock_warn):
|
def test_sqlite_outer_exception(self, mock_create_conn, mock_messages, mock_warn):
|
||||||
|
|
@ -732,15 +732,21 @@ class TestDatabase:
|
||||||
|
|
||||||
result = find_temp_events("192.168.1.1", "port_scan", "scan_123", "event_1")
|
result = find_temp_events("192.168.1.1", "port_scan", "scan_123", "event_1")
|
||||||
|
|
||||||
mock_cursor.execute.assert_called_with(
|
called_query, called_params = mock_cursor.execute.call_args[0]
|
||||||
"""
|
|
||||||
SELECT event
|
expected_query = """
|
||||||
FROM temp_events
|
SELECT event
|
||||||
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
|
FROM temp_events
|
||||||
LIMIT 1
|
WHERE target = ? AND module_name = ? AND scan_unique_id = ? AND event_name = ?
|
||||||
""",
|
LIMIT 1
|
||||||
("192.168.1.1", "port_scan", "scan_123", "event_1"),
|
"""
|
||||||
)
|
|
||||||
|
# Normalize whitespace (collapse multiple spaces/newlines into one space)
|
||||||
|
def normalize(sql: str) -> str:
|
||||||
|
return " ".join(sql.split())
|
||||||
|
|
||||||
|
assert normalize(called_query) == normalize(expected_query)
|
||||||
|
assert called_params == ("192.168.1.1", "port_scan", "scan_123", "event_1")
|
||||||
assert result == {"test": "data"}
|
assert result == {"test": "data"}
|
||||||
|
|
||||||
# -------------------------------------------------------
|
# -------------------------------------------------------
|
||||||
|
|
@ -768,7 +774,7 @@ class TestDatabase:
|
||||||
expected = ['{"event1": "data1"}', '{"event2": "data2"}']
|
expected = ['{"event1": "data1"}', '{"event2": "data2"}']
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
@patch("nettacker.database.db.create_connection")
|
@patch("nettacker.database.db.create_connection")
|
||||||
def test_find_events_sqlite_exception(self, mock_create_conn, mock_warn):
|
def test_find_events_sqlite_exception(self, mock_create_conn, mock_warn):
|
||||||
mock_connection = Mock()
|
mock_connection = Mock()
|
||||||
|
|
@ -839,7 +845,7 @@ class TestDatabase:
|
||||||
]
|
]
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
@patch("nettacker.database.db.logging.warn")
|
@patch("nettacker.database.db.logger.warn")
|
||||||
@patch("nettacker.database.db.create_connection")
|
@patch("nettacker.database.db.create_connection")
|
||||||
def test_select_reports_sqlite_exception(self, mock_create_conn, mock_warn):
|
def test_select_reports_sqlite_exception(self, mock_create_conn, mock_warn):
|
||||||
mock_connection = Mock()
|
mock_connection = Mock()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue