refactor tests and migrate to pytest (#1081)

* refactor tests and migrate to pytest

* Update tests

---------

Co-authored-by: Arkadii Yakovets <arkadii.yakovets@owasp.org>
This commit is contained in:
Achintya Jai 2025-06-11 06:57:16 +05:30 committed by GitHub
parent 8748df910b
commit 74e494dd1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 746 additions and 679 deletions

View File

@ -1,7 +1,8 @@
from unittest.mock import patch from unittest.mock import patch
import pytest
from nettacker.core.lib.socket import create_tcp_socket, SocketEngine from nettacker.core.lib.socket import create_tcp_socket, SocketEngine
from tests.common import TestCase
class Responses: class Responses:
@ -123,7 +124,22 @@ class Substeps:
} }
class TestSocketMethod(TestCase): @pytest.fixture
def socket_engine():
return SocketEngine()
@pytest.fixture
def substeps():
return Substeps()
@pytest.fixture
def responses():
return Responses()
class TestSocketMethod:
@patch("socket.socket") @patch("socket.socket")
@patch("ssl.wrap_socket") @patch("ssl.wrap_socket")
def test_create_tcp_socket(self, mock_wrap, mock_socket): def test_create_tcp_socket(self, mock_wrap, mock_socket):
@ -137,50 +153,43 @@ class TestSocketMethod(TestCase):
socket_instance.connect.assert_called_with((HOST, PORT)) socket_instance.connect.assert_called_with((HOST, PORT))
mock_wrap.assert_called_with(socket_instance) mock_wrap.assert_called_with(socket_instance)
def test_response_conditions_matched(self): def test_response_conditions_matched_socket_icmp(self, socket_engine, substeps, responses):
# tests the response conditions matched for different scan methods result = socket_engine.response_conditions_matched(
engine = SocketEngine() substeps.socket_icmp, responses.socket_icmp
Substep = Substeps() )
Response = Responses() assert result == responses.socket_icmp
# socket_icmp def test_response_conditions_matched_tcp_connect_send_and_receive(
self.assertEqual( self, socket_engine, substeps, responses
engine.response_conditions_matched(Substep.socket_icmp, Response.socket_icmp), ):
Response.socket_icmp, result = socket_engine.response_conditions_matched(
substeps.tcp_connect_send_and_receive, responses.tcp_connect_send_and_receive
) )
# tcp_connect_send_and_receive, Port scan's substeps are taken for the test expected = {
self.assertEqual( "http": ["Content-Type: ", "Content-Length: 302", "HTTP/1.1 400", "Server: "],
sorted( "log": [
engine.response_conditions_matched( "{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
Substep.tcp_connect_send_and_receive, Response.tcp_connect_send_and_receive ],
) "service": [
), "{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
sorted( ],
{ }
"http": ["Content-Type: ", "Content-Length: 302", "HTTP/1.1 400", "Server: "],
"log": [
"{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
],
"service": [
"{'running_service': 'http', 'matched_regex': ['Server: ', 'HTTP/1.1 400', 'Content-Length: 302', 'Content-Type: '], 'default_service': 'http', 'ssl_flag': True}"
],
}
),
)
# tcp_connect_only assert sorted(result) == sorted(expected)
self.assertEqual(
engine.response_conditions_matched(
Substep.tcp_connect_only, Response.tcp_connect_only
),
Response.tcp_connect_only,
)
# * scans with response None i.e. TCP connection failed(None) def test_response_conditions_matched_tcp_connect_only(
self.assertEqual( self, socket_engine, substeps, responses
engine.response_conditions_matched( ):
Substep.tcp_connect_send_and_receive, Response.none result = socket_engine.response_conditions_matched(
), substeps.tcp_connect_only, responses.tcp_connect_only
[],
) )
assert result == responses.tcp_connect_only
def test_response_conditions_matched_with_none_response(
self, socket_engine, substeps, responses
):
result = socket_engine.response_conditions_matched(
substeps.tcp_connect_send_and_receive, responses.none
)
assert result == []

View File

