diff --git a/tests/database/test_models.py b/tests/database/test_models.py new file mode 100644 index 00000000..47a6b705 --- /dev/null +++ b/tests/database/test_models.py @@ -0,0 +1,94 @@ +from datetime import datetime + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from nettacker.database.models import Base, Report, TempEvents, HostsLog +from tests.common import TestCase + + +class TestModels(TestCase): + def setUp(self): + # Creating an in-memory SQLite database for testing + self.engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + self.session = Session() + + def tearDown(self): + self.session.close() + Base.metadata.drop_all(self.engine) + + def test_report_model(self): + test_date = datetime.now() + test_report = Report( + date=test_date, + scan_unique_id="test123", + report_path_filename="/path/to/report.txt", + options='{"option1": "value1"}', + ) + + self.session.add(test_report) + self.session.commit() + + retrieved_report = self.session.query(Report).first() + self.assertIsNotNone(retrieved_report) + self.assertEqual(retrieved_report.scan_unique_id, "test123") + self.assertEqual(retrieved_report.report_path_filename, "/path/to/report.txt") + self.assertEqual(retrieved_report.options, '{"option1": "value1"}') + + repr_string = repr(retrieved_report) + self.assertIn("test123", repr_string) + self.assertIn("/path/to/report.txt", repr_string) + + def test_temp_events_model(self): + test_date = datetime.now() + test_event = TempEvents( + date=test_date, + target="192.168.1.1", + module_name="port_scan", + scan_unique_id="test123", + event_name="open_port", + port="80", + event="Port 80 is open", + data='{"details": "HTTP server running"}', + ) + + self.session.add(test_event) + self.session.commit() + + retrieved_event = self.session.query(TempEvents).first() + self.assertIsNotNone(retrieved_event) + self.assertEqual(retrieved_event.target, "192.168.1.1") + self.assertEqual(retrieved_event.module_name, "port_scan") + self.assertEqual(retrieved_event.port, "80") + + repr_string = repr(retrieved_event) + self.assertIn("192.168.1.1", repr_string) + self.assertIn("port_scan", repr_string) + + def test_hosts_log_model(self): + test_date = datetime.now() + test_log = HostsLog( + date=test_date, + target="192.168.1.1", + module_name="vulnerability_scan", + scan_unique_id="test123", + port="443", + event="Found vulnerability CVE-2021-12345", + json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}', + ) + + self.session.add(test_log) + self.session.commit() + + retrieved_log = self.session.query(HostsLog).first() + self.assertIsNotNone(retrieved_log) + self.assertEqual(retrieved_log.target, "192.168.1.1") + self.assertEqual(retrieved_log.module_name, "vulnerability_scan") + self.assertEqual(retrieved_log.port, "443") + self.assertEqual(retrieved_log.event, "Found vulnerability CVE-2021-12345") + + repr_string = repr(retrieved_log) + self.assertIn("192.168.1.1", repr_string) + self.assertIn("vulnerability_scan", repr_string) diff --git a/tests/database/test_mysql.py b/tests/database/test_mysql.py new file mode 100644 index 00000000..0dd3335e --- /dev/null +++ b/tests/database/test_mysql.py @@ -0,0 +1,141 @@ +from unittest.mock import patch, MagicMock + +from sqlalchemy.exc import SQLAlchemyError + +from nettacker.config import Config +from nettacker.database.models import Base +from nettacker.database.mysql import mysql_create_database, mysql_create_tables +from tests.common import TestCase + + +class TestMySQLFunctions(TestCase): + """Test cases for mysql.py functions""" + + @patch("nettacker.database.mysql.create_engine") + def test_mysql_create_database_success(self, mock_create_engine): + """Test successful database creation""" + # Set up mock config + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "test_user", + "password": "test_pass", + "host": "localhost", + "port": "3306", + "name": "test_db", + } + Config.db.name = "test_db" + + # Set up mock connection and execution + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + mock_engine.connect.return_value.__enter__.return_value = mock_conn + + # Mock database query results - database doesn't exist yet + mock_conn.execute.return_value = [("mysql",), ("information_schema",)] + + # Call the function + mysql_create_database() + + # Assertions + mock_create_engine.assert_called_once_with( + "mysql+pymysql://test_user:test_pass@localhost:3306" + ) + + # Check that execute was called with any text object that has the expected SQL + call_args_list = mock_conn.execute.call_args_list + self.assertEqual(len(call_args_list), 2) # Two calls to execute + + # Check that the first call is SHOW DATABASES + first_call_arg = call_args_list[0][0][0] + self.assertEqual(str(first_call_arg), "SHOW DATABASES;") + + # Check that the second call is CREATE DATABASE + second_call_arg = call_args_list[1][0][0] + self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ") + + @patch("nettacker.database.mysql.create_engine") + def test_mysql_create_database_already_exists(self, mock_create_engine): + """Test when database already exists""" + # Set up mock config + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "test_user", + "password": "test_pass", + "host": "localhost", + "port": "3306", + "name": "test_db", + } + Config.db.name = "test_db" + + # Set up mock connection and execution + mock_conn = MagicMock() + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + mock_engine.connect.return_value.__enter__.return_value = mock_conn + + # Mock database query results - database already exists + mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)] + + # Call the function + mysql_create_database() + + # Assertions + mock_create_engine.assert_called_once_with( + "mysql+pymysql://test_user:test_pass@localhost:3306" + ) + + # Check that execute was called once with SHOW DATABASES + self.assertEqual(mock_conn.execute.call_count, 1) + call_arg = mock_conn.execute.call_args[0][0] + self.assertEqual(str(call_arg), "SHOW DATABASES;") + + @patch("nettacker.database.mysql.create_engine") + def test_mysql_create_database_exception(self, mock_create_engine): + """Test exception handling in create database""" + # Set up mock config + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "test_user", + "password": "test_pass", + "host": "localhost", + "port": "3306", + "name": "test_db", + } + + # Set up mock to raise exception + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + mock_engine.connect.side_effect = SQLAlchemyError("Connection error") + + # Call the function (should not raise exception) + with patch("builtins.print") as mock_print: + mysql_create_database() + mock_print.assert_called_once() + + @patch("nettacker.database.mysql.create_engine") + def test_mysql_create_tables(self, mock_create_engine): + """Test table creation function""" + # Set up mock config + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "test_user", + "password": "test_pass", + "host": "localhost", + "port": "3306", + "name": "test_db", + } + + # Set up mock engine + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + # Call the function + with patch.object(Base.metadata, "create_all") as mock_create_all: + mysql_create_tables() + + # Assertions + mock_create_engine.assert_called_once_with( + "mysql+pymysql://test_user:test_pass@localhost:3306/test_db" + ) + mock_create_all.assert_called_once_with(mock_engine) diff --git a/tests/database/test_postgresql.py b/tests/database/test_postgresql.py new file mode 100644 index 00000000..a7be39fe --- /dev/null +++ b/tests/database/test_postgresql.py @@ -0,0 +1,92 @@ +from unittest.mock import patch, MagicMock + +from sqlalchemy.exc import OperationalError + +from nettacker.config import Config +from nettacker.database.models import Base +from nettacker.database.postgresql import postgres_create_database +from tests.common import TestCase + + +class TestPostgresFunctions(TestCase): + @patch("nettacker.database.postgresql.create_engine") + def test_postgres_create_database_success(self, mock_create_engine): + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "user", + "password": "pass", + "host": "localhost", + "port": "5432", + "name": "nettacker_db", + } + + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + with patch.object(Base.metadata, "create_all") as mock_create_all: + postgres_create_database() + + mock_create_engine.assert_called_once_with( + "postgresql+psycopg2://user:pass@localhost:5432/nettacker_db" + ) + mock_create_all.assert_called_once_with(mock_engine) + + @patch("nettacker.database.postgresql.create_engine") + def test_postgres_create_database_if_not_exists(self, mock_create_engine): + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "user", + "password": "pass", + "host": "localhost", + "port": "5432", + "name": "nettacker_db", + } + Config.db.name = "nettacker_db" + + mock_engine_initial = MagicMock() + mock_engine_fallback = MagicMock() + mock_engine_final = MagicMock() + + mock_create_engine.side_effect = [ + mock_engine_initial, + mock_engine_fallback, + mock_engine_final, + ] + + with patch.object( + Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None] + ): + mock_conn = MagicMock() + mock_engine_fallback.connect.return_value = mock_conn + mock_conn.execution_options.return_value = mock_conn + + postgres_create_database() + + assert mock_create_engine.call_count == 3 + args, _ = mock_conn.execute.call_args + assert str(args[0]) == "CREATE DATABASE nettacker_db" + mock_conn.close.assert_called_once() + + @patch("nettacker.database.postgresql.create_engine") + def test_postgres_create_database_create_fail(self, mock_create_engine): + Config.db = MagicMock() + Config.db.as_dict.return_value = { + "username": "user", + "password": "pass", + "host": "localhost", + "port": "5432", + "name": "nettacker_db", + } + + mock_engine_initial = MagicMock() + mock_engine_fallback = MagicMock() + + mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback] + + mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None) + + with patch.object( + Base.metadata, "create_all", side_effect=OperationalError("fail", None, None) + ): + with self.assertRaises(OperationalError): + postgres_create_database() diff --git a/tests/database/test_sqlite.py b/tests/database/test_sqlite.py new file mode 100644 index 00000000..599f0251 --- /dev/null +++ b/tests/database/test_sqlite.py @@ -0,0 +1,42 @@ +from unittest.mock import patch, MagicMock + +from sqlalchemy import create_engine, inspect + +from nettacker.config import Config +from nettacker.database.models import Base +from nettacker.database.sqlite import sqlite_create_tables +from tests.common import TestCase + + +class TestSQLiteFunctions(TestCase): + @patch("nettacker.database.sqlite.create_engine") + def test_sqlite_create_tables(self, mock_create_engine): + Config.db = MagicMock() + Config.db.as_dict.return_value = {"name": "/path/to/test.db"} + + mock_engine = MagicMock() + mock_create_engine.return_value = mock_engine + + with patch.object(Base.metadata, "create_all") as mock_create_all: + sqlite_create_tables() + + mock_create_engine.assert_called_once_with( + "sqlite:////path/to/test.db", connect_args={"check_same_thread": False} + ) + mock_create_all.assert_called_once_with(mock_engine) + + def test_sqlite_create_tables_integration(self): + engine = create_engine("sqlite:///:memory:") + + Config.db = MagicMock() + Config.db.as_dict.return_value = {"name": ":memory:"} + + with patch("nettacker.database.sqlite.create_engine", return_value=engine): + sqlite_create_tables() + + inspector = inspect(engine) + tables = inspector.get_table_names() + + self.assertIn("reports", tables, "Reports table was not created") + self.assertIn("temp_events", tables, "Temp events table was not created") + self.assertIn("scan_events", tables, "Scan events table was not created")