mirror of https://github.com/OWASP/Nettacker.git
Unittets for database files (#1077)
* unittests for database files * ruff fixes
This commit is contained in:
parent
4fd743a15d
commit
af7abb683c
|
|
@ -0,0 +1,94 @@
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from nettacker.database.models import Base, Report, TempEvents, HostsLog
|
||||||
|
from tests.common import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Creating an in-memory SQLite database for testing
|
||||||
|
self.engine = create_engine("sqlite:///:memory:")
|
||||||
|
Base.metadata.create_all(self.engine)
|
||||||
|
Session = sessionmaker(bind=self.engine)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.session.close()
|
||||||
|
Base.metadata.drop_all(self.engine)
|
||||||
|
|
||||||
|
def test_report_model(self):
|
||||||
|
test_date = datetime.now()
|
||||||
|
test_report = Report(
|
||||||
|
date=test_date,
|
||||||
|
scan_unique_id="test123",
|
||||||
|
report_path_filename="/path/to/report.txt",
|
||||||
|
options='{"option1": "value1"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.add(test_report)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
retrieved_report = self.session.query(Report).first()
|
||||||
|
self.assertIsNotNone(retrieved_report)
|
||||||
|
self.assertEqual(retrieved_report.scan_unique_id, "test123")
|
||||||
|
self.assertEqual(retrieved_report.report_path_filename, "/path/to/report.txt")
|
||||||
|
self.assertEqual(retrieved_report.options, '{"option1": "value1"}')
|
||||||
|
|
||||||
|
repr_string = repr(retrieved_report)
|
||||||
|
self.assertIn("test123", repr_string)
|
||||||
|
self.assertIn("/path/to/report.txt", repr_string)
|
||||||
|
|
||||||
|
def test_temp_events_model(self):
|
||||||
|
test_date = datetime.now()
|
||||||
|
test_event = TempEvents(
|
||||||
|
date=test_date,
|
||||||
|
target="192.168.1.1",
|
||||||
|
module_name="port_scan",
|
||||||
|
scan_unique_id="test123",
|
||||||
|
event_name="open_port",
|
||||||
|
port="80",
|
||||||
|
event="Port 80 is open",
|
||||||
|
data='{"details": "HTTP server running"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.add(test_event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
retrieved_event = self.session.query(TempEvents).first()
|
||||||
|
self.assertIsNotNone(retrieved_event)
|
||||||
|
self.assertEqual(retrieved_event.target, "192.168.1.1")
|
||||||
|
self.assertEqual(retrieved_event.module_name, "port_scan")
|
||||||
|
self.assertEqual(retrieved_event.port, "80")
|
||||||
|
|
||||||
|
repr_string = repr(retrieved_event)
|
||||||
|
self.assertIn("192.168.1.1", repr_string)
|
||||||
|
self.assertIn("port_scan", repr_string)
|
||||||
|
|
||||||
|
def test_hosts_log_model(self):
|
||||||
|
test_date = datetime.now()
|
||||||
|
test_log = HostsLog(
|
||||||
|
date=test_date,
|
||||||
|
target="192.168.1.1",
|
||||||
|
module_name="vulnerability_scan",
|
||||||
|
scan_unique_id="test123",
|
||||||
|
port="443",
|
||||||
|
event="Found vulnerability CVE-2021-12345",
|
||||||
|
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
|
||||||
|
)
|
||||||
|
|
||||||
|
self.session.add(test_log)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
retrieved_log = self.session.query(HostsLog).first()
|
||||||
|
self.assertIsNotNone(retrieved_log)
|
||||||
|
self.assertEqual(retrieved_log.target, "192.168.1.1")
|
||||||
|
self.assertEqual(retrieved_log.module_name, "vulnerability_scan")
|
||||||
|
self.assertEqual(retrieved_log.port, "443")
|
||||||
|
self.assertEqual(retrieved_log.event, "Found vulnerability CVE-2021-12345")
|
||||||
|
|
||||||
|
repr_string = repr(retrieved_log)
|
||||||
|
self.assertIn("192.168.1.1", repr_string)
|
||||||
|
self.assertIn("vulnerability_scan", repr_string)
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from sqlalchemy.exc import SQLAlchemyError
|
||||||
|
|
||||||
|
from nettacker.config import Config
|
||||||
|
from nettacker.database.models import Base
|
||||||
|
from nettacker.database.mysql import mysql_create_database, mysql_create_tables
|
||||||
|
from tests.common import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestMySQLFunctions(TestCase):
|
||||||
|
"""Test cases for mysql.py functions"""
|
||||||
|
|
||||||
|
@patch("nettacker.database.mysql.create_engine")
|
||||||
|
def test_mysql_create_database_success(self, mock_create_engine):
|
||||||
|
"""Test successful database creation"""
|
||||||
|
# Set up mock config
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "test_user",
|
||||||
|
"password": "test_pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "3306",
|
||||||
|
"name": "test_db",
|
||||||
|
}
|
||||||
|
Config.db.name = "test_db"
|
||||||
|
|
||||||
|
# Set up mock connection and execution
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
mock_engine.connect.return_value.__enter__.return_value = mock_conn
|
||||||
|
|
||||||
|
# Mock database query results - database doesn't exist yet
|
||||||
|
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
mysql_create_database()
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
mock_create_engine.assert_called_once_with(
|
||||||
|
"mysql+pymysql://test_user:test_pass@localhost:3306"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that execute was called with any text object that has the expected SQL
|
||||||
|
call_args_list = mock_conn.execute.call_args_list
|
||||||
|
self.assertEqual(len(call_args_list), 2) # Two calls to execute
|
||||||
|
|
||||||
|
# Check that the first call is SHOW DATABASES
|
||||||
|
first_call_arg = call_args_list[0][0][0]
|
||||||
|
self.assertEqual(str(first_call_arg), "SHOW DATABASES;")
|
||||||
|
|
||||||
|
# Check that the second call is CREATE DATABASE
|
||||||
|
second_call_arg = call_args_list[1][0][0]
|
||||||
|
self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ")
|
||||||
|
|
||||||
|
@patch("nettacker.database.mysql.create_engine")
|
||||||
|
def test_mysql_create_database_already_exists(self, mock_create_engine):
|
||||||
|
"""Test when database already exists"""
|
||||||
|
# Set up mock config
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "test_user",
|
||||||
|
"password": "test_pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "3306",
|
||||||
|
"name": "test_db",
|
||||||
|
}
|
||||||
|
Config.db.name = "test_db"
|
||||||
|
|
||||||
|
# Set up mock connection and execution
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
mock_engine.connect.return_value.__enter__.return_value = mock_conn
|
||||||
|
|
||||||
|
# Mock database query results - database already exists
|
||||||
|
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
mysql_create_database()
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
mock_create_engine.assert_called_once_with(
|
||||||
|
"mysql+pymysql://test_user:test_pass@localhost:3306"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that execute was called once with SHOW DATABASES
|
||||||
|
self.assertEqual(mock_conn.execute.call_count, 1)
|
||||||
|
call_arg = mock_conn.execute.call_args[0][0]
|
||||||
|
self.assertEqual(str(call_arg), "SHOW DATABASES;")
|
||||||
|
|
||||||
|
@patch("nettacker.database.mysql.create_engine")
|
||||||
|
def test_mysql_create_database_exception(self, mock_create_engine):
|
||||||
|
"""Test exception handling in create database"""
|
||||||
|
# Set up mock config
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "test_user",
|
||||||
|
"password": "test_pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "3306",
|
||||||
|
"name": "test_db",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set up mock to raise exception
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
|
||||||
|
|
||||||
|
# Call the function (should not raise exception)
|
||||||
|
with patch("builtins.print") as mock_print:
|
||||||
|
mysql_create_database()
|
||||||
|
mock_print.assert_called_once()
|
||||||
|
|
||||||
|
@patch("nettacker.database.mysql.create_engine")
|
||||||
|
def test_mysql_create_tables(self, mock_create_engine):
|
||||||
|
"""Test table creation function"""
|
||||||
|
# Set up mock config
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "test_user",
|
||||||
|
"password": "test_pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "3306",
|
||||||
|
"name": "test_db",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set up mock engine
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
|
||||||
|
# Call the function
|
||||||
|
with patch.object(Base.metadata, "create_all") as mock_create_all:
|
||||||
|
mysql_create_tables()
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
mock_create_engine.assert_called_once_with(
|
||||||
|
"mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
|
||||||
|
)
|
||||||
|
mock_create_all.assert_called_once_with(mock_engine)
|
||||||
|
|
@ -0,0 +1,92 @@
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
from nettacker.config import Config
|
||||||
|
from nettacker.database.models import Base
|
||||||
|
from nettacker.database.postgresql import postgres_create_database
|
||||||
|
from tests.common import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostgresFunctions(TestCase):
|
||||||
|
@patch("nettacker.database.postgresql.create_engine")
|
||||||
|
def test_postgres_create_database_success(self, mock_create_engine):
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "5432",
|
||||||
|
"name": "nettacker_db",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
|
||||||
|
with patch.object(Base.metadata, "create_all") as mock_create_all:
|
||||||
|
postgres_create_database()
|
||||||
|
|
||||||
|
mock_create_engine.assert_called_once_with(
|
||||||
|
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
|
||||||
|
)
|
||||||
|
mock_create_all.assert_called_once_with(mock_engine)
|
||||||
|
|
||||||
|
@patch("nettacker.database.postgresql.create_engine")
|
||||||
|
def test_postgres_create_database_if_not_exists(self, mock_create_engine):
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "5432",
|
||||||
|
"name": "nettacker_db",
|
||||||
|
}
|
||||||
|
Config.db.name = "nettacker_db"
|
||||||
|
|
||||||
|
mock_engine_initial = MagicMock()
|
||||||
|
mock_engine_fallback = MagicMock()
|
||||||
|
mock_engine_final = MagicMock()
|
||||||
|
|
||||||
|
mock_create_engine.side_effect = [
|
||||||
|
mock_engine_initial,
|
||||||
|
mock_engine_fallback,
|
||||||
|
mock_engine_final,
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
|
||||||
|
):
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_engine_fallback.connect.return_value = mock_conn
|
||||||
|
mock_conn.execution_options.return_value = mock_conn
|
||||||
|
|
||||||
|
postgres_create_database()
|
||||||
|
|
||||||
|
assert mock_create_engine.call_count == 3
|
||||||
|
args, _ = mock_conn.execute.call_args
|
||||||
|
assert str(args[0]) == "CREATE DATABASE nettacker_db"
|
||||||
|
mock_conn.close.assert_called_once()
|
||||||
|
|
||||||
|
@patch("nettacker.database.postgresql.create_engine")
|
||||||
|
def test_postgres_create_database_create_fail(self, mock_create_engine):
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {
|
||||||
|
"username": "user",
|
||||||
|
"password": "pass",
|
||||||
|
"host": "localhost",
|
||||||
|
"port": "5432",
|
||||||
|
"name": "nettacker_db",
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_engine_initial = MagicMock()
|
||||||
|
mock_engine_fallback = MagicMock()
|
||||||
|
|
||||||
|
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
|
||||||
|
|
||||||
|
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
|
||||||
|
):
|
||||||
|
with self.assertRaises(OperationalError):
|
||||||
|
postgres_create_database()
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, inspect
|
||||||
|
|
||||||
|
from nettacker.config import Config
|
||||||
|
from nettacker.database.models import Base
|
||||||
|
from nettacker.database.sqlite import sqlite_create_tables
|
||||||
|
from tests.common import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLiteFunctions(TestCase):
|
||||||
|
@patch("nettacker.database.sqlite.create_engine")
|
||||||
|
def test_sqlite_create_tables(self, mock_create_engine):
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {"name": "/path/to/test.db"}
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
|
||||||
|
with patch.object(Base.metadata, "create_all") as mock_create_all:
|
||||||
|
sqlite_create_tables()
|
||||||
|
|
||||||
|
mock_create_engine.assert_called_once_with(
|
||||||
|
"sqlite:////path/to/test.db", connect_args={"check_same_thread": False}
|
||||||
|
)
|
||||||
|
mock_create_all.assert_called_once_with(mock_engine)
|
||||||
|
|
||||||
|
def test_sqlite_create_tables_integration(self):
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
Config.db = MagicMock()
|
||||||
|
Config.db.as_dict.return_value = {"name": ":memory:"}
|
||||||
|
|
||||||
|
with patch("nettacker.database.sqlite.create_engine", return_value=engine):
|
||||||
|
sqlite_create_tables()
|
||||||
|
|
||||||
|
inspector = inspect(engine)
|
||||||
|
tables = inspector.get_table_names()
|
||||||
|
|
||||||
|
self.assertIn("reports", tables, "Reports table was not created")
|
||||||
|
self.assertIn("temp_events", tables, "Temp events table was not created")
|
||||||
|
self.assertIn("scan_events", tables, "Scan events table was not created")
|
||||||
Loading…
Reference in New Issue