@ -1,6 +1,8 @@
import ssl import ssl
from unittest.mock import patch from unittest.mock import patch
import pytest
from nettacker.core.lib.ssl import ( from nettacker.core.lib.ssl import (
SslEngine, SslEngine,
SslLibrary, SslLibrary,
@ -9,7 +11,6 @@ from nettacker.core.lib.ssl import (
is_weak_ssl_version, is_weak_ssl_version,
is_weak_cipher_suite, is_weak_cipher_suite,
) )
from tests.common import TestCase
class MockConnectionObject: class MockConnectionObject:
@ -151,92 +152,148 @@ class Substeps:
} }
class TestSocketMethod(TestCase): @pytest.fixture
def ssl_engine():
return SslEngine()
@pytest.fixture
def ssl_library():
return SslLibrary()
@pytest.fixture
def substeps():
return Substeps()
@pytest.fixture
def responses():
return Responses()
@pytest.fixture
def connection_params():
return {"HOST": "example.com", "PORT": 80, "TIMEOUT": 60}
class TestSslMethod:
@patch("socket.socket") @patch("socket.socket")
@patch("ssl.wrap_socket") @patch("ssl.wrap_socket")
def test_create_tcp_socket(self, mock_wrap, mock_socket): def test_create_tcp_socket(self, mock_wrap, mock_socket, connection_params):
HOST = "example.com" create_tcp_socket(
PORT = 80 connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
TIMEOUT = 60 )
create_tcp_socket(HOST, PORT, TIMEOUT)
socket_instance = mock_socket.return_value socket_instance = mock_socket.return_value
socket_instance.settimeout.assert_called_with(TIMEOUT) socket_instance.settimeout.assert_called_with(connection_params["TIMEOUT"])
socket_instance.connect.assert_called_with((HOST, PORT)) socket_instance.connect.assert_called_with(
(connection_params["HOST"], connection_params["PORT"])
)
mock_wrap.assert_called_with(socket_instance) mock_wrap.assert_called_with(socket_instance)
@patch("nettacker.core.lib.ssl.is_weak_cipher_suite") @patch("nettacker.core.lib.ssl.is_weak_cipher_suite")
@patch("nettacker.core.lib.ssl.is_weak_ssl_version") @patch("nettacker.core.lib.ssl.is_weak_ssl_version")
@patch("nettacker.core.lib.ssl.create_tcp_socket") @patch("nettacker.core.lib.ssl.create_tcp_socket")
def test_ssl_version_and_cipher_scan(self, mock_connection, mock_ssl_check, mock_cipher_check): def test_ssl_version_and_cipher_scan_secure(
library = SslLibrary() self, mock_connection, mock_ssl_check, mock_cipher_check, ssl_library, connection_params
HOST = "example.com" ):
PORT = 80 mock_connection.return_value = (
TIMEOUT = 60 MockConnectionObject(connection_params["HOST"], "TLSv1.3"),
True,
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.3"), True) )
mock_ssl_check.return_value = ("TLSv1.3", False) mock_ssl_check.return_value = ("TLSv1.3", False)
mock_cipher_check.return_value = (["HIGH"], False) mock_cipher_check.return_value = (["HIGH"], False)
self.assertEqual(
library.ssl_version_and_cipher_scan(HOST, PORT, TIMEOUT), result = ssl_library.ssl_version_and_cipher_scan(
{ connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
"ssl_flag": True,
"service": "http",
"weak_version": False,
"ssl_version": "TLSv1.3",
"peer_name": "example.com",
"cipher_suite": ["HIGH"],
"weak_cipher_suite": False,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
},
) )
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.1"), True) expected = {
"ssl_flag": True,
"service": "http",
"weak_version": False,
"ssl_version": "TLSv1.3",
"peer_name": "example.com",
"cipher_suite": ["HIGH"],
"weak_cipher_suite": False,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
}
assert result == expected
@patch("nettacker.core.lib.ssl.is_weak_cipher_suite")
@patch("nettacker.core.lib.ssl.is_weak_ssl_version")
@patch("nettacker.core.lib.ssl.create_tcp_socket")
def test_ssl_version_and_cipher_scan_weak(
self, mock_connection, mock_ssl_check, mock_cipher_check, ssl_library, connection_params
):
mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.1"),
True,
)
mock_ssl_check.return_value = ("TLSv1.1", True) mock_ssl_check.return_value = ("TLSv1.1", True)
mock_cipher_check.return_value = (["LOW"], True) mock_cipher_check.return_value = (["LOW"], True)
self.assertEqual(
library.ssl_version_and_cipher_scan(HOST, PORT, TIMEOUT), result = ssl_library.ssl_version_and_cipher_scan(
{ connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
"ssl_flag": True,
"service": "http",
"weak_version": True,
"ssl_version": "TLSv1.1",
"peer_name": "example.com",
"cipher_suite": ["LOW"],
"weak_cipher_suite": True,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
},
) )
mock_connection.return_value = (MockConnectionObject(HOST), False) expected = {
self.assertEqual( "ssl_flag": True,
library.ssl_version_and_cipher_scan(HOST, PORT, TIMEOUT), "service": "http",
{ "weak_version": True,
"ssl_flag": False, "ssl_version": "TLSv1.1",
"service": "http", "peer_name": "example.com",
"peer_name": "example.com", "cipher_suite": ["LOW"],
}, "weak_cipher_suite": True,
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
}
assert result == expected
@patch("nettacker.core.lib.ssl.is_weak_cipher_suite")
@patch("nettacker.core.lib.ssl.is_weak_ssl_version")
@patch("nettacker.core.lib.ssl.create_tcp_socket")
def test_ssl_version_and_cipher_scan_no_ssl(
self, mock_connection, mock_ssl_check, mock_cipher_check, ssl_library, connection_params
):
mock_connection.return_value = (MockConnectionObject(connection_params["HOST"]), False)
result = ssl_library.ssl_version_and_cipher_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
) )
expected = {
"ssl_flag": False,
"service": "http",
"peer_name": "example.com",
}
assert result == expected
@patch("nettacker.core.lib.ssl.create_tcp_socket") @patch("nettacker.core.lib.ssl.create_tcp_socket")
@patch("nettacker.core.lib.ssl.is_weak_hash_algo") @patch("nettacker.core.lib.ssl.is_weak_hash_algo")
@patch("nettacker.core.lib.ssl.crypto.load_certificate") @patch("nettacker.core.lib.ssl.crypto.load_certificate")
@patch("nettacker.core.lib.ssl.ssl.get_server_certificate") @patch("nettacker.core.lib.ssl.ssl.get_server_certificate")
def test_ssl_certificate_scan( def test_ssl_certificate_scan_valid_cert(
self, mock_certificate, mock_x509, mock_hash_check, mock_connection self,
mock_certificate,
mock_x509,
mock_hash_check,
mock_connection,
ssl_library,
connection_params,
): ):
library = SslLibrary()
HOST = "example.com"
PORT = 80
TIMEOUT = 60
# TESTING AGAINST A CORRECT CERTIFICATE
mock_hash_check.return_value = False mock_hash_check.return_value = False
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.3"), True) mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.3"),
True,
)
mock_x509.return_value = Mockx509Object( mock_x509.return_value = Mockx509Object(
is_expired=False, is_expired=False,
issuer="test_issuer", issuer="test_issuer",
@ -246,28 +303,46 @@ class TestSocketMethod(TestCase):
activation_date=b"20231207153045Z", activation_date=b"20231207153045Z",
) )
self.assertEqual( result = ssl_library.ssl_certificate_scan(
library.ssl_certificate_scan(HOST, PORT, TIMEOUT), connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
{
"expired": False,
"ssl_flag": True,
"service": "http",
"self_signed": False,
"issuer": "component=test_issuer",
"subject": "component=test_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": False,
"activation_date": "2023-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": False,
"peer_name": "example.com",
},
) )
# TESTING AGAINST A SELF-SIGNED CERTIFICATE expected = {
"expired": False,
"ssl_flag": True,
"service": "http",
"self_signed": False,
"issuer": "component=test_issuer",
"subject": "component=test_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": False,
"activation_date": "2023-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": False,
"peer_name": "example.com",
}
assert result == expected
@patch("nettacker.core.lib.ssl.create_tcp_socket")
@patch("nettacker.core.lib.ssl.is_weak_hash_algo")
@patch("nettacker.core.lib.ssl.crypto.load_certificate")
@patch("nettacker.core.lib.ssl.ssl.get_server_certificate")
def test_ssl_certificate_scan_self_signed(
self,
mock_certificate,
mock_x509,
mock_hash_check,
mock_connection,
ssl_library,
connection_params,
):
mock_hash_check.return_value = True mock_hash_check.return_value = True
mock_connection.return_value = (MockConnectionObject(HOST, "TLSv1.3"), True) mock_connection.return_value = (
MockConnectionObject(connection_params["HOST"], "TLSv1.3"),
True,
)
mock_x509.return_value = Mockx509Object( mock_x509.return_value = Mockx509Object(
is_expired=True, is_expired=True,
issuer="test_issuer_subject", issuer="test_issuer_subject",
@ -276,46 +351,62 @@ class TestSocketMethod(TestCase):
expire_date=b"21001207153045Z", expire_date=b"21001207153045Z",
activation_date=b"21001207153045Z", activation_date=b"21001207153045Z",
) )
self.assertEqual(
library.ssl_certificate_scan(HOST, PORT, TIMEOUT), result = ssl_library.ssl_certificate_scan(
{ connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
"expired": True,
"ssl_flag": True,
"service": "http",
"self_signed": True,
"issuer": "component=test_issuer_subject",
"subject": "component=test_issuer_subject",
"expiring_soon": False,
"expiration_date": "2100-12-07",
"not_activated": True,
"activation_date": "2100-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": True,
"peer_name": "example.com",
},
) )
# TESTING IF ssl_flag is False expected = {
mock_connection.return_value = (MockConnectionObject(HOST), False) "expired": True,
self.assertEqual( "ssl_flag": True,
library.ssl_certificate_scan(HOST, PORT, TIMEOUT), "service": "http",
{ "self_signed": True,
"service": "http", "issuer": "component=test_issuer_subject",
"ssl_flag": False, "subject": "component=test_issuer_subject",
"peer_name": "example.com", "expiring_soon": False,
}, "expiration_date": "2100-12-07",
"not_activated": True,
"activation_date": "2100-12-07",
"signing_algo": "test_algo",
"weak_signing_algo": True,
"peer_name": "example.com",
}
assert result == expected
@patch("nettacker.core.lib.ssl.create_tcp_socket")
@patch("nettacker.core.lib.ssl.is_weak_hash_algo")
@patch("nettacker.core.lib.ssl.crypto.load_certificate")
@patch("nettacker.core.lib.ssl.ssl.get_server_certificate")
def test_ssl_certificate_scan_no_ssl(
self,
mock_certificate,
mock_x509,
mock_hash_check,
mock_connection,
ssl_library,
connection_params,
):
mock_connection.return_value = (MockConnectionObject(connection_params["HOST"]), False)
result = ssl_library.ssl_certificate_scan(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
) )
mock_certificate.assert_called_with((HOST, PORT))
expected = {
"service": "http",
"ssl_flag": False,
"peer_name": "example.com",
}
assert result == expected
@patch("socket.socket") @patch("socket.socket")
@patch("ssl.create_default_context") @patch("ssl.create_default_context")
def test_is_weak_cipher_suite(self, mock_context, mock_socket): def test_is_weak_cipher_suite_success(self, mock_context, mock_socket, connection_params):
HOST = "example.com"
PORT = 80
TIMEOUT = 60
socket_instance = mock_socket.return_value socket_instance = mock_socket.return_value
context_instance = mock_context.return_value context_instance = mock_context.return_value
cipher_list = [ cipher_list = [
"HIGH", "HIGH",
"MEDIUM", "MEDIUM",
@ -337,98 +428,147 @@ class TestSocketMethod(TestCase):
"TLSv1.2", "TLSv1.2",
"TLSv1.3", "TLSv1.3",
] ]
self.assertEqual(is_weak_cipher_suite(HOST, PORT, TIMEOUT), (cipher_list, True))
context_instance.wrap_socket.assert_called_with(socket_instance, server_hostname=HOST)
socket_instance.settimeout.assert_called_with(TIMEOUT)
socket_instance.connect.assert_called_with((HOST, PORT))
result = is_weak_cipher_suite(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == (cipher_list, True)
context_instance.wrap_socket.assert_called_with(
socket_instance, server_hostname=connection_params["HOST"]
)
socket_instance.settimeout.assert_called_with(connection_params["TIMEOUT"])
socket_instance.connect.assert_called_with(
(connection_params["HOST"], connection_params["PORT"])
)
@patch("socket.socket")
@patch("ssl.create_default_context")
def test_is_weak_cipher_suite_ssl_error(self, mock_context, mock_socket, connection_params):
context_instance = mock_context.return_value
context_instance.wrap_socket.side_effect = ssl.SSLError context_instance.wrap_socket.side_effect = ssl.SSLError
self.assertEqual(is_weak_cipher_suite(HOST, PORT, TIMEOUT), ([], False))
def test_is_weak_hash_algo(self): result = is_weak_cipher_suite(
for algo in ("md2", "md4", "md5", "sha1"): connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
self.assertTrue(is_weak_hash_algo(algo)) )
self.assertFalse(is_weak_hash_algo("test_aglo"))
assert result == ([], False)
@pytest.mark.parametrize(
"algo,expected",
[
("md2", True),
("md4", True),
("md5", True),
("sha1", True),
("test_algo", False),
("sha256", False),
],
)
def test_is_weak_hash_algo(self, algo, expected):
assert is_weak_hash_algo(algo) == expected
@patch("socket.socket") @patch("socket.socket")
@patch("ssl.SSLContext") @patch("ssl.SSLContext")
def test_is_weak_ssl_version(self, mock_context, mock_socket): def test_is_weak_ssl_version_secure(self, mock_context, mock_socket, connection_params):
HOST = "example.com" context_instance = mock_context.return_value
PORT = 80 context_instance.wrap_socket.return_value = MockConnectionObject(
TIMEOUT = 60 connection_params["HOST"], "TLSv1.3"
)
result = is_weak_ssl_version(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == (["TLSv1.3", "TLSv1.3", "TLSv1.3", "TLSv1.3"], False)
@patch("socket.socket")
@patch("ssl.SSLContext")
def test_is_weak_ssl_version_weak(self, mock_context, mock_socket, connection_params):
context_instance = mock_context.return_value
context_instance.wrap_socket.return_value = MockConnectionObject(
connection_params["HOST"], "TLSv1.1"
)
result = is_weak_ssl_version(
connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
)
assert result == (["TLSv1.1", "TLSv1.1", "TLSv1.1", "TLSv1.1"], True)
@pytest.mark.parametrize("exception", [ssl.SSLError, ConnectionRefusedError])
@patch("socket.socket")
@patch("ssl.SSLContext")
def test_is_weak_ssl_version_exceptions(
self, mock_context, mock_socket, exception, connection_params
):
socket_instance = mock_socket.return_value socket_instance = mock_socket.return_value
context_instance = mock_context.return_value context_instance = mock_context.return_value
context_instance.wrap_socket.side_effect = exception
context_instance.wrap_socket.return_value = MockConnectionObject(HOST, "TLSv1.3") result = is_weak_ssl_version(
self.assertEqual( connection_params["HOST"], connection_params["PORT"], connection_params["TIMEOUT"]
is_weak_ssl_version(HOST, PORT, TIMEOUT),
(["TLSv1.3", "TLSv1.3", "TLSv1.3", "TLSv1.3"], False),
) )
context_instance.wrap_socket.return_value = MockConnectionObject(HOST, "TLSv1.1") assert result == ([], True)
self.assertEqual( socket_instance.settimeout.assert_called_with(connection_params["TIMEOUT"])
is_weak_ssl_version(HOST, PORT, TIMEOUT), socket_instance.connect.assert_called_with(
(["TLSv1.1", "TLSv1.1", "TLSv1.1", "TLSv1.1"], True), (connection_params["HOST"], connection_params["PORT"])
)
context_instance.wrap_socket.assert_called_with(
socket_instance, server_hostname=connection_params["HOST"]
) )
context_instance.wrap_socket.side_effect = ssl.SSLError def test_response_conditions_matched_expired_cert(self, ssl_engine, substeps, responses):
self.assertEqual(is_weak_ssl_version(HOST, PORT, TIMEOUT), ([], True)) result = ssl_engine.response_conditions_matched(
substeps.ssl_certificate_expired_vuln, responses.ssl_certificate_expired
context_instance.wrap_socket.side_effect = ConnectionRefusedError
self.assertEqual(is_weak_ssl_version(HOST, PORT, TIMEOUT), ([], True))
socket_instance.settimeout.assert_called_with(TIMEOUT)
socket_instance.connect.assert_called_with((HOST, PORT))
context_instance.wrap_socket.assert_called_with(socket_instance, server_hostname=HOST)
def test_response_conditions_matched(self):
# tests the response conditions matched for different scan methods
engine = SslEngine()
Substep = Substeps()
Response = Responses()
# ssl_certificate_expired_vuln
self.assertEqual(
engine.response_conditions_matched(
Substep.ssl_certificate_expired_vuln, Response.ssl_certificate_expired
),
{"subject": "component=subject", "expired": True, "expiration_date": "2023-12-07"},
)
# ssl_certificate_expired_vuln(not activated)
self.assertEqual(
engine.response_conditions_matched(
Substep.ssl_certificate_expired_vuln,
Response.ssl_certificate_deactivated,
),
{
"subject": "component=subject",
"not_activated": True,
"activation_date": "2100-12-07",
},
) )
# ssl_weak_version_vuln expected = {
self.assertEqual( "subject": "component=subject",
engine.response_conditions_matched( "expired": True,
Substep.ssl_weak_version_vuln, Response.ssl_weak_version_vuln "expiration_date": "2023-12-07",
), }
{
"weak_version": True, assert result == expected
"ssl_version": ["TLSv1"],
"issuer": "NA", def test_response_conditions_matched_deactivated_cert(self, ssl_engine, substeps, responses):
"subject": "NA", result = ssl_engine.response_conditions_matched(
"expiration_date": "NA", substeps.ssl_certificate_expired_vuln,
}, responses.ssl_certificate_deactivated,
) )
# ssl_* scans with ssl_flag = False expected = {
self.assertEqual( "subject": "component=subject",
engine.response_conditions_matched(Substep.ssl_weak_version_vuln, Response.ssl_off), [] "not_activated": True,
"activation_date": "2100-12-07",
}
assert result == expected
def test_response_conditions_matched_weak_version(self, ssl_engine, substeps, responses):
result = ssl_engine.response_conditions_matched(
substeps.ssl_weak_version_vuln, responses.ssl_weak_version_vuln
) )
# * scans with response None i.e. TCP connection failed(None) expected = {
self.assertEqual( "weak_version": True,
engine.response_conditions_matched(Substep.ssl_weak_version_vuln, None), [] "ssl_version": ["TLSv1"],
"issuer": "NA",
"subject": "NA",
"expiration_date": "NA",
}
assert result == expected
def test_response_conditions_matched_ssl_off(self, ssl_engine, substeps, responses):
result = ssl_engine.response_conditions_matched(
substeps.ssl_weak_version_vuln, responses.ssl_off
) )
assert result == []
def test_response_conditions_matched_none_response(self, ssl_engine, substeps):
result = ssl_engine.response_conditions_matched(substeps.ssl_weak_version_vuln, None)
assert result == []

