Unittets for database files (#1077)

* unittests for database files

* ruff fixes
This commit is contained in:
Achintya Jai 2025-06-09 04:29:42 +05:30 committed by GitHub
parent 4fd743a15d
commit af7abb683c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 369 additions and 0 deletions

View File

@ -0,0 +1,94 @@
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nettacker.database.models import Base, Report, TempEvents, HostsLog
from tests.common import TestCase
class TestModels(TestCase):
def setUp(self):
# Creating an in-memory SQLite database for testing
self.engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
def tearDown(self):
self.session.close()
Base.metadata.drop_all(self.engine)
def test_report_model(self):
test_date = datetime.now()
test_report = Report(
date=test_date,
scan_unique_id="test123",
report_path_filename="/path/to/report.txt",
options='{"option1": "value1"}',
)
self.session.add(test_report)
self.session.commit()
retrieved_report = self.session.query(Report).first()
self.assertIsNotNone(retrieved_report)
self.assertEqual(retrieved_report.scan_unique_id, "test123")
self.assertEqual(retrieved_report.report_path_filename, "/path/to/report.txt")
self.assertEqual(retrieved_report.options, '{"option1": "value1"}')
repr_string = repr(retrieved_report)
self.assertIn("test123", repr_string)
self.assertIn("/path/to/report.txt", repr_string)
def test_temp_events_model(self):
test_date = datetime.now()
test_event = TempEvents(
date=test_date,
target="192.168.1.1",
module_name="port_scan",
scan_unique_id="test123",
event_name="open_port",
port="80",
event="Port 80 is open",
data='{"details": "HTTP server running"}',
)
self.session.add(test_event)
self.session.commit()
retrieved_event = self.session.query(TempEvents).first()
self.assertIsNotNone(retrieved_event)
self.assertEqual(retrieved_event.target, "192.168.1.1")
self.assertEqual(retrieved_event.module_name, "port_scan")
self.assertEqual(retrieved_event.port, "80")
repr_string = repr(retrieved_event)
self.assertIn("192.168.1.1", repr_string)
self.assertIn("port_scan", repr_string)
def test_hosts_log_model(self):
test_date = datetime.now()
test_log = HostsLog(
date=test_date,
target="192.168.1.1",
module_name="vulnerability_scan",
scan_unique_id="test123",
port="443",
event="Found vulnerability CVE-2021-12345",
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
)
self.session.add(test_log)
self.session.commit()
retrieved_log = self.session.query(HostsLog).first()
self.assertIsNotNone(retrieved_log)
self.assertEqual(retrieved_log.target, "192.168.1.1")
self.assertEqual(retrieved_log.module_name, "vulnerability_scan")
self.assertEqual(retrieved_log.port, "443")
self.assertEqual(retrieved_log.event, "Found vulnerability CVE-2021-12345")
repr_string = repr(retrieved_log)
self.assertIn("192.168.1.1", repr_string)
self.assertIn("vulnerability_scan", repr_string)

View File

@ -0,0 +1,141 @@
from unittest.mock import patch, MagicMock
from sqlalchemy.exc import SQLAlchemyError
from nettacker.config import Config
from nettacker.database.models import Base
from nettacker.database.mysql import mysql_create_database, mysql_create_tables
from tests.common import TestCase
class TestMySQLFunctions(TestCase):
"""Test cases for mysql.py functions"""
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_success(self, mock_create_engine):
"""Test successful database creation"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
# Set up mock connection and execution
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
# Mock database query results - database doesn't exist yet
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
# Call the function
mysql_create_database()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
# Check that execute was called with any text object that has the expected SQL
call_args_list = mock_conn.execute.call_args_list
self.assertEqual(len(call_args_list), 2) # Two calls to execute
# Check that the first call is SHOW DATABASES
first_call_arg = call_args_list[0][0][0]
self.assertEqual(str(first_call_arg), "SHOW DATABASES;")
# Check that the second call is CREATE DATABASE
second_call_arg = call_args_list[1][0][0]
self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ")
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_already_exists(self, mock_create_engine):
"""Test when database already exists"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
# Set up mock connection and execution
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
# Mock database query results - database already exists
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
# Call the function
mysql_create_database()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
# Check that execute was called once with SHOW DATABASES
self.assertEqual(mock_conn.execute.call_count, 1)
call_arg = mock_conn.execute.call_args[0][0]
self.assertEqual(str(call_arg), "SHOW DATABASES;")
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_exception(self, mock_create_engine):
"""Test exception handling in create database"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
# Set up mock to raise exception
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
# Call the function (should not raise exception)
with patch("builtins.print") as mock_print:
mysql_create_database()
mock_print.assert_called_once()
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_tables(self, mock_create_engine):
"""Test table creation function"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
# Set up mock engine
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
# Call the function
with patch.object(Base.metadata, "create_all") as mock_create_all:
mysql_create_tables()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
)
mock_create_all.assert_called_once_with(mock_engine)

View File

@ -0,0 +1,92 @@
from unittest.mock import patch, MagicMock
from sqlalchemy.exc import OperationalError
from nettacker.config import Config
from nettacker.database.models import Base
from nettacker.database.postgresql import postgres_create_database
from tests.common import TestCase
class TestPostgresFunctions(TestCase):
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_success(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all:
postgres_create_database()
mock_create_engine.assert_called_once_with(
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
)
mock_create_all.assert_called_once_with(mock_engine)
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_if_not_exists(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
Config.db.name = "nettacker_db"
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_engine_final = MagicMock()
mock_create_engine.side_effect = [
mock_engine_initial,
mock_engine_fallback,
mock_engine_final,
]
with patch.object(
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
):
mock_conn = MagicMock()
mock_engine_fallback.connect.return_value = mock_conn
mock_conn.execution_options.return_value = mock_conn
postgres_create_database()
assert mock_create_engine.call_count == 3
args, _ = mock_conn.execute.call_args
assert str(args[0]) == "CREATE DATABASE nettacker_db"
mock_conn.close.assert_called_once()
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_create_fail(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
with patch.object(
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
):
with self.assertRaises(OperationalError):
postgres_create_database()

View File

@ -0,0 +1,42 @@
from unittest.mock import patch, MagicMock
from sqlalchemy import create_engine, inspect
from nettacker.config import Config
from nettacker.database.models import Base
from nettacker.database.sqlite import sqlite_create_tables
from tests.common import TestCase
class TestSQLiteFunctions(TestCase):
@patch("nettacker.database.sqlite.create_engine")
def test_sqlite_create_tables(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {"name": "/path/to/test.db"}
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all:
sqlite_create_tables()
mock_create_engine.assert_called_once_with(
"sqlite:////path/to/test.db", connect_args={"check_same_thread": False}
)
mock_create_all.assert_called_once_with(mock_engine)
def test_sqlite_create_tables_integration(self):
engine = create_engine("sqlite:///:memory:")
Config.db = MagicMock()
Config.db.as_dict.return_value = {"name": ":memory:"}
with patch("nettacker.database.sqlite.create_engine", return_value=engine):
sqlite_create_tables()
inspector = inspect(engine)
tables = inspector.get_table_names()
self.assertIn("reports", tables, "Reports table was not created")
self.assertIn("temp_events", tables, "Temp events table was not created")
self.assertIn("scan_events", tables, "Scan events table was not created")