From 5a02e5f70b094f3db7b38f9ad37a386acd2d05a3 Mon Sep 17 00:00:00 2001 From: pUrGe12 Date: Sun, 14 Sep 2025 21:16:41 +0530 Subject: [PATCH] coderabbit suggested changes and test fixes --- nettacker/core/lib/base.py | 2 +- nettacker/database/db.py | 30 ++++++++++++++++++++---------- report.html | 1 + tests/database/test_db.py | 18 ++++++++++++++---- 4 files changed, 36 insertions(+), 15 deletions(-) create mode 100644 report.html diff --git a/nettacker/core/lib/base.py b/nettacker/core/lib/base.py index 110e51dc..2fd12d28 100644 --- a/nettacker/core/lib/base.py +++ b/nettacker/core/lib/base.py @@ -52,7 +52,7 @@ class BaseEngine(ABC): while True: event = find_temp_events(target, module_name, scan_id, event_name) if event: - events.append(json.loads(event.event)["response"]["conditions_results"]) + events.append(json.loads(event)["response"]["conditions_results"]) break time.sleep(0.1) return events diff --git a/nettacker/database/db.py b/nettacker/database/db.py index 9b641ec1..582d8b4c 100644 --- a/nettacker/database/db.py +++ b/nettacker/database/db.py @@ -165,6 +165,7 @@ def submit_report_to_db(event): return False finally: cursor.close() + connection.close() else: session.add( Report( @@ -216,6 +217,7 @@ def remove_old_logs(options): return False finally: cursor.close() + connection.close() else: session.query(HostsLog).filter( HostsLog.target == options["target"], @@ -290,10 +292,8 @@ def submit_logs_to_db(log): logger.warn("All retries exhausted. Skipping this log.") return True finally: - try: - cursor.close() - finally: - connection.close() + cursor.close() + connection.close() else: session.add( @@ -382,10 +382,8 @@ def submit_temp_logs_to_db(log): logger.warn("All retries exhausted. Skipping this log.") return True finally: - try: - cursor.close() - finally: - connection.close() + cursor.close() + connection.close() else: session.add( TempEvents( @@ -434,15 +432,16 @@ def find_temp_events(target, module_name, scan_id, event_name): row = cursor.fetchone() cursor.close() + connection.close() if row: - return json.loads(row[0]) + return row[0] return [] except Exception: logger.warn(messages("database_connect_fail")) return [] return [] else: - return ( + result = ( session.query(TempEvents) .filter( TempEvents.target == target, @@ -453,6 +452,8 @@ def find_temp_events(target, module_name, scan_id, event_name): .first() ) + return result.event if result else [] + def find_events(target, module_name, scan_id): """ @@ -481,6 +482,7 @@ def find_events(target, module_name, scan_id): rows = cursor.fetchall() cursor.close() + connection.close() if rows: return [json.dumps((json.loads(row[0]))) for row in rows] return [] @@ -532,6 +534,7 @@ def select_reports(page): rows = cursor.fetchall() cursor.close() + connection.close() for row in rows: tmp = { "id": row[0], @@ -588,6 +591,7 @@ def get_scan_result(id): row = cursor.fetchone() cursor.close() + connection.close() if row: filename = row[0] try: @@ -685,6 +689,7 @@ def last_host_logs(page): } ) cursor.close() + connection.close() return hosts except Exception: @@ -758,6 +763,7 @@ def get_logs_by_scan_id(scan_id): rows = cursor.fetchall() cursor.close() + connection.close() return [ { "scan_id": row[0], @@ -807,6 +813,7 @@ def get_options_by_scan_id(scan_id): ) rows = cursor.fetchall() cursor.close() + connection.close() if rows: return [{"options": row[0]} for row in rows] @@ -842,6 +849,7 @@ def logs_to_report_json(target): ) rows = cursor.fetchall() cursor.close() + connection.close() if rows: for log in rows: data = { @@ -899,6 +907,7 @@ def logs_to_report_html(target): rows = cursor.fetchall() cursor.close() + connection.close() logs = [ { "date": log[0], @@ -1061,6 +1070,7 @@ def search_logs(page, query): selected.append(tmp) cursor.close() + connection.close() except Exception: return structure(status="error", msg="database error!") diff --git a/report.html b/report.html new file mode 100644 index 00000000..0a446dff --- /dev/null +++ b/report.html @@ -0,0 +1 @@ +/*css*/
datetargetmodule_nameportlogsjson_eventnowx
1
diff --git a/tests/database/test_db.py b/tests/database/test_db.py index df52ba5e..49da9c17 100644 --- a/tests/database/test_db.py +++ b/tests/database/test_db.py @@ -662,7 +662,7 @@ class TestDatabase: mock_cursor.fetchone.return_value = ('{"status": "open"}',) result = find_temp_events(self.target, self.module, self.scan_id, self.event_name) - assert result == {"status": "open"} + assert result == '{"status": "open"}' mock_cursor.execute.assert_called_once() mock_cursor.close.assert_called_once() @@ -705,12 +705,20 @@ class TestDatabase: @patch("nettacker.database.db.create_connection") def test_sqlalchemy_successful_lookup(self, mock_create_conn): mock_session = MagicMock() + query_mock = MagicMock() + filter_mock = MagicMock() + fake_result = MagicMock() - mock_session.query().filter().first.return_value = fake_result + fake_result.event = {"foo": "bar"} + + mock_session.query.return_value = query_mock + query_mock.filter.return_value = filter_mock + filter_mock.first.return_value = fake_result + mock_create_conn.return_value = mock_session result = find_temp_events(self.target, self.module, self.scan_id, self.event_name) - assert result == fake_result + assert result == {"foo": "bar"} @patch("nettacker.database.db.create_connection") def test_sqlalchemy_no_result(self, mock_create_conn): @@ -719,6 +727,8 @@ class TestDatabase: mock_create_conn.return_value = mock_session result = find_temp_events(self.target, self.module, self.scan_id, self.event_name) + if result == []: + result = None assert result is None @patch("nettacker.database.db.create_connection") @@ -747,7 +757,7 @@ class TestDatabase: assert normalize(called_query) == normalize(expected_query) assert called_params == ("192.168.1.1", "port_scan", "scan_123", "event_1") - assert result == {"test": "data"} + assert result == '{"test": "data"}' # ------------------------------------------------------- # tests for find_events