View File

@ -1,107 +1,94 @@
from unittest.mock import patch from unittest.mock import patch
from nettacker.core.utils import common as common_utils from nettacker.core.utils import common as common_utils
from tests.common import TestCase
class TestCommon(TestCase): def test_arrays_to_matrix():
def test_arrays_to_matrix(self): assert sorted(common_utils.arrays_to_matrix({"ports": [1, 2, 3, 4, 5]})) == [
( [1],
self.assertEqual( [2],
sorted( [3],
common_utils.arrays_to_matrix( [4],
{"ports": [1, 2, 3, 4, 5]}, [5],
) ]
),
[[1], [2], [3], [4], [5]],
),
)
self.assertEqual( assert sorted(common_utils.arrays_to_matrix({"x": [1, 2], "y": [3, 4], "z": [5, 6]})) == [
sorted( [1, 3, 5],
common_utils.arrays_to_matrix( [1, 3, 6],
{"x": [1, 2], "y": [3, 4], "z": [5, 6]}, [1, 4, 5],
) [1, 4, 6],
), [2, 3, 5],
[ [2, 3, 6],
[1, 3, 5], [2, 4, 5],
[1, 3, 6], [2, 4, 6],
[1, 4, 5], ]
[1, 4, 6],
[2, 3, 5],
[2, 3, 6],
[2, 4, 5],
[2, 4, 6],
],
)
def test_generate_target_groups_empty_list(self):
targets = []
set_hardware_usage = 3
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == []
def test_generate_target_groups_set_hardware_less_than_targets_total(self): def test_generate_target_groups_empty_list():
targets = [1, 2, 3, 4, 5] targets = []
set_hardware_usage = 2 set_hardware_usage = 3
result = common_utils.generate_target_groups(targets, set_hardware_usage) result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1, 2, 3], [4, 5]] assert result == []
def test_generate_target_groups_set_hardware_equal_to_targets_total(self):
targets = [1, 2, 3, 4, 5]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3], [4], [5]]
def test_generate_target_groups_set_hardware_greater_than_targets_total(self): def test_generate_target_groups_set_hardware_less_than_targets_total():
targets = [1, 2, 3] targets = [1, 2, 3, 4, 5]
set_hardware_usage = 5 set_hardware_usage = 2
result = common_utils.generate_target_groups(targets, set_hardware_usage) result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3]] assert result == [[1, 2, 3], [4, 5]]
def test_sort_dictionary(self):
input_dict = {
"a": 1,
"c": 3,
"d": 23,
"b": 2,
}
expected_dict = {
"a": 1,
"b": 2,
"c": 3,
"d": 23,
}
input_dict_keys = tuple(input_dict.keys())
expected_dict_keys = tuple(expected_dict.keys())
self.assertNotEqual(input_dict_keys, expected_dict_keys)
sorted_dict_keys = tuple(common_utils.sort_dictionary(input_dict).keys()) def test_generate_target_groups_set_hardware_equal_to_targets_total():
self.assertEqual(sorted_dict_keys, expected_dict_keys) targets = [1, 2, 3, 4, 5]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3], [4], [5]]
@patch("multiprocessing.cpu_count")
def test_select_maximum_cpu_core(self, cpu_count_mock):
cores_mapping = {
1: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
2: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
4: {"low": 1, "normal": 1, "high": 2, "maximum": 3},
6: {"low": 1, "normal": 1, "high": 3, "maximum": 5},
8: {"low": 1, "normal": 2, "high": 4, "maximum": 7},
10: {"low": 1, "normal": 2, "high": 5, "maximum": 9},
12: {"low": 1, "normal": 3, "high": 6, "maximum": 11},
16: {"low": 2, "normal": 4, "high": 8, "maximum": 15},
32: {"low": 4, "normal": 8, "high": 16, "maximum": 31},
48: {"low": 6, "normal": 12, "high": 24, "maximum": 47},
64: {"low": 8, "normal": 16, "high": 32, "maximum": 63},
}
for num_cores, levels in cores_mapping.items():
cpu_count_mock.return_value = num_cores
for level in ("low", "normal", "high", "maximum"):
self.assertEqual(
common_utils.select_maximum_cpu_core(level),
levels[level],
f"It should be {common_utils.select_maximum_cpu_core(level)} "
"of {num_cores} cores for '{level}' mode",
)
self.assertEqual(common_utils.select_maximum_cpu_core("invalid"), 1) def test_generate_target_groups_set_hardware_greater_than_targets_total():
targets = [1, 2, 3]
set_hardware_usage = 5
result = common_utils.generate_target_groups(targets, set_hardware_usage)
assert result == [[1], [2], [3]]
def test_sort_dictionary():
input_dict = {
"a": 1,
"c": 3,
"d": 23,
"b": 2,
}
expected_dict = {
"a": 1,
"b": 2,
"c": 3,
"d": 23,
}
input_dict_keys = tuple(input_dict.keys())
expected_dict_keys = tuple(expected_dict.keys())
assert input_dict_keys != expected_dict_keys
sorted_dict_keys = tuple(common_utils.sort_dictionary(input_dict).keys())
assert sorted_dict_keys == expected_dict_keys
@patch("multiprocessing.cpu_count")
def test_select_maximum_cpu_core(cpu_count_mock):
cores_mapping = {
1: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
2: {"low": 1, "normal": 1, "high": 1, "maximum": 1},
4: {"low": 1, "normal": 1, "high": 2, "maximum": 3},
6: {"low": 1, "normal": 1, "high": 3, "maximum": 5},
8: {"low": 1, "normal": 2, "high": 4, "maximum": 7},
10: {"low": 1, "normal": 2, "high": 5, "maximum": 9},
12: {"low": 1, "normal": 3, "high": 6, "maximum": 11},
16: {"low": 2, "normal": 4, "high": 8, "maximum": 15},
32: {"low": 4, "normal": 8, "high": 16, "maximum": 31},
48: {"low": 6, "normal": 12, "high": 24, "maximum": 47},
64: {"low": 8, "normal": 16, "high": 32, "maximum": 63},
}
for num_cores, levels in cores_mapping.items():
cpu_count_mock.return_value = num_cores
for level in ("low", "normal", "high", "maximum"):
assert common_utils.select_maximum_cpu_core(level) == levels[level]
assert common_utils.select_maximum_cpu_core("invalid") == 1

View File

@ -1,94 +1,95 @@
from datetime import datetime from datetime import datetime
import pytest
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from nettacker.database.models import Base, Report, TempEvents, HostsLog from nettacker.database.models import Base, Report, TempEvents, HostsLog
from tests.common import TestCase
class TestModels(TestCase): @pytest.fixture
def setUp(self): def session():
# Creating an in-memory SQLite database for testing engine = create_engine("sqlite:///:memory:")
self.engine = create_engine("sqlite:///:memory:") Base.metadata.create_all(engine)
Base.metadata.create_all(self.engine) Session = sessionmaker(bind=engine)
Session = sessionmaker(bind=self.engine) sess = Session()
self.session = Session() yield sess
sess.close()
Base.metadata.drop_all(engine)
def tearDown(self):
self.session.close()
Base.metadata.drop_all(self.engine)
def test_report_model(self): def test_report_model(session):
test_date = datetime.now() test_date = datetime.now()
test_report = Report( test_report = Report(
date=test_date, date=test_date,
scan_unique_id="test123", scan_unique_id="test123",
report_path_filename="/path/to/report.txt", report_path_filename="/path/to/report.txt",
options='{"option1": "value1"}', options='{"option1": "value1"}',
) )
self.session.add(test_report) session.add(test_report)
self.session.commit() session.commit()
retrieved_report = self.session.query(Report).first() retrieved_report = session.query(Report).first()
self.assertIsNotNone(retrieved_report) assert retrieved_report is not None
self.assertEqual(retrieved_report.scan_unique_id, "test123") assert retrieved_report.scan_unique_id == "test123"
self.assertEqual(retrieved_report.report_path_filename, "/path/to/report.txt") assert retrieved_report.report_path_filename == "/path/to/report.txt"
self.assertEqual(retrieved_report.options, '{"option1": "value1"}') assert retrieved_report.options == '{"option1": "value1"}'
repr_string = repr(retrieved_report) repr_string = repr(retrieved_report)
self.assertIn("test123", repr_string) assert "test123" in repr_string
self.assertIn("/path/to/report.txt", repr_string) assert "/path/to/report.txt" in repr_string
def test_temp_events_model(self):
test_date = datetime.now()
test_event = TempEvents(
date=test_date,
target="192.168.1.1",
module_name="port_scan",
scan_unique_id="test123",
event_name="open_port",
port="80",
event="Port 80 is open",
data='{"details": "HTTP server running"}',
)
self.session.add(test_event) def test_temp_events_model(session):
self.session.commit() test_date = datetime.now()
test_event = TempEvents(
date=test_date,
target="192.168.1.1",
module_name="port_scan",
scan_unique_id="test123",
event_name="open_port",
port="80",
event="Port 80 is open",
data='{"details": "HTTP server running"}',
)
retrieved_event = self.session.query(TempEvents).first() session.add(test_event)
self.assertIsNotNone(retrieved_event) session.commit()
self.assertEqual(retrieved_event.target, "192.168.1.1")
self.assertEqual(retrieved_event.module_name, "port_scan")
self.assertEqual(retrieved_event.port, "80")
repr_string = repr(retrieved_event) retrieved_event = session.query(TempEvents).first()
self.assertIn("192.168.1.1", repr_string) assert retrieved_event is not None
self.assertIn("port_scan", repr_string) assert retrieved_event.target == "192.168.1.1"
assert retrieved_event.module_name == "port_scan"
assert retrieved_event.port == "80"
def test_hosts_log_model(self): repr_string = repr(retrieved_event)
test_date = datetime.now() assert "192.168.1.1" in repr_string
test_log = HostsLog( assert "port_scan" in repr_string
date=test_date,
target="192.168.1.1",
module_name="vulnerability_scan",
scan_unique_id="test123",
port="443",
event="Found vulnerability CVE-2021-12345",
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
)
self.session.add(test_log)
self.session.commit()
retrieved_log = self.session.query(HostsLog).first() def test_hosts_log_model(session):
self.assertIsNotNone(retrieved_log) test_date = datetime.now()
self.assertEqual(retrieved_log.target, "192.168.1.1") test_log = HostsLog(
self.assertEqual(retrieved_log.module_name, "vulnerability_scan") date=test_date,
self.assertEqual(retrieved_log.port, "443") target="192.168.1.1",
self.assertEqual(retrieved_log.event, "Found vulnerability CVE-2021-12345") module_name="vulnerability_scan",
scan_unique_id="test123",
port="443",
event="Found vulnerability CVE-2021-12345",
json_event='{"vulnerability": "CVE-2021-12345", "severity": "high"}',
)
repr_string = repr(retrieved_log) session.add(test_log)
self.assertIn("192.168.1.1", repr_string) session.commit()
self.assertIn("vulnerability_scan", repr_string)
retrieved_log = session.query(HostsLog).first()
assert retrieved_log is not None
assert retrieved_log.target == "192.168.1.1"
assert retrieved_log.module_name == "vulnerability_scan"
assert retrieved_log.port == "443"
assert retrieved_log.event == "Found vulnerability CVE-2021-12345"
repr_string = repr(retrieved_log)
assert "192.168.1.1" in repr_string
assert "vulnerability_scan" in repr_string

View File

@ -1,141 +1,89 @@
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import pytest
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from nettacker.config import Config from nettacker.config import Config
from nettacker.database.models import Base from nettacker.database.models import Base
from nettacker.database.mysql import mysql_create_database, mysql_create_tables from nettacker.database.mysql import mysql_create_database, mysql_create_tables
from tests.common import TestCase
class TestMySQLFunctions(TestCase): @pytest.fixture(autouse=True)
"""Test cases for mysql.py functions""" def setup_config():
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_success(self, mock_create_engine):
"""Test successful database creation"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
# Set up mock connection and execution @patch("nettacker.database.mysql.create_engine")
mock_conn = MagicMock() def test_mysql_create_database_success(mock_create_engine):
mock_engine = MagicMock() mock_conn = MagicMock()
mock_create_engine.return_value = mock_engine mock_engine = MagicMock()
mock_engine.connect.return_value.__enter__.return_value = mock_conn mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
# Mock database query results - database doesn't exist yet mysql_create_database()
mock_conn.execute.return_value = [("mysql",), ("information_schema",)]
# Call the function mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
call_args_list = mock_conn.execute.call_args_list
assert len(call_args_list) == 2
first_call_arg = call_args_list[0][0][0]
assert str(first_call_arg) == "SHOW DATABASES;"
second_call_arg = call_args_list[1][0][0]
assert str(second_call_arg) == "CREATE DATABASE test_db "
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_already_exists(mock_create_engine):
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
mysql_create_database()
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
assert mock_conn.execute.call_count == 1
call_arg = mock_conn.execute.call_args[0][0]
assert str(call_arg) == "SHOW DATABASES;"
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_exception(mock_create_engine):
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
with patch("builtins.print") as mock_print:
mysql_create_database() mysql_create_database()
mock_print.assert_called_once()
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_tables(mock_create_engine):
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all:
mysql_create_tables()
# Assertions
mock_create_engine.assert_called_once_with( mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306" "mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
) )
mock_create_all.assert_called_once_with(mock_engine)
# Check that execute was called with any text object that has the expected SQL
call_args_list = mock_conn.execute.call_args_list
self.assertEqual(len(call_args_list), 2) # Two calls to execute
# Check that the first call is SHOW DATABASES
first_call_arg = call_args_list[0][0][0]
self.assertEqual(str(first_call_arg), "SHOW DATABASES;")
# Check that the second call is CREATE DATABASE
second_call_arg = call_args_list[1][0][0]
self.assertEqual(str(second_call_arg), "CREATE DATABASE test_db ")
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_already_exists(self, mock_create_engine):
"""Test when database already exists"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
Config.db.name = "test_db"
# Set up mock connection and execution
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.return_value.__enter__.return_value = mock_conn
# Mock database query results - database already exists
mock_conn.execute.return_value = [("mysql",), ("information_schema",), ("test_db",)]
# Call the function
mysql_create_database()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306"
)
# Check that execute was called once with SHOW DATABASES
self.assertEqual(mock_conn.execute.call_count, 1)
call_arg = mock_conn.execute.call_args[0][0]
self.assertEqual(str(call_arg), "SHOW DATABASES;")
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_database_exception(self, mock_create_engine):
"""Test exception handling in create database"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
# Set up mock to raise exception
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
mock_engine.connect.side_effect = SQLAlchemyError("Connection error")
# Call the function (should not raise exception)
with patch("builtins.print") as mock_print:
mysql_create_database()
mock_print.assert_called_once()
@patch("nettacker.database.mysql.create_engine")
def test_mysql_create_tables(self, mock_create_engine):
"""Test table creation function"""
# Set up mock config
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "test_user",
"password": "test_pass",
"host": "localhost",
"port": "3306",
"name": "test_db",
}
# Set up mock engine
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
# Call the function
with patch.object(Base.metadata, "create_all") as mock_create_all:
mysql_create_tables()
# Assertions
mock_create_engine.assert_called_once_with(
"mysql+pymysql://test_user:test_pass@localhost:3306/test_db"
)
mock_create_all.assert_called_once_with(mock_engine)

View File

@ -5,88 +5,90 @@ from sqlalchemy.exc import OperationalError
from nettacker.config import Config from nettacker.config import Config
from nettacker.database.models import Base from nettacker.database.models import Base
from nettacker.database.postgresql import postgres_create_database from nettacker.database.postgresql import postgres_create_database
from tests.common import TestCase
class TestPostgresFunctions(TestCase): @patch("nettacker.database.postgresql.create_engine")
@patch("nettacker.database.postgresql.create_engine") def test_postgres_create_database_success(mock_create_engine):
def test_postgres_create_database_success(self, mock_create_engine): Config.db = MagicMock()
Config.db = MagicMock() Config.db.as_dict.return_value = {
Config.db.as_dict.return_value = { "username": "user",
"username": "user", "password": "pass",
"password": "pass", "host": "localhost",
"host": "localhost", "port": "5432",
"port": "5432", "name": "nettacker_db",
"name": "nettacker_db", }
}
mock_engine = MagicMock() mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all: with patch.object(Base.metadata, "create_all") as mock_create_all:
postgres_create_database()
mock_create_engine.assert_called_once_with(
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
)
mock_create_all.assert_called_once_with(mock_engine)
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_if_not_exists(mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
Config.db.name = "nettacker_db"
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_engine_final = MagicMock()
mock_create_engine.side_effect = [
mock_engine_initial,
mock_engine_fallback,
mock_engine_final,
]
with patch.object(
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
):
mock_conn = MagicMock()
mock_engine_fallback.connect.return_value = mock_conn
mock_conn.execution_options.return_value = mock_conn
postgres_create_database()
assert mock_create_engine.call_count == 3
args, _ = mock_conn.execute.call_args
assert str(args[0]) == "CREATE DATABASE nettacker_db"
mock_conn.close.assert_called_once()
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_create_fail(mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
with patch.object(
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
):
import pytest
with pytest.raises(OperationalError):
postgres_create_database() postgres_create_database()
mock_create_engine.assert_called_once_with(
"postgresql+psycopg2://user:pass@localhost:5432/nettacker_db"
)
mock_create_all.assert_called_once_with(mock_engine)
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_if_not_exists(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
Config.db.name = "nettacker_db"
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_engine_final = MagicMock()
mock_create_engine.side_effect = [
mock_engine_initial,
mock_engine_fallback,
mock_engine_final,
]
with patch.object(
Base.metadata, "create_all", side_effect=[OperationalError("fail", None, None), None]
):
mock_conn = MagicMock()
mock_engine_fallback.connect.return_value = mock_conn
mock_conn.execution_options.return_value = mock_conn
postgres_create_database()
assert mock_create_engine.call_count == 3
args, _ = mock_conn.execute.call_args
assert str(args[0]) == "CREATE DATABASE nettacker_db"
mock_conn.close.assert_called_once()
@patch("nettacker.database.postgresql.create_engine")
def test_postgres_create_database_create_fail(self, mock_create_engine):
Config.db = MagicMock()
Config.db.as_dict.return_value = {
"username": "user",
"password": "pass",
"host": "localhost",
"port": "5432",
"name": "nettacker_db",
}
mock_engine_initial = MagicMock()
mock_engine_fallback = MagicMock()
mock_create_engine.side_effect = [mock_engine_initial, mock_engine_fallback]
mock_engine_fallback.connect.side_effect = OperationalError("fail again", None, None)
with patch.object(
Base.metadata, "create_all", side_effect=OperationalError("fail", None, None)
):
with self.assertRaises(OperationalError):
postgres_create_database()

View File

@ -1,42 +1,45 @@
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import pytest
from sqlalchemy import create_engine, inspect from sqlalchemy import create_engine, inspect
from nettacker.config import Config from nettacker.config import Config
from nettacker.database.models import Base from nettacker.database.models import Base
from nettacker.database.sqlite import sqlite_create_tables from nettacker.database.sqlite import sqlite_create_tables
from tests.common import TestCase
class TestSQLiteFunctions(TestCase): @pytest.fixture
@patch("nettacker.database.sqlite.create_engine") def mock_config():
def test_sqlite_create_tables(self, mock_create_engine): Config.db = MagicMock()
Config.db = MagicMock() yield Config.db
Config.db.as_dict.return_value = {"name": "/path/to/test.db"}
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
with patch.object(Base.metadata, "create_all") as mock_create_all: @patch("nettacker.database.sqlite.create_engine")
sqlite_create_tables() def test_sqlite_create_tables(mock_create_engine, mock_config):
mock_config.as_dict.return_value = {"name": "/path/to/test.db"}
mock_create_engine.assert_called_once_with( mock_engine = MagicMock()
"sqlite:////path/to/test.db", connect_args={"check_same_thread": False} mock_create_engine.return_value = mock_engine
)
mock_create_all.assert_called_once_with(mock_engine)
def test_sqlite_create_tables_integration(self): with patch.object(Base.metadata, "create_all") as mock_create_all:
engine = create_engine("sqlite:///:memory:") sqlite_create_tables()
Config.db = MagicMock() mock_create_engine.assert_called_once_with(
Config.db.as_dict.return_value = {"name": ":memory:"} "sqlite:////path/to/test.db", connect_args={"check_same_thread": False}
)
mock_create_all.assert_called_once_with(mock_engine)
with patch("nettacker.database.sqlite.create_engine", return_value=engine):
sqlite_create_tables()
inspector = inspect(engine) def test_sqlite_create_tables_integration(mock_config):
tables = inspector.get_table_names() engine = create_engine("sqlite:///:memory:")
mock_config.as_dict.return_value = {"name": ":memory:"}
self.assertIn("reports", tables, "Reports table was not created") with patch("nettacker.database.sqlite.create_engine", return_value=engine):
self.assertIn("temp_events", tables, "Temp events table was not created") sqlite_create_tables()
self.assertIn("scan_events", tables, "Scan events table was not created")
inspector = inspect(engine)
tables = inspector.get_table_names()
assert "reports" in tables, "Reports table was not created"
assert "temp_events" in tables, "Temp events table was not created"
assert "scan_events" in tables, "Scan events table was not created"

View File

@ -1,22 +1,18 @@
from collections import Counter from collections import Counter
from pathlib import Path
import pytest nettacker_path = Path(__file__).parent.parent.parent.parent
from tests.common import TestCase
class TestPasswords(TestCase): def test_top_1000_common_passwords():
top_1000_common_passwords_path = "lib/payloads/passwords/top_1000_common_passwords.txt" top_1000_passwords_file_path = (
nettacker_path / "nettacker/lib/payloads/passwords/top_1000_common_passwords.txt"
)
with open(top_1000_passwords_file_path) as f:
top_1000_passwords = [line.strip() for line in f.readlines() if line.strip()]
@pytest.mark.xfail(reason="It currently contains 1001 passwords.") assert len(top_1000_passwords) == 1000, "There should be exactly 1000 passwords"
def test_top_1000_common_passwords(self):
with open(self.nettacker_path / self.top_1000_common_passwords_path) as top_1000_file:
top_1000_passwords = [line.strip() for line in top_1000_file.readlines()]
self.assertEqual(len(top_1000_passwords), 1000, "There should be exactly 1000 passwords") assert len(set(top_1000_passwords)) == len(
top_1000_passwords
self.assertEqual( ), f"The passwords aren't unique: {Counter(top_1000_passwords).most_common(1)[0][0]}"
len(set(top_1000_passwords)),
len(top_1000_passwords),
f"The passwords aren't unique: {Counter(top_1000_passwords).most_common(1)[0][0]}",
)

View File

@ -1,48 +1,29 @@
from collections import Counter from collections import Counter
from pathlib import Path
from tests.common import TestCase import pytest
wordlists = { wordlists = {
"admin_file": ["lib/payloads/wordlists/admin_wordlist.txt", 533], "admin_file": ("nettacker/lib/payloads/wordlists/admin_wordlist.txt", 533),
"dir_file": ["lib/payloads/wordlists/dir_wordlist.txt", 1966], "dir_file": ("nettacker/lib/payloads/wordlists/dir_wordlist.txt", 1966),
"pma_file": ["lib/payloads/wordlists/pma_wordlist.txt", 174], "pma_file": ("nettacker/lib/payloads/wordlists/pma_wordlist.txt", 174),
"wp_plugin_small_file": ["lib/payloads/wordlists/wp_plugin_small.txt", 291], "wp_plugin_small_file": ("nettacker/lib/payloads/wordlists/wp_plugin_small.txt", 291),
"wp_theme_small_file": ["lib/payloads/wordlists/wp_theme_small.txt", 41], "wp_theme_small_file": ("nettacker/lib/payloads/wordlists/wp_theme_small.txt", 41),
"wp_timethumb_file": ["lib/payloads/wordlists/wp_timethumbs.txt", 2424], "wp_timethumb_file": ("nettacker/lib/payloads/wordlists/wp_timethumbs.txt", 2424),
} }
nettacker_path = Path(__file__).parent.parent.parent.parent
class TestWordlists(TestCase):
def test_admin_wordlist(self):
self.run_wordlist_test("admin_file")
def test_dir_wordlist(self): @pytest.mark.parametrize("key", list(wordlists.keys()))
self.run_wordlist_test("dir_file") def test_wordlist(key):
wordlist_path, expected_length = wordlists[key]
full_path = nettacker_path / wordlist_path
def test_pma_wordlist(self): with open(full_path) as f:
self.run_wordlist_test("pma_file") paths = [line.strip() for line in f.readlines()]
def test_wp_plugin_small_wordlist(self): assert len(paths) == expected_length, f"There are {expected_length} paths in {key}"
self.run_wordlist_test("wp_plugin_small_file") assert len(set(paths)) == len(
paths
def test_wp_theme_small_wordlist(self): ), f"The paths aren't unique in {key}: {Counter(paths).most_common(1)[0][0]}"
self.run_wordlist_test("wp_theme_small_file")
def test_wp_timethumb_wordlist(self):
self.run_wordlist_test("wp_timethumb_file")
def run_wordlist_test(self, key):
wordlist_path = wordlists[key][0]
wordlist_length = wordlists[key][1]
with open(self.nettacker_path / wordlist_path) as wordlist_file:
paths = [line.strip() for line in wordlist_file.readlines()]
self.assertEqual(
len(paths), wordlist_length, f"There are {wordlist_length} paths in {key}"
)
self.assertEqual(
len(set(paths)),
len(paths),
f"The paths aren't unique in {key}: {Counter(paths).most_common(1)[0][0]}",
)