refactor tests and migrate to pytest (#1081)

* refactor tests and migrate to pytest

* Update tests

---------

Co-authored-by: Arkadii Yakovets <arkadii.yakovets@owasp.org>
This commit is contained in:
Achintya Jai 2025-06-11 06:57:16 +05:30 committed by GitHub
parent 8748df910b
commit 74e494dd1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 746 additions and 679 deletions

View File

@ -1,7 +1,8 @@
from unittest.mock import patch
import pytest
from nettacker.core.lib.socket import create_tcp_socket, SocketEngine
from tests.common import TestCase
class Responses:
@ -123,7 +124,22 @@ class Substeps:
}
class TestSocketMethod(TestCase):
@pytest.fixture
def socket_engine():
return SocketEngine()
@pytest.fixture
def substeps():
return Substeps()
@pytest.fixture
def responses():
return Responses()
class TestSocketMethod:
@patch("socket.socket")
@patch("ssl.wrap_socket")
def test_create_tcp_socket(self, mock_wrap, mock_socket):
@ -137,50 +153,43 @@ class TestSocketMethod(TestCase):
socket_instance.connect.assert_called_with((HOST, PORT))
mock_wrap.assert_called_with(socket_instance)
def test_response_conditions_matched(self):
# tests the response conditions matched for different scan methods
engine = SocketEngine()
Substep = Substeps()
Response = Responses()
def test_response_conditions_matched_socket_icmp(self, socket_engine, substeps, responses):
result = socket_engine.response_conditions_matched(
substeps.socket_icmp, responses.socket_icmp
)
assert result == responses.socket_icmp
# socket_icmp
self.assertEqual(
engine.response_conditions_matched(Substep.socket_icmp, Response.socket_icmp),
Response.socket_icmp,
def test_response_conditions_matched_tcp_connect_send_and_receive(
self, socket_engine, substeps, responses
):
result = socket_engine.response_conditions_matched(
substeps.tcp_connect_send_and_receive, responses.tcp_connect_send_and_receive
)
# tcp_connect_send_and_receive, Port scan's substeps are taken for the test
self.assertEqual(
sorted(
engine.response_conditions_matched(
Substep.tcp_connect_send_and_receive, Response.tcp_connect_send_and_receive
)
),
sorted(
{
"http": ["Content-Type: ", "Content-Length: 302", "HTTP/1.1 400", "Server: "],
"log": [
"{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
],
"service": [
"{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
],
}
),
)
expected = {
"http": ["Content-Type: ", "Content-Length: 302", "HTTP/1.1 400", "Server: "],
"log": [
"{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
],
"service": [
"{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
],
}
# tcp_connect_only
self.assertEqual(
engine.response_conditions_matched(
Substep.tcp_connect_only, Response.tcp_connect_only
),
Response.tcp_connect_only,
)
assert sorted(result) == sorted(expected)
# * scans with response None i.e. TCP connection failed(None)
self.assertEqual(
engine.response_conditions_matched(
Substep.tcp_connect_send_and_receive, Response.none
),
[],
def test_response_conditions_matched_tcp_connect_only(
self, socket_engine, substeps, responses
):
result = socket_engine.response_conditions_matched(
substeps.tcp_connect_only, responses.tcp_connect_only
)
assert result == responses.tcp_connect_only
def test_response_conditions_matched_with_none_response(
self, socket_engine, substeps, responses
):
result = socket_engine.response_conditions_matched(
substeps.tcp_connect_send_and_receive, responses.none
)
assert result == []

View File

@ -1,6 +1,8 @@
import ssl
from unittest.mock import patch
import pytest
from nettacker.core.lib.ssl import (
SslEngine,
SslLibrary,
@ -9,7 +11,6 @@ from nettacker.core.lib.ssl import (
is_weak_ssl_version,
is_weak_cipher_suite,
)
from tests.common import TestCase
class MockConnectionObject:
@ -151,92 +152,148 @@ class Substeps:
}
class TestSocketMethod(TestCase):
@pytest.fixture
def ssl_engine():
return SslEngine()
@pytest.fixture
def ssl_library():
return SslLibrary()
@pytest.fixture
def substeps():
return Substeps()
@pytest.fixture
def responses():
return Responses()
@pytest.fixture
def connection_params():
return {"HOST": "example.com", "PORT": 80, "TIMEOUT": 60}
class TestSslMethod:
@patch("socket.socket")
@patch("ssl.wrap_socket")
def test_create_tcp_socket(self, mock_wrap, mock_socket):
HOST = "example.com"
PORT = 80
TIMEOUT = 60
def test_create_tcp_socket(self, mock_wrap, mock_socket, connection_params):
create_tcp_socket(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
create_tcp_socket(HOST, PORT, TIMEOUT)
socket_instance = mock_socket.return_value
socket_instance.settimeout.assert_called_with(TIMEOUT)
socket_instance.connect.assert_called_with((HOST, PORT))
socket_instance.settimeout.assert_called_with(connection_params["TIMEOUT"])
socket_instance.connect.assert_called_with(
(connection_params["HOST"], connection_params["PORT"])
)
mock_wrap.assert_called_with(socket_instance)
@patch("nettacker.core.lib.ssl.is_weak_cipher_suite")
@patch("nettacker.core.lib.ssl.is_weak_ssl_version")
@patch("nettacker.core.lib.ssl.create_tcp_socket")
def test_ssl_version_and_cipher_scan(self, mock_connection, mock_ssl_check, mock_cipher_check):
library = SslLibrary()
HOST = "example.com"
PORT = 80
TIMEOUT = 60
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.3"), True)
def test_ssl_version_and_cipher_scan_secure(
self, mock_connection, mock_ssl_check, mock_cipher_check, ssl_library, connection_params
):
mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.3"),
True,
)
mock_ssl_check.return_value = ("TLSv1.3", False)
mock_cipher_check.return_value = (["HIGH"], False)
self.assertEqual(
library.ssl_version_and_cipher_scan(HOST, PORT, TIMEOUT),
{
"ssl_flag": True,
"service": "http",
"weak_version": False,
"ssl_version": "TLSv1.3",
"peer_name": "example.com",
"cipher_suite": ["HIGH"],
"weak_cipher_suite": False,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
},
result = ssl_library.ssl_version_and_cipher_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.1"), True)
expected = {
"ssl_flag": True,
"service": "http",
"weak_version": False,
"ssl_version": "TLSv1.3",
"peer_name": "example.com",
"cipher_suite": ["HIGH"],
"weak_cipher_suite": False,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
}
assert result == expected
@patch("nettacker.core.lib.ssl.is_weak_cipher_suite")
@patch("nettacker.core.lib.ssl.is_weak_ssl_version")
@patch("nettacker.core.lib.ssl.create_tcp_socket")
def test_ssl_version_and_cipher_scan_weak(
self, mock_connection, mock_ssl_check, mock_cipher_check, ssl_library, connection_params
):
mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.1"),
True,
)
mock_ssl_check.return_value = ("TLSv1.1", True)
mock_cipher_check.return_value = (["LOW"], True)
self.assertEqual(
library.ssl_version_and_cipher_scan(HOST, PORT, TIMEOUT),
{
"ssl_flag": True,
"service": "http",
"weak_version": True,
"ssl_version": "TLSv1.1",
"peer_name": "example.com",
"cipher_suite": ["LOW"],
"weak_cipher_suite": True,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
},
result = ssl_library.ssl_version_and_cipher_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
mock_connection.return_value = (MockConnectionObject(HOST), False)
self.assertEqual(
library.ssl_version_and_cipher_scan(HOST, PORT, TIMEOUT),
{
"ssl_flag": False,
"service": "http",
"peer_name": "example.com",
},
expected = {
"ssl_flag": True,
"service": "http",
"weak_version": True,
"ssl_version": "TLSv1.1",
"peer_name": "example.com",
"cipher_suite": ["LOW"],
"weak_cipher_suite": True,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
}
assert result == expected
@patch("nettacker.core.lib.ssl.is_weak_cipher_suite")
@patch("nettacker.core.lib.ssl.is_weak_ssl_version")
@patch("nettacker.core.lib.ssl.create_tcp_socket")
def test_ssl_version_and_cipher_scan_no_ssl(
self, mock_connection, mock_ssl_check, mock_cipher_check, ssl_library, connection_params
):
mock_connection.return_value = (MockConnectionObject(connection_params["HOST"]), False)
result = ssl_library.ssl_version_and_cipher_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
expected = {
"ssl_flag": False,
"service": "http",
"peer_name": "example.com",
}
assert result == expected
@patch("nettacker.core.lib.ssl.create_tcp_socket")
@patch("nettacker.core.lib.ssl.is_weak_hash_algo")
@patch("nettacker.core.lib.ssl.crypto.load_certificate")
@patch("nettacker.core.lib.ssl.ssl.get_server_certificate")
def test_ssl_certificate_scan(
self, mock_certificate, mock_x509, mock_hash_check, mock_connection
def test_ssl_certificate_scan_valid_cert(
self,
mock_certificate,
mock_x509,
mock_hash_check,
mock_connection,
ssl_library,
connection_params,
):
library = SslLibrary()
HOST = "example.com"
PORT = 80
TIMEOUT = 60
# TESTING AGAINST A CORRECT CERTIFICATE
mock_hash_check.return_value = False
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.3"), True)
mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.3"),
True,
)
mock_x509.return_value = Mockx509Object(
is_expired=False,
issuer="test_issuer",
@ -246,28 +303,46 @@ class TestSocketMethod(TestCase):
activation_date=b"20231207153045Z",
)
self.assertEqual(
library.ssl_certificate_scan(HOST, PORT, TIMEOUT),
{
"expired": False,
"ssl_flag": True,
"service": "http",
"self_signed": False,
"issuer": "component=test_issuer",
"subject": "component=test_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": False,
"activation_date": "2023-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": False,
"peer_name": "example.com",
},
result = ssl_library.ssl_certificate_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
# TESTING AGAINST A SELF-SIGNED CERTIFICATE
expected = {
"expired": False,
"ssl_flag": True,
"service": "http",
"self_signed": False,
"issuer": "component=test_issuer",
"subject": "component=test_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": False,
"activation_date": "2023-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": False,
"peer_name": "example.com",
}
assert result == expected
@patch("nettacker.core.lib.ssl.create_tcp_socket")
@patch("nettacker.core.lib.ssl.is_weak_hash_algo")
@patch("nettacker.core.lib.ssl.crypto.load_certificate")
@patch("nettacker.core.lib.ssl.ssl.get_server_certificate")
def test_ssl_certificate_scan_self_signed(
self,
mock_certificate,
mock_x509,
mock_hash_check,
mock_connection,
ssl_library,
connection_params,
):
mock_hash_check.return_value = True
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.3"), True)
mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.3"),
True,
)
mock_x509.return_value = Mockx509Object(
is_expired=True,
issuer="test_issuer_subject",
@ -276,46 +351,62 @@ class TestSocketMethod(TestCase):
expire_date=b"21001207153045Z",
activation_date=b"21001207153045Z",
)
self.assertEqual(
library.ssl_certificate_scan(HOST, PORT, TIMEOUT),
{
"expired": True,
"ssl_flag": True,
"service": "http",
"self_signed": True,
"issuer": "component=test_issuer_subject",
"subject": "component=test_issuer_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": True,
"activation_date": "2100-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": True,
"peer_name": "example.com",
},
result = ssl_library.ssl_certificate_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
# TESTING IF ssl_flag is False
mock_connection.return_value = (MockConnectionObject(HOST), False)
self.assertEqual(
library.ssl_certificate_scan(HOST, PORT, TIMEOUT),
{
"service": "http",
"ssl_flag": False,
"peer_name": "example.com",
},
expected = {
"expired": True,
"ssl_flag": True,
"service": "http",
"self_signed": True,
"issuer": "component=test_issuer_subject",
"subject": "component=test_issuer_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": True,
"activation_date": "2100-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": True,
"peer_name": "example.com",
}
assert result == expected
@patch("nettacker.core.lib.ssl.create_tcp_socket")
@patch("nettacker.core.lib.ssl.is_weak_hash_algo")
@patch("nettacker.core.lib.ssl.crypto.load_certificate")
@patch("nettacker.core.lib.ssl.ssl.get_server_certificate")
def test_ssl_certificate_scan_no_ssl(
self,
mock_certificate,
mock_x509,
mock_hash_check,
mock_connection,
ssl_library,
connection_params,
):
mock_connection.return_value = (MockConnectionObject(connection_params["HOST"]), False)
result = ssl_library.ssl_certificate_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
mock_certificate.assert_called_with((HOST, PORT))
expected = {
"service": "http",
"ssl_flag": False,
"peer_name": "example.com",
}
assert result == expected
@patch("socket.socket")
@patch("ssl.create_default_context")
def test_is_weak_cipher_suite(self, mock_context, mock_socket):
HOST = "example.com"
PORT = 80
TIMEOUT = 60
def test_is_weak_cipher_suite_success(self, mock_context, mock_socket, connection_params):
socket_instance = mock_socket.return_value
context_instance = mock_context.return_value
cipher_list = [
"HIGH",
"MEDIUM",
@ -337,98 +428,147 @@ class TestSocketMethod(TestCase):
"TLSv1.2",
"TLSv1.3",
]
self.assertEqual(is_weak_cipher_suite(HOST, PORT, TIMEOUT), (cipher_list, True))
context_instance.wrap_socket.assert_called_with(socket_instance, server_hostname=HOST)
socket_instance.settimeout.assert_called_with(TIMEOUT)
socket_instance.connect.assert_called_with((HOST, PORT))
result = is_weak_cipher_suite(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == (cipher_list, True)
context_instance.wrap_socket.assert_called_with(
socket_instance, server_hostname=connection_params["HOST"]
)
socket_instance.settimeout.assert_called_with(connection_params["TIMEOUT"])
socket_instance.connect.assert_called_with(
(connection_params["HOST"], connection_params["PORT"])
)
@patch("socket.socket")
@patch("ssl.create_default_context")
def test_is_weak_cipher_suite_ssl_error(self, mock_context, mock_socket, connection_params):
context_instance = mock_context.return_value
context_instance.wrap_socket.side_effect = ssl.SSLError
self.assertEqual(is_weak_cipher_suite(HOST, PORT, TIMEOUT), ([], False))
def test_is_weak_hash_algo(self):
for algo in ("md2", "md4", "md5", "sha1"):
self.assertTrue(is_weak_hash_algo(algo))
self.assertFalse(is_weak_hash_algo("test_aglo"))
result = is_weak_cipher_suite(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == ([], False)
@pytest.mark.parametrize(
"algo,expected",
[
("md2", True),
("md4", True),
("md5", True),
("sha1", True),
("test_algo", False),
("sha256", False),
],
)
def test_is_weak_hash_algo(self, algo, expected):
assert is_weak_hash_algo(algo) == expected
@patch("socket.socket")
@patch("ssl.SSLContext")
def test_is_weak_ssl_version(self, mock_context, mock_socket):
HOST = "example.com"
PORT = 80
TIMEOUT = 60
def test_is_weak_ssl_version_secure(self, mock_context, mock_socket, connection_params):
context_instance = mock_context.return_value
context_instance.wrap_socket.return_value = MockConnectionObject(
connection_params["HOST"], "TLSv1.3"
)
result = is_weak_ssl_version(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == (["TLSv1.3", "TLSv1.3", "TLSv1.3", "TLSv1.3"], False)
@patch("socket.socket")
@patch("ssl.SSLContext")
def test_is_weak_ssl_version_weak(self, mock_context, mock_socket, connection_params):
context_instance = mock_context.return_value
context_instance.wrap_socket.return_value = MockConnectionObject(
connection_params["HOST"], "TLSv1.1"
)
result = is_weak_ssl_version(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == (["TLSv1.1", "TLSv1.1", "TLSv1.1", "TLSv1.1"], True)
@pytest.mark.parametrize("exception", [ssl.SSLError, ConnectionRefusedError])
@patch("socket.socket")
@patch("ssl.SSLContext")
def test_is_weak_ssl_version_exceptions(
self, mock_context, mock_socket, exception, connection_params
):
socket_instance = mock_socket.return_value
context_instance = mock_context.return_value
context_instance.wrap_socket.side_effect = exception
context_instance.wrap_socket.return_value = MockConnectionObject(HOST, "TLSv1.3")
self.assertEqual(
is_weak_ssl_version(HOST, PORT, TIMEOUT),
(["TLSv1.3", "TLSv1.3", "TLSv1.3", "TLSv1.3"], False),
result = is_weak_ssl_version(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
context_instance.wrap_socket.return_value = MockConnectionObject(HOST, "TLSv1.1")
self.assertEqual(
is_weak_ssl_version(HOST, PORT, TIMEOUT),
(["TLSv1.1", "TLSv1.1", "TLSv1.1", "TLSv1.1"], True),
assert result == ([], True)
socket_instance.settimeout.assert_called_with(connection_params["TIMEOUT"])
socket_instance.connect.assert_called_with(
(connection_params["HOST"], connection_params["PORT"])
)
context_instance.wrap_socket.assert_called_with(
socket_instance, server_hostname=connection_params["HOST"]
)
context_instance.wrap_socket.side_effect = ssl.SSLError
self.assertEqual(is_weak_ssl_version(HOST, PORT, TIMEOUT), ([], True))
context_instance.wrap_socket.side_effect = ConnectionRefusedError
self.assertEqual(is_weak_ssl_version(HOST, PORT, TIMEOUT), ([], True))
socket_instance.settimeout.assert_called_with(TIMEOUT)
socket_instance.connect.assert_called_with((HOST, PORT))
context_instance.wrap_socket.assert_called_with(socket_instance, server_hostname=HOST)
def test_response_conditions_matched(self):
# tests the response conditions matched for different scan methods
engine = SslEngine()
Substep = Substeps()
Response = Responses()
# ssl_certificate_expired_vuln
self.assertEqual(
engine.response_conditions_matched(
Substep.ssl_certificate_expired_vuln, Response.ssl_certificate_expired
),
{"subject": "component=subject", "expired": True, "expiration_date": "2023-12-07"},
)
# ssl_certificate_expired_vuln(not activated)
self.assertEqual(
engine.response_conditions_matched(
Substep.ssl_certificate_expired_vuln,
Response.ssl_certificate_deactivated,
),
{
"subject": "component=subject",
"not_activated": True,
"activation_date": "2100-12-07",
},
def test_response_conditions_matched_expired_cert(self, ssl_engine, substeps, responses):
result = ssl_engine.response_conditions_matched(
substeps.ssl_certificate_expired_vuln, responses.ssl_certificate_expired
)
# ssl_weak_version_vuln
self.assertEqual(
engine.response_conditions_matched(
Substep.ssl_weak_version_vuln, Response.ssl_weak_version_vuln
),
{
"weak_version": True,
"ssl_version": ["TLSv1"],
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
},
expected = {
"subject": "component=subject",
"expired": True,
"expiration_date": "2023-12-07",
}
assert result == expected
def test_response_conditions_matched_deactivated_cert(self, ssl_engine, substeps, responses):
result = ssl_engine.response_conditions_matched(
substeps.ssl_certificate_expired_vuln,
responses.ssl_certificate_deactivated,
)
# ssl_* scans with ssl_flag = False
self.assertEqual(
engine.response_conditions_matched(Substep.ssl_weak_version_vuln, Response.ssl_off), []
expected = {
"subject": "component=subject",
"not_activated": True,
"activation_date": "2100-12-07",
}
assert result == expected
def test_response_conditions_matched_weak_version(self, ssl_engine, substeps, responses):
result = ssl_engine.response_conditions_matched(
substeps.ssl_weak_version_vuln, responses.ssl_weak_version_vuln
)
# * scans with response None i.e. TCP connection failed(None)
self.assertEqual(
engine.response_conditions_matched(Substep.ssl_weak_version_vuln, None), []
expected = {
"weak_version": True,
"ssl_version": ["TLSv1"],
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
}
assert result == expected
def test_response_conditions_matched_ssl_off(self, ssl_engine, substeps, responses):
result = ssl_engine.response_conditions_matched(
substeps.ssl_weak_version_vuln, responses.ssl_off
)
assert result == []
def test_response_conditions_matched_none_response(self, ssl_engine, substeps):
result = ssl_engine.response_conditions_matched(substeps.ssl_weak_version_vuln, None)
assert result == []

View File

@ -1,107 +1,94 @@
from unittest.mock import patch
from nettacker.core.utils import common as common_utils
from tests.common import TestCase
class TestCommon(TestCase):
def test_arrays_to_matrix(self):
(
self.assertEqual(
sorted(
common_utils.arrays_to_matrix(
{"ports": [1, 2, 3, 4, 5]},
)
),
[[1], [2], [3], [4], [5]],
),
)
def test_arrays_to_matrix():
assert sorted(common_utils.arrays_to_matrix({"ports": [1, 2, 3, 4, 5]})) == [
[1],
[2],
[3],
[4],
[5],
]
self.assertEqual(
sorted(
common_utils.arrays_to_matrix(
{"x": [1, 2], "y": [3, 4], "z": [5, 6]},
)
),
[
[1, 3, 5],
[1, 3, 6],
[1, 4, 5],
[1, 4, 6],
[2, 3, 5],
[2, 3, 6],
[2, 4, 5],
[2, 4, 6],
],
)
assert sorted(common_utils.arrays_to_matrix({"x": [1, 2], "y": [3, 4], "z": [5, 6]})) == [
[1, 3, 5],
[1, 3, 6],
[1, 4, 5],
[1, 4, 6],
[2, 3, 5],
[2, 3, 6],
[2, 4, 5],
[2, 4, 6],
]
def test_generate_target_groups_empty_list(self):
targets = []
set_hardware_usage = 3
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == []
def test_generate_target_groups_set_hardware_less_than_targets_total(self):
targets = [1, 2, 3, 4, 5]
set_hardware_usage = 2
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1, 2, 3], [4, 5]]
def test_generate_target_groups_empty_list():
targets = []
set_hardware_usage = 3
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == []
def test_generate_target_groups_set_hardware_equal_to_targets_total(self):
targets = [1, 2, 3, 4, 5]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3], [4], [5]]
def test_generate_target_groups_set_hardware_greater_than_targets_total(self):
targets = [1, 2, 3]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3]]
def test_generate_target_groups_set_hardware_less_than_targets_total():
targets = [1, 2, 3, 4, 5]
set_hardware_usage = 2
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1, 2, 3], [4, 5]]
def test_sort_dictionary(self):
input_dict = {
"a": 1,
"c": 3,
"d": 23,
"b": 2,
}
expected_dict = {
"a": 1,
"b": 2,
"c": 3,
"d": 23,
}
input_dict_keys = tuple(input_dict.keys())
expected_dict_keys = tuple(expected_dict.keys())
self.assertNotEqual(input_dict_keys, expected_dict_keys)
sorted_dict_keys = tuple(common_utils.sort_dictionary(input_dict).keys())
self.assertEqual(sorted_dict_keys, expected_dict_keys)
def test_generate_target_groups_set_hardware_equal_to_targets_total():
targets = [1, 2, 3, 4, 5]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3], [4], [5]]
@patch("multiprocessing.cpu_count")
def test_select_maximum_cpu_core(self, cpu_count_mock):
cores_mapping = {
1: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
2: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
4: {"low": 1, "normal": 1, "high": 2, "maximum": 3},
6: {"low": 1, "normal": 1, "high": 3, "maximum": 5},
8: {"low": 1, "normal": 2, "high": 4, "maximum": 7},
10: {"low": 1, "normal": 2, "high": 5, "maximum": 9},
12: {"low": 1, "normal": 3, "high": 6, "maximum": 11},
16: {"low": 2, "normal": 4, "high": 8, "maximum": 15},
32: {"low": 4, "normal": 8, "high": 16, "maximum": 31},
48: {"low": 6, "normal": 12, "high": 24, "maximum": 47},
64: {"low": 8, "normal": 16, "high": 32, "maximum": 63},
}
for num_cores, levels in cores_mapping.items():
cpu_count_mock.return_value = num_cores
for level in ("low", "normal", "high", "maximum"):
self.assertEqual(
common_utils.select_maximum_cpu_core(level),
levels[level],
f"It should be {common_utils.select_maximum_cpu_core(level)} "
"of {num_cores} cores for '{level}' mode",
)
self.assertEqual(common_utils.select_maximum_cpu_core("invalid"), 1)
def test_generate_target_groups_set_hardware_greater_than_targets_total():
targets = [1, 2, 3]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3]]
def test_sort_dictionary():
input_dict = {
"a": 1,
"c": 3,
"d": 23,
"b": 2,
}
expected_dict = {
"a": 1,
"b": 2,
"c": 3,
"d": 23,
}
input_dict_keys = tuple(input_dict.keys())
expected_dict_keys = tuple(expected_dict.keys())
assert input_dict_keys != expected_dict_keys
sorted_dict_keys = tuple(common_utils.sort_dictionary(input_dict).keys())
assert sorted_dict_keys == expected_dict_keys
@patch("multiprocessing.cpu_count")
def test_select_maximum_cpu_core(cpu_count_mock):
cores_mapping = {
1: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
2: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
4: {"low": 1, "normal": 1, "high": 2, "maximum": 3},
6: {"low": 1, "normal": 1, "high": 3, "maximum": 5},
8: {"low": 1, "normal": 2, "high": 4, "maximum": 7},
10: {"low": 1, "normal": 2, "high": 5, "maximum": 9},
12: {"low": 1, "normal": 3, "high": 6, "maximum": 11},
16: {"low": 2, "normal": 4, "high": 8, "maximum": 15},
32: {"low": 4, "normal": 8, "high": 16, "maximum": 31},
48: {"low": 6, "normal": 12, "high": 24, "maximum": 47},
64: {"low": 8, "normal": 16, "high": 32, "maximum": 63},
}
for num_cores, levels in cores_mapping.items():
cpu_count_mock.return_value = num_cores
for level in ("low", "normal", "high", "maximum"):
assert common_utils.select_maximum_cpu_core(level) == levels[level]
assert common_utils.select_maximum_cpu_core("invalid") == 1

View File

@ -1,94 +1,95 @@
from datetime import datetime
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from nettacker.database.models import Base, Report, TempEvents, HostsLog
from tests.common import TestCase
class TestModels(TestCase):
def setUp(self):
# Creating an in-memory SQLite database for testing
self.engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(self.engine)
Session = sessionmaker(bind=self.engine)
self.session = Session()
@pytest.fixture
def session():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
sess = Session()
yield sess
sess.close()
Base.metadata.drop_all(engine)
def tearDown(self):
self.session.close()
Base.metadata.drop_all(self.engine)
def test_report_model(self):
test_date = datetime.now()
test_report = Report(
date=test_date,
scan_unique_id="test123",
report_path_filename="/path/to/report.txt",
options='{"option1": "value1"}',
)
def test_report_model(session):
test_date = datetime.now()
test_report = Report(
date=test_date,
scan_unique_id="test123",
report_path_filename="/path/to/report.txt",
options='{"option1": "value1"}',
)
self.session.add(test_report)
self.session.commit()
session.add(test_report)
session.commit()
retrieved_report = self.session.query(Report).first()
self.assertIsNotNone(retrieved_report)
self.assertEqual(retrieved_report.scan_unique_id, "test123")
self.assertEqual(retrieved_report.report_path_filename, "/path/to/report.txt")
self.assertEqual(retrieved_report.options, '{"option1": "value1"}')
retrieved_report = session.query(Report).first()
assert retrieved_report is not None
assert retrieved_report.scan_unique_id == "test123"
assert retrieved_report.report_path_filename == "/path/to/report.txt"
assert retrieved_report.options == '{"option1": "value1"}'
repr_string = repr(retrieved_report)
self.assertIn("test123", repr_string)
self.assertIn("/path/to/report.txt", repr_string)
repr_string = repr(retrieved_report)
assert "test123" in repr_string
assert "/path/to/report.txt" in repr_string
def test_temp_events_model(self):
test_date = datetime.now()
test_event = TempEvents(
date=test_date,
target="192.168.1.1",
module_name="port_scan",
scan_unique_id="test123",
event_name="open_port",
port="80",
event="Port 80 is open",
data='{"details": "HTTP server running"}',
)
self.session.add(test_event)
self.session.commit()
def test_temp_events_model(session):
test_date = datetime.now()
test_event = TempEvents(
date=test_date,
target="192.168.1.1",
module_name="port_scan",
scan_unique_id="test123",
event_name="open_port",
port="80",
event="Port 80 is open",
data='{"details": "HTTP server running"}',
)
retrieved_event = self.session.query(TempEvents).first()
self.assertIsNotNone(retrieved_event)
self.assertEqual(retrieved_event.target, "192.168.1.1")
self.assertEqual(retrieved_event.module_name, "port_scan")
self.assertEqual(retrieved_event.port, "80")
session.add(test_event)
session.commit()
repr_string = repr(retrieved_event)
self.assertIn("192.168.1.1", repr_string)
self.assertIn("port_scan", repr_string)
retrieved_event = session.query(TempEvents).first()
assert retrieved_event is not None
assert retrieved_event.target == "192.168.1.1"
assert retrieved_event.module_name == "port_scan"
assert retrieved_event.port == "80"
def test_hosts_log_model(self):
test_date = datetime.now()
test_log = HostsLog(
date=test_date,
target="192.168.1.1",
module_name="vulnerability_scan",
scan_unique_id="test123",
port="443",
event="Found vulnerability CVE-2021-12345",
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
)
repr_string = repr(retrieved_event)
assert "192.168.1.1" in repr_string
assert "port_scan" in repr_string
self.session.add(test_log)
self.session.commit()
retrieved_log = self.session.query(HostsLog).first()
self.assertIsNotNone(retrieved_log)
self.assertEqual(retrieved_log.target, "192.168.1.1")
self.assertEqual(retrieved_log.module_name, "vulnerability_scan")
self.assertEqual(retrieved_log.port, "443")
self.assertEqual(retrieved_log.event, "Found vulnerability CVE-2021-12345")
def test_hosts_log_model(session):
test_date = datetime.now()
test_log = HostsLog(
date=test_date,
target="192.168.1.1",
module_name="vulnerability_scan",
scan_unique_id="test123",
port="443",
event="Found vulnerability CVE-2021-12345",
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
)
repr_string = repr(retrieved_log)
self.assertIn("192.168.1.1", repr_string)
self.assertIn("vulnerability_scan", repr_string)
session.add(test_log)
session.commit()
retrieved_log = session.query(HostsLog).first()
assert retrieved_log is not None
assert retrieved_log.target == "192.168.1.1"
assert retrieved_log.module_name == "vulnerability_scan"
assert retrieved_log.port == "443"
assert retrieved_log.event == "Found vulnerability CVE-2021-12345"
repr_string = repr(retrieved_log)
assert "192.168.1.1" in repr_string
assert "vulnerability_scan" in repr_string

View File

@ -1,141 +1,89 @@
from unittest.mock import patch, MagicMock
import pytest
from sqlalchemy.exc import SQLAlchemyError
from nettacker.config import Config
from nettacker.database.models import Base
from nettacker.database.mysql import mysql_create_database, mysql_create_tables
from tests.common import TestCase
class TestMySQLFunctions(TestCase):
"""Test cases for mysql.py functions"""
@pytest.fixture(autouse=True)
def setup_config():
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_success(self, mock_create_engine):
"""Test successful database creation"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
# Set up mock connection and execution
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_success(mock_create_engine):
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
# Mock database query results - database doesn't exist yet
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
mysql_create_database()
# Call the function
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
call_args_list = mock_conn.execute.call_args_list
assert len(call_args_list) == 2
first_call_arg = call_args_list[0][0][0]
assert str(first_call_arg) == "SHOW DATABASES;"
second_call_arg = call_args_list[1][0][0]
assert str(second_call_arg) == "CREATE DATABASE test_db "
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_already_exists(mock_create_engine):
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
mysql_create_database()
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
assert mock_conn.execute.call_count == 1
call_arg = mock_conn.execute.call_args[0][0]
assert str(call_arg) == "SHOW DATABASES;"
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_exception(mock_create_engine):
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
with patch("builtins.print") as mock_print:
mysql_create_database()
mock_print.assert_called_once()
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_tables(mock_create_engine):
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all:
mysql_create_tables()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
"mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
)
# Check that execute was called with any text object that has the expected SQL
call_args_list = mock_conn.execute.call_args_list
self.assertEqual(len(call_args_list), 2) # Two calls to execute
# Check that the first call is SHOW DATABASES
first_call_arg = call_args_list[0][0][0]
self.assertEqual(str(first_call_arg), "SHOW DATABASES;")
# Check that the second call is CREATE DATABASE
second_call_arg = call_args_list[1][0][0]
self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ")
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_already_exists(self, mock_create_engine):
"""Test when database already exists"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
# Set up mock connection and execution
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
# Mock database query results - database already exists
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
# Call the function
mysql_create_database()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
# Check that execute was called once with SHOW DATABASES
self.assertEqual(mock_conn.execute.call_count, 1)
call_arg = mock_conn.execute.call_args[0][0]
self.assertEqual(str(call_arg), "SHOW DATABASES;")
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_exception(self, mock_create_engine):
"""Test exception handling in create database"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
# Set up mock to raise exception
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
# Call the function (should not raise exception)
with patch("builtins.print") as mock_print:
mysql_create_database()
mock_print.assert_called_once()
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_tables(self, mock_create_engine):
"""Test table creation function"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
# Set up mock engine
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
# Call the function
with patch.object(Base.metadata, "create_all") as mock_create_all:
mysql_create_tables()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
)
mock_create_all.assert_called_once_with(mock_engine)
mock_create_all.assert_called_once_with(mock_engine)

View File

@ -5,88 +5,90 @@ from sqlalchemy.exc import OperationalError
from nettacker.config import Config
from nettacker.database.models import Base
from nettacker.database.postgresql import postgres_create_database
from tests.common import TestCase
class TestPostgresFunctions(TestCase):
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_success(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_success(mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all:
with patch.object(Base.metadata, "create_all") as mock_create_all:
postgres_create_database()
mock_create_engine.assert_called_once_with(
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
)
mock_create_all.assert_called_once_with(mock_engine)
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_if_not_exists(mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
Config.db.name = "nettacker_db"
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_engine_final = MagicMock()
mock_create_engine.side_effect = [
mock_engine_initial,
mock_engine_fallback,
mock_engine_final,
]
with patch.object(
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
):
mock_conn = MagicMock()
mock_engine_fallback.connect.return_value = mock_conn
mock_conn.execution_options.return_value = mock_conn
postgres_create_database()
assert mock_create_engine.call_count == 3
args, _ = mock_conn.execute.call_args
assert str(args[0]) == "CREATE DATABASE nettacker_db"
mock_conn.close.assert_called_once()
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_create_fail(mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
with patch.object(
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
):
import pytest
with pytest.raises(OperationalError):
postgres_create_database()
mock_create_engine.assert_called_once_with(
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
)
mock_create_all.assert_called_once_with(mock_engine)
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_if_not_exists(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
Config.db.name = "nettacker_db"
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_engine_final = MagicMock()
mock_create_engine.side_effect = [
mock_engine_initial,
mock_engine_fallback,
mock_engine_final,
]
with patch.object(
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
):
mock_conn = MagicMock()
mock_engine_fallback.connect.return_value = mock_conn
mock_conn.execution_options.return_value = mock_conn
postgres_create_database()
assert mock_create_engine.call_count == 3
args, _ = mock_conn.execute.call_args
assert str(args[0]) == "CREATE DATABASE nettacker_db"
mock_conn.close.assert_called_once()
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_create_fail(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
with patch.object(
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
):
with self.assertRaises(OperationalError):
postgres_create_database()

View File

@ -1,42 +1,45 @@
from unittest.mock import patch, MagicMock
import pytest
from sqlalchemy import create_engine, inspect
from nettacker.config import Config
from nettacker.database.models import Base
from nettacker.database.sqlite import sqlite_create_tables
from tests.common import TestCase
class TestSQLiteFunctions(TestCase):
@patch("nettacker.database.sqlite.create_engine")
def test_sqlite_create_tables(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {"name": "/path/to/test.db"}
@pytest.fixture
def mock_config():
Config.db = MagicMock()
yield Config.db
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all:
sqlite_create_tables()
@patch("nettacker.database.sqlite.create_engine")
def test_sqlite_create_tables(mock_create_engine, mock_config):
mock_config.as_dict.return_value = {"name": "/path/to/test.db"}
mock_create_engine.assert_called_once_with(
"sqlite:////path/to/test.db", connect_args={"check_same_thread": False}
)
mock_create_all.assert_called_once_with(mock_engine)
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
def test_sqlite_create_tables_integration(self):
engine = create_engine("sqlite:///:memory:")
with patch.object(Base.metadata, "create_all") as mock_create_all:
sqlite_create_tables()
Config.db = MagicMock()
Config.db.as_dict.return_value = {"name": ":memory:"}
mock_create_engine.assert_called_once_with(
"sqlite:////path/to/test.db", connect_args={"check_same_thread": False}
)
mock_create_all.assert_called_once_with(mock_engine)
with patch("nettacker.database.sqlite.create_engine", return_value=engine):
sqlite_create_tables()
inspector = inspect(engine)
tables = inspector.get_table_names()
def test_sqlite_create_tables_integration(mock_config):
engine = create_engine("sqlite:///:memory:")
mock_config.as_dict.return_value = {"name": ":memory:"}
self.assertIn("reports", tables, "Reports table was not created")
self.assertIn("temp_events", tables, "Temp events table was not created")
self.assertIn("scan_events", tables, "Scan events table was not created")
with patch("nettacker.database.sqlite.create_engine", return_value=engine):
sqlite_create_tables()
inspector = inspect(engine)
tables = inspector.get_table_names()
assert "reports" in tables, "Reports table was not created"
assert "temp_events" in tables, "Temp events table was not created"
assert "scan_events" in tables, "Scan events table was not created"

View File

@ -1,22 +1,18 @@
from collections import Counter
from pathlib import Path
import pytest
from tests.common import TestCase
nettacker_path = Path(__file__).parent.parent.parent.parent
class TestPasswords(TestCase):
top_1000_common_passwords_path = "lib/payloads/passwords/top_1000_common_passwords.txt"
def test_top_1000_common_passwords():
top_1000_passwords_file_path = (
nettacker_path / "nettacker/lib/payloads/passwords/top_1000_common_passwords.txt"
)
with open(top_1000_passwords_file_path) as f:
top_1000_passwords = [line.strip() for line in f.readlines() if line.strip()]
@pytest.mark.xfail(reason="It currently contains 1001 passwords.")
def test_top_1000_common_passwords(self):
with open(self.nettacker_path / self.top_1000_common_passwords_path) as top_1000_file:
top_1000_passwords = [line.strip() for line in top_1000_file.readlines()]
assert len(top_1000_passwords) == 1000, "There should be exactly 1000 passwords"
self.assertEqual(len(top_1000_passwords), 1000, "There should be exactly 1000 passwords")
self.assertEqual(
len(set(top_1000_passwords)),
len(top_1000_passwords),
f"The passwords aren't unique: {Counter(top_1000_passwords).most_common(1)[0][0]}",
)
assert len(set(top_1000_passwords)) == len(
top_1000_passwords
), f"The passwords aren't unique: {Counter(top_1000_passwords).most_common(1)[0][0]}"

View File

@ -1,48 +1,29 @@
from collections import Counter
from pathlib import Path
from tests.common import TestCase
import pytest
wordlists = {
"admin_file": ["lib/payloads/wordlists/admin_wordlist.txt", 533],
"dir_file": ["lib/payloads/wordlists/dir_wordlist.txt", 1966],
"pma_file": ["lib/payloads/wordlists/pma_wordlist.txt", 174],
"wp_plugin_small_file": ["lib/payloads/wordlists/wp_plugin_small.txt", 291],
"wp_theme_small_file": ["lib/payloads/wordlists/wp_theme_small.txt", 41],
"wp_timethumb_file": ["lib/payloads/wordlists/wp_timethumbs.txt", 2424],
"admin_file": ("nettacker/lib/payloads/wordlists/admin_wordlist.txt", 533),
"dir_file": ("nettacker/lib/payloads/wordlists/dir_wordlist.txt", 1966),
"pma_file": ("nettacker/lib/payloads/wordlists/pma_wordlist.txt", 174),
"wp_plugin_small_file": ("nettacker/lib/payloads/wordlists/wp_plugin_small.txt", 291),
"wp_theme_small_file": ("nettacker/lib/payloads/wordlists/wp_theme_small.txt", 41),
"wp_timethumb_file": ("nettacker/lib/payloads/wordlists/wp_timethumbs.txt", 2424),
}
nettacker_path = Path(__file__).parent.parent.parent.parent
class TestWordlists(TestCase):
def test_admin_wordlist(self):
self.run_wordlist_test("admin_file")
def test_dir_wordlist(self):
self.run_wordlist_test("dir_file")
@pytest.mark.parametrize("key", list(wordlists.keys()))
def test_wordlist(key):
wordlist_path, expected_length = wordlists[key]
full_path = nettacker_path / wordlist_path
def test_pma_wordlist(self):
self.run_wordlist_test("pma_file")
with open(full_path) as f:
paths = [line.strip() for line in f.readlines()]
def test_wp_plugin_small_wordlist(self):
self.run_wordlist_test("wp_plugin_small_file")
def test_wp_theme_small_wordlist(self):
self.run_wordlist_test("wp_theme_small_file")
def test_wp_timethumb_wordlist(self):
self.run_wordlist_test("wp_timethumb_file")
def run_wordlist_test(self, key):
wordlist_path = wordlists[key][0]
wordlist_length = wordlists[key][1]
with open(self.nettacker_path / wordlist_path) as wordlist_file:
paths = [line.strip() for line in wordlist_file.readlines()]
self.assertEqual(
len(paths), wordlist_length, f"There are {wordlist_length} paths in {key}"
)
self.assertEqual(
len(set(paths)),
len(paths),
f"The paths aren't unique in {key}: {Counter(paths).most_common(1)[0][0]}",
)
assert len(paths) == expected_length, f"There are {expected_length} paths in {key}"
assert len(set(paths)) == len(
paths
), f"The paths aren't unique in {key}: {Counter(paths).most_common(1)[0][0]}"