This commit is contained in:
PreistlyPython 2025-11-17 02:47:52 +00:00 committed by GitHub
commit 18cb5202a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 44 deletions

View File

@ -1,4 +1,4 @@
import os
from pathlib import Path
from flask import abort
@ -117,13 +117,16 @@ def get_file(filename):
Returns:
content of the file or abort(404)
"""
if not os.path.normpath(filename).startswith(str(Config.path.web_static_dir)):
base = Config.path.web_static_dir.resolve()
try:
target = Path(filename).resolve(strict=True)
except FileNotFoundError:
abort(404)
if not target.is_relative_to(base):
abort(404)
try:
return open(filename, "rb").read()
except ValueError:
abort(404)
except IOError:
return target.read_bytes()
except OSError:
abort(404)

View File

@ -1,10 +1,10 @@
import csv
import json
import multiprocessing
import os
import random
import string
import time
from pathlib import Path
from threading import Thread
from types import SimpleNamespace
@ -187,8 +187,8 @@ def get_statics(path):
"""
static_types = mime_types()
return Response(
get_file(os.path.join(Config.path.web_static_dir, path)),
mimetype=static_types.get(os.path.splitext(path)[1], "text/html"),
get_file(Config.path.web_static_dir / path),
mimetype=static_types.get(Path(path).suffix, "text/html"),
)
@ -220,7 +220,7 @@ def sanitize_report_path_filename(report_path_filename):
Returns:
the sanitized report path filename
"""
filename = secure_filename(os.path.basename(report_path_filename))
filename = secure_filename(Path(report_path_filename).name)
if not filename:
return False
# Define a list or tuple of valid extensions
@ -391,8 +391,8 @@ def get_result_content():
return Response(
file_content,
mimetype=mime_types().get(os.path.splitext(filename)[1], "text/plain"),
headers={"Content-Disposition": "attachment;filename=" + filename.split("/")[-1]},
mimetype=mime_types().get(Path(filename).suffix, "text/plain"),
headers={"Content-Disposition": f'attachment; filename="{Path(filename).name}"'},
)

View File

@ -2,7 +2,6 @@ import csv
import html
import importlib
import json
import os
import uuid
from datetime import datetime
from pathlib import Path
@ -371,31 +370,34 @@ def create_compare_report(options, scan_id):
else generate_compare_filepath(scan_id)
)
base_path = str(nettacker_path_config.results_dir)
base_path = nettacker_path_config.results_dir
compare_report_path_filename = sanitize_path(compare_report_path_filename)
fullpath = os.path.normpath(os.path.join(base_path, compare_report_path_filename))
fullpath = (base_path / compare_report_path_filename).resolve()
if not fullpath.startswith(base_path):
if not fullpath.is_relative_to(base_path.resolve()):
raise PermissionError
if (len(fullpath) >= 5 and fullpath[-5:] == ".html") or (
len(fullpath) >= 4 and fullpath[-4:] == ".htm"
):
suffix = fullpath.suffix.lower()
suffixes = [s.lower() for s in fullpath.suffixes]
if suffix in (".html", ".htm"):
html_report = build_compare_report(compare_results)
with Path(fullpath).open("w", encoding="utf-8") as compare_report:
with fullpath.open("w", encoding="utf-8") as compare_report:
compare_report.write(html_report + "\n")
elif len(fullpath) >= 5 and fullpath[-5:] == ".json":
with Path(fullpath).open("w", encoding="utf-8") as compare_report:
compare_report.write(str(json.dumps(compare_results)) + "\n")
elif len(fullpath) >= 5 and fullpath[-4:] == ".csv":
elif suffixes[-2:] == [".dd", ".json"]:
with fullpath.open("w", encoding="utf-8") as compare_report:
compare_report.write(json.dumps(compare_results) + "\n")
elif suffix == ".json":
with fullpath.open("w", encoding="utf-8") as compare_report:
compare_report.write(json.dumps(compare_results) + "\n")
elif suffix == ".csv":
keys = compare_results.keys()
with Path(fullpath).open("a") as csvfile:
with fullpath.open("a") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=keys)
if csvfile.tell() == 0:
writer.writeheader()
writer.writerow(compare_results)
else:
with Path(fullpath).open("w", encoding="utf-8") as compare_report:
with fullpath.open("w", encoding="utf-8") as compare_report:
compare_report.write(create_compare_text_table(compare_results))
log.write(create_compare_text_table(compare_results))

View File

@ -1,9 +1,9 @@
import sys
from os.path import abspath, dirname, join
from pathlib import Path
project_root = dirname(dirname(__file__))
nettacker_dir = abspath(join(project_root, "nettacker"))
tests_dir = abspath(join(project_root, "tests"))
project_root = Path(__file__).parent.parent
nettacker_dir = (project_root / "nettacker").resolve()
tests_dir = (project_root / "tests").resolve()
sys.path.insert(0, nettacker_dir)
sys.path.insert(1, tests_dir)
sys.path.insert(0, str(nettacker_dir))
sys.path.insert(1, str(tests_dir))

View File

@ -183,13 +183,9 @@ def test_text_report(mock_submit, mock_open_file, mock_build_text, mock_get_logs
@patch("nettacker.core.graph.get_options_by_scan_id")
@patch("nettacker.core.graph.build_compare_report", return_value="<html-report>")
@patch("nettacker.core.graph.Path.open", new_callable=mock_open)
@patch("nettacker.core.graph.os.path.normpath", side_effect=lambda x: x)
@patch("nettacker.core.graph.os.path.join", side_effect=lambda *args: "/".join(args))
@patch("nettacker.core.graph.create_compare_text_table", return_value="text-report")
def test_html_json_csv_text(
mock_text_table,
mock_join,
mock_norm,
mock_open_file,
mock_build_html,
mock_get_opts,
@ -247,9 +243,8 @@ def test_no_comparison_logs(mock_logs):
@patch("nettacker.core.graph.get_logs_by_scan_id")
@patch("nettacker.core.graph.get_options_by_scan_id")
@patch("nettacker.core.graph.os.path.normpath", side_effect=lambda x: "/etc/passwd")
@patch("nettacker.core.graph.os.path.join", side_effect=lambda *args: "/etc/passwd")
def test_permission_error(mock_join, mock_norm, mock_opts, mock_logs):
@patch("nettacker.core.graph.sanitize_path", return_value="../../etc/passwd")
def test_permission_error(mock_sanitize, mock_opts, mock_logs):
dummy_log = {
"target": "1.1.1.1",
"module_name": "mod",

View File

@ -1,5 +1,5 @@
import os
import re
from pathlib import Path
import pytest
import yaml
@ -12,9 +12,10 @@ DUMMY_TEST_STRING = (
def get_yaml_files():
for base in BASE_DIRS:
for file in os.listdir(base):
if file.endswith(".yaml"):
yield os.path.join(base, file)
base_path = Path(base)
for file in base_path.iterdir():
if file.suffix == ".yaml":
yield str(file)
def load_yaml(file_path):
@ -77,7 +78,7 @@ def test_yaml_regexes_valid(yaml_file):
if payloads[0].get("library") == "http":
regexes = extract_http_regexes(payloads)
elif payloads[0].get("library") == "socket":
file_name = os.path.basename(yaml_file)
file_name = Path(yaml_file).name
regexes = extract_socket_regexes(file_name, payloads)
else:
pytest.skip(f"Unknown library type in {yaml_file}")