# -*- coding: utf-8 -*- # type: ignore from dataclasses import asdict import errno import itertools import json import logging import os from pathlib import Path import sys import tempfile import time from collections import defaultdict from datetime import datetime from typing import Dict, Optional, List, Any, Union, Iterator import click import requests from packaging.specifiers import SpecifierSet from packaging.utils import canonicalize_name from packaging.version import parse as parse_version, Version from pydantic_core import to_jsonable_python from filelock import FileLock from safety_schemas.models import Ecosystem, FileType from tenacity import ( before_sleep_log, retry, retry_if_exception_type, stop_after_attempt, wait_exponential_jitter, ) from .constants import ( API_MIRRORS, DB_CACHE_FILE, OPEN_MIRRORS, REQUEST_TIMEOUT, DATA_API_BASE_URL, JSON_SCHEMA_VERSION, IGNORE_UNPINNED_REQ_REASON, ) from .errors import ( DatabaseFetchError, DatabaseFileNotFoundError, InvalidCredentialError, TooManyRequestsError, NetworkConnectionError, RequestTimeoutError, ServerError, MalformedDatabase, ) from .meta import get_meta_http_headers from .models import ( Vulnerability, CVE, Severity, Fix, is_pinned_requirement, SafetyRequirement, ) from .output_utils import ( print_service, get_applied_msg, prompt_service, get_skipped_msg, get_fix_opt_used_msg, is_using_api_key, get_specifier_range_info, ) from .util import ( build_remediation_info_url, pluralize, read_requirements, Package, build_telemetry_data, sync_safety_context, SafetyContext, validate_expiration_date, is_a_remote_mirror, get_requirements_content, SafetyPolicyFile, get_terminal_size, is_ignore_unpinned_mode, get_hashes, ) LOG = logging.getLogger(__name__) def get_from_cache( db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False ) -> Optional[Dict[str, Any]]: """ Retrieves the database from the cache if it is valid. Args: db_name (str): The name of the database. cache_valid_seconds (int): The validity period of the cache in seconds. skip_time_verification (bool): Whether to skip time verification. Returns: Optional[[Dict[str, Any]]: The cached database if available and valid, otherwise False. """ cache_file_lock = f"{DB_CACHE_FILE}.lock" os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True) lock = FileLock(cache_file_lock, timeout=10) with lock: if os.path.exists(DB_CACHE_FILE): with open(DB_CACHE_FILE) as f: try: data = json.loads(f.read()) if db_name in data: if "cached_at" in data[db_name]: if ( data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification ): LOG.debug( "Getting the database from cache at %s, cache setting: %s", data[db_name]["cached_at"], cache_valid_seconds, ) try: data[db_name]["db"]["meta"]["base_domain"] = ( "https://data.safetycli.com" ) except KeyError: pass return data[db_name]["db"] LOG.debug( "Cached file is too old, it was cached at %s", data[db_name]["cached_at"], ) else: LOG.debug( "There is not the cached_at key in %s database", data[db_name], ) except json.JSONDecodeError: LOG.debug("JSONDecodeError trying to get the cached database.") else: LOG.debug("Cache file doesn't exist...") return None def write_to_cache(db_name: str, data: Dict[str, Any]) -> None: """ Writes the database to the cache. Args: db_name (str): The name of the database. data (Dict[str, Any]): The database data to be cached. """ # cache is in: ~/safety/cache.json # and has the following form: # { # "insecure.json": { # "cached_at": 12345678 # "db": {} # }, # "insecure_full.json": { # "cached_at": 12345678 # "db": {} # }, # } if not os.path.exists(os.path.dirname(DB_CACHE_FILE)): try: os.makedirs(os.path.dirname(DB_CACHE_FILE)) with open(DB_CACHE_FILE, "w") as _: _.write(json.dumps({})) LOG.debug("Cache file created") except OSError as exc: # Guard against race condition LOG.debug("Unable to create the cache file because: %s", exc.errno) if exc.errno != errno.EEXIST: raise cache_file_lock = f"{DB_CACHE_FILE}.lock" lock = FileLock(cache_file_lock, timeout=10) with lock: if os.path.exists(DB_CACHE_FILE): with open(DB_CACHE_FILE, "r") as f: try: cache = json.loads(f.read()) except json.JSONDecodeError: LOG.debug( "JSONDecodeError in the local cache, dumping the full cache file." ) cache = {} else: cache = {} with open(DB_CACHE_FILE, "w") as f: cache[db_name] = {"cached_at": time.time(), "db": data} f.write(json.dumps(cache)) LOG.debug("Safety updated the cache file for %s database.", db_name) @retry( stop=stop_after_attempt(3), wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3), reraise=True, retry=retry_if_exception_type( ( NetworkConnectionError, RequestTimeoutError, TooManyRequestsError, MalformedDatabase, ServerError, ) ), before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING), ) def fetch_database_url( session: requests.Session, mirror: str, db_name: str, cached: int, telemetry: bool = True, ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache: bool = True, ) -> Dict[str, Any]: """ Fetches the database from a URL. Args: session (requests.Session): The requests session. mirror (str): The URL of the mirror. db_name (str): The name of the database. cached (int): The cache validity in seconds. telemetry (bool): Whether to include telemetry data. ecosystem (Ecosystem): The ecosystem. from_cache (bool): Whether to fetch from cache. Returns: Dict[str, Any]: The fetched database. """ headers = {"schema-version": JSON_SCHEMA_VERSION, "ecosystem": ecosystem.value} headers.update(get_meta_http_headers()) if cached and from_cache: cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached) if cached_data: LOG.info("Database %s returned from cache.", db_name) return cached_data url = mirror + db_name telemetry_data = { "telemetry": json.dumps( build_telemetry_data(telemetry=telemetry), default=to_jsonable_python ) } try: r = session.get( url=url, timeout=REQUEST_TIMEOUT, headers=headers, params=telemetry_data ) except requests.exceptions.ConnectionError: raise NetworkConnectionError() except requests.exceptions.Timeout: raise RequestTimeoutError() except requests.exceptions.RequestException: raise DatabaseFetchError() if r.status_code == 403: raise InvalidCredentialError(credential=session.get_credential(), reason=r.text) if r.status_code == 429: raise TooManyRequestsError(reason=r.text) if r.status_code >= 500 and r.status_code < 600: raise ServerError(reason=r.reason) try: data = r.json() except json.JSONDecodeError as e: raise MalformedDatabase(reason=e) if cached: LOG.info("Writing %s to cache because cached value was %s", db_name, cached) write_to_cache(db_name, data) return data def fetch_policy(session: requests.Session) -> Dict[str, Any]: """ Fetches the policy from the server. Args: session (requests.Session): The requests session. Returns: Dict[str, Any]: The fetched policy. """ url = f"{DATA_API_BASE_URL}policy/" try: LOG.debug("Getting policy") headers = get_meta_http_headers() r = session.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers) LOG.debug(r.text) return r.json() except Exception: LOG.exception("Error fetching policy") return {"safety_policy": "", "audit_and_monitor": False} def fetch_database_file( path: str, db_name: str, cached: int = 0, ecosystem: Optional[Ecosystem] = None ) -> Dict[str, Any]: """ Fetches the database from a local file. Args: path (str): The path to the local file. db_name (str): The name of the database. cached (int): The cache validity in seconds. ecosystem (Optional[Ecosystem]): The ecosystem. Returns: Dict[str, Any]: The fetched database. """ full_path = ( (Path(path) / (ecosystem.value if ecosystem else "") / db_name) .expanduser() .resolve() ) if not full_path.exists(): raise DatabaseFileNotFoundError(db=path) with open(full_path) as f: data = json.loads(f.read()) if cached: LOG.info("Writing %s to cache because cached value was %s", db_name, cached) write_to_cache(db_name, data) return data def is_valid_database(db: Dict[str, Any]) -> bool: """ Checks if the database is valid. Args: db (Dict[str, Any]): The database. Returns: bool: True if the database is valid, False otherwise. """ try: if db["meta"]["schema_version"] == JSON_SCHEMA_VERSION: return True except Exception: return False return False def fetch_database( session: requests.Session, full: bool = False, db: Union[Optional[str], bool] = False, cached: int = 0, telemetry: bool = True, ecosystem: Optional[Ecosystem] = None, from_cache: bool = True, ) -> Dict[str, Any]: """ Fetches the database from a mirror or a local file. Args: session (requests.Session): The requests session. full (bool): Whether to fetch the full database. db (Optional[str]): The path to the local database file. cached (int): The cache validity in seconds. telemetry (bool): Whether to include telemetry data. ecosystem (Optional[Ecosystem]): The ecosystem. from_cache (bool): Whether to fetch from cache. Returns: Dict[str, Any]: The fetched database. """ if session.is_using_auth_credentials(): mirrors = API_MIRRORS elif db: mirrors = [db] else: mirrors = OPEN_MIRRORS db_name = "insecure_full.json" if full else "insecure.json" for mirror in mirrors: # mirror can either be a local path or a URL if is_a_remote_mirror(mirror): if ecosystem is None: ecosystem = Ecosystem.PYTHON data = fetch_database_url( session, mirror, db_name=db_name, cached=cached, telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache, ) else: data = fetch_database_file( mirror, db_name=db_name, cached=cached, ecosystem=ecosystem ) if data: if is_valid_database(data): return data raise MalformedDatabase( fetched_from=mirror, reason=f"Not supported schema version. " f"This Safety version supports only schema version {JSON_SCHEMA_VERSION}", ) raise DatabaseFetchError() def get_vulnerabilities( pkg: str, spec: str, db: Dict[str, Any] ) -> Iterator[Dict[str, Any]]: """ Retrieves vulnerabilities for a package from the database. Args: pkg (str): The package name. spec (str): The specifier set. db (Dict[str, Any]): The database. Returns: Iterator[Dict[str, Any]]: An iterator of vulnerabilities. """ for entry in db["vulnerable_packages"][pkg]: for entry_spec in entry["specs"]: if entry_spec == spec: yield entry def get_vulnerability_from( vuln_id: str, cve: Optional[CVE], data: Dict[str, Any], specifier: str, db: Dict[str, Any], name: str, pkg: Package, ignore_vulns: Dict[str, Any], affected: SafetyRequirement, ) -> Vulnerability: """ Constructs a Vulnerability object from the provided data. Args: vuln_id (str): The vulnerability ID. cve (Optional[CVE]): The CVE object. data (Dict[str, Any]): The vulnerability data. specifier (str): The specifier set. db (Dict[str, Any]): The database. name (str): The package name. pkg (Package): The Package object. ignore_vulns (Dict[str, Any]): The ignored vulnerabilities. affected (SafetyRequirement): The affected requirement. Returns: Vulnerability: The constructed Vulnerability object. """ base_domain = db.get("meta", {}).get("base_domain") unpinned_ignored = ignore_vulns.get(vuln_id, {}).get("requirements", None) should_ignore = not unpinned_ignored or str(affected.specifier) in unpinned_ignored ignored = ( ignore_vulns and vuln_id in ignore_vulns and should_ignore and ( not ignore_vulns[vuln_id]["expires"] or ignore_vulns[vuln_id]["expires"] > datetime.utcnow() ) ) more_info_url = f"{base_domain}{data.get('more_info_path', '')}" severity = None if cve and (cve.cvssv2 or cve.cvssv3): severity = Severity(source=cve.name, cvssv2=cve.cvssv2, cvssv3=cve.cvssv3) analyzed_requirement = affected analyzed_version = ( next(iter(analyzed_requirement.specifier)).version if is_pinned_requirement(analyzed_requirement.specifier) else None ) vulnerable_spec = set() vulnerable_spec.add(specifier) return Vulnerability( vulnerability_id=vuln_id, package_name=name, pkg=pkg, ignored=ignored, ignored_reason=ignore_vulns.get(vuln_id, {}).get("reason", None) if ignore_vulns else None, ignored_expires=ignore_vulns.get(vuln_id, {}).get("expires", None) if ignore_vulns else None, vulnerable_spec=vulnerable_spec, all_vulnerable_specs=data.get("specs", []), analyzed_version=analyzed_version, analyzed_requirement=analyzed_requirement, advisory=data.get("advisory"), is_transitive=data.get("transitive", False), published_date=data.get("published_date"), fixed_versions=[ver for ver in data.get("fixed_versions", []) if ver], closest_versions_without_known_vulnerabilities=data.get( "closest_secure_versions", [] ), resources=data.get("vulnerability_resources"), CVE=cve, severity=severity, affected_versions=data.get("affected_versions", []), more_info_url=more_info_url, ) def get_cve_from(data: Dict[str, Any], db_full: Dict[str, Any]) -> Optional[CVE]: """ Retrieves the CVE object from the provided data. Args: data (Dict[str, Any]): The vulnerability data. db_full (Dict[str, Any]): The full database. Returns: Optional[CVE]: The CVE object if found, otherwise None. """ try: xve_id: str = str( next( filter( lambda i: i.get("type", None) in ["cve", "pve"], data.get("ids", []) ) ).get("id", "") ) except StopIteration: xve_id: str = "" if not xve_id: return None cve_meta = db_full.get("meta", {}).get("severities", {}).get(xve_id, {}) return CVE( name=xve_id, cvssv2=cve_meta.get("cvssv2", None), cvssv3=cve_meta.get("cvssv3", None), ) def ignore_vuln_if_needed( pkg: Package, vuln_id: str, cve: Optional[CVE], ignore_vulns: Dict[str, Any], ignore_severity_rules: Dict[str, Any], req: SafetyRequirement, ) -> None: """ Determines if a vulnerability should be ignored based on severity rules and updates the ignore_vulns dictionary. Args: pkg (Package): The package. vuln_id (str): The vulnerability ID. cve (Optional[CVE]): The CVE object. ignore_vulns (Dict[str, Any]): The ignored vulnerabilities. ignore_severity_rules (Dict[str, Any]): The severity rules for ignoring vulnerabilities. req (SafetyRequirement): The affected requirement. """ if not ignore_severity_rules: ignore_severity_rules = {} if not isinstance(ignore_vulns, dict): return severity = None if cve: if cve.cvssv2 and cve.cvssv2.get("base_score", None): severity = cve.cvssv2.get("base_score", None) if cve.cvssv3 and cve.cvssv3.get("base_score", None): severity = cve.cvssv3.get("base_score", None) ignore_severity_below = float( ignore_severity_rules.get("ignore-cvss-severity-below", 0.0) ) ignore_unknown_severity = bool( ignore_severity_rules.get("ignore-cvss-unknown-severity", False) ) if severity: if float(severity) < ignore_severity_below: reason = "Ignored by severity rule in policy file, {0} < {1}".format( float(severity), ignore_severity_below ) ignore_vulns[vuln_id] = {"reason": reason, "expires": None} elif ignore_unknown_severity: reason = "Unknown CVSS severity, ignored by severity rule in policy file." ignore_vulns[vuln_id] = {"reason": reason, "expires": None} version = ( next(iter(req.specifier)).version if is_pinned_requirement(req.specifier) else pkg.version ) is_prev_not_ignored: bool = vuln_id not in ignore_vulns is_req_not_ignored: bool = "requirements" in ignore_vulns.get(vuln_id, {}) and str( req.specifier ) not in ignore_vulns.get(vuln_id, {}).get("requirements", set()) if (is_prev_not_ignored or is_req_not_ignored) and is_ignore_unpinned_mode(version): reason = IGNORE_UNPINNED_REQ_REASON requirements = set() requirements.add(str(req.specifier)) ignore_vulns[vuln_id] = { "reason": reason, "expires": None, "requirements": requirements, } def is_vulnerable( vulnerable_spec: SpecifierSet, requirement: SafetyRequirement, package: Package ) -> bool: """ Checks if a package version is vulnerable. Args: vulnerable_spec (SpecifierSet): The specifier set for vulnerable versions. requirement (SafetyRequirement): The package requirement. package (Package): The package. Returns: bool: True if the package version is vulnerable, False otherwise. """ if is_pinned_requirement(requirement.specifier): try: return vulnerable_spec.contains(next(iter(requirement.specifier)).version) except Exception: # Ugly for now... message = f"Version {requirement.specifier} for {package.name} is invalid and is ignored by Safety. Please See PEP 440." if message not in [a["message"] for a in SafetyContext.local_announcements]: SafetyContext.local_announcements.append( {"message": message, "type": "warning", "local": True} ) return False return any( requirement.specifier.filter( vulnerable_spec.filter(package.insecure_versions, prereleases=True), prereleases=True, ) ) @sync_safety_context def check( *, session: requests.Session, packages: List[Package] = [], db_mirror: Union[Optional[str], bool] = False, cached: int = 0, ignore_vulns: Optional[Dict[str, Any]] = None, ignore_severity_rules: Optional[Dict[str, Any]] = None, proxy: Optional[Dict[str, Any]] = None, include_ignored: bool = False, is_env_scan: bool = True, telemetry: bool = True, params: Optional[Dict[str, Any]] = None, project: Optional[str] = None, ) -> tuple: """ Performs a vulnerability check on the provided packages. Args: session (requests.Session): The requests session. packages (List[Package]): The list of packages to check. db_mirror (Union[Optional[str], bool]): The database mirror. cached (int): The cache validity in seconds. ignore_vulns (Optional[Dict[str, Any]]): The ignored vulnerabilities. ignore_severity_rules (Optional[Dict[str, Any]]): The severity rules for ignoring vulnerabilities. proxy (Optional[Dict[str, Any]]): The proxy settings. include_ignored (bool): Whether to include ignored vulnerabilities. is_env_scan (bool): Whether it is an environment scan. telemetry (bool): Whether to include telemetry data. params (Optional[Dict[str, Any]]): Additional parameters. project (Optional[str]): The project name. Returns: tuple: A tuple containing the list of vulnerabilities and the full database. """ SafetyContext().command = "check" db = fetch_database(session, db=db_mirror, cached=cached, telemetry=telemetry) db_full = None vulnerable_packages = frozenset(db.get("vulnerable_packages", [])) vulnerabilities = [] found_pkgs = {} requirements = iter([]) for p in packages: requirements = itertools.chain(p.requirements, requirements) found_pkgs[canonicalize_name(p.name)] = p # Let's report by req, pinned in environment will be ==version for req in requirements: vuln_per_req = {} name = canonicalize_name(req.name) pkg = found_pkgs.get(name, None) if not pkg.version: if not db_full: db_full = fetch_database( session, full=True, db=db_mirror, cached=cached, telemetry=telemetry ) pkg.refresh_from(db_full) if name in vulnerable_packages: # we have a candidate here, build the spec set for specifier in db["vulnerable_packages"][name]: spec_set = SpecifierSet(specifiers=specifier) if is_vulnerable(spec_set, req, pkg): if not db_full: db_full = fetch_database( session, full=True, db=db_mirror, cached=cached, telemetry=telemetry, ) if not pkg.latest_version: pkg.refresh_from(db_full) for data in get_vulnerabilities( pkg=name, spec=specifier, db=db_full ): try: vuln_id: str = str( next( filter( lambda i: i.get("type", None) == "pyup", data.get("ids", []), ) ).get("id", "") ) except StopIteration: vuln_id: str = "" if vuln_id in vuln_per_req: vuln_per_req[vuln_id].vulnerable_spec.add(specifier) continue cve = get_cve_from(data, db_full) ignore_vuln_if_needed( pkg, vuln_id, cve, ignore_vulns, ignore_severity_rules, req ) vulnerability = get_vulnerability_from( vuln_id, cve, data, specifier, db_full, name, pkg, ignore_vulns, req, ) should_add_vuln = not ( vulnerability.is_transitive and is_env_scan ) if ( include_ignored or vulnerability.vulnerability_id not in ignore_vulns ) and should_add_vuln: vuln_per_req[vulnerability.vulnerability_id] = vulnerability vulnerabilities.append(vulnerability) return vulnerabilities, db_full def precompute_remediations( remediations: Dict[str, Dict[str, Any]], packages: Dict[str, Package], vulns: List[Vulnerability], secure_vulns_by_user: set, ) -> None: """ Precomputes the remediations for the given vulnerabilities. Args: remediations (Dict[str, Dict[str, Any]]): The remediations dictionary. packages (Dict[str, Package]): The packages dictionary. vulns (List[Vulnerability]): The list of vulnerabilities. secure_vulns_by_user (set): The set of vulnerabilities secured by the user. """ for vuln in vulns: if vuln.ignored and vuln.ignored_reason != IGNORE_UNPINNED_REQ_REASON: secure_vulns_by_user.add(vuln.vulnerability_id) continue if ( vuln.package_name in remediations.keys() and str(vuln.analyzed_requirement.specifier) in remediations[vuln.package_name] ): spec = remediations[vuln.package_name][ str(vuln.analyzed_requirement.specifier) ] spec["vulnerabilities_found"] = spec.get("vulnerabilities_found", 0) + 1 else: version = None is_pinned = is_pinned_requirement(vuln.analyzed_requirement.specifier) if is_pinned: version = next(iter(vuln.analyzed_requirement.specifier)).version if not is_pinned and is_ignore_unpinned_mode(version): # Let's ignore this requirement continue vulns_count = 1 packages[vuln.package_name] = vuln.pkg remediations[vuln.package_name][ str(vuln.analyzed_requirement.specifier) ] = { "vulnerabilities_found": vulns_count, "version": version, "requirement": vuln.analyzed_requirement, "more_info_url": vuln.pkg.more_info_url, } def get_closest_ver( versions: List[str], version: Optional[str], spec: SpecifierSet ) -> Dict[str, Optional[Union[str, Version]]]: """ Retrieves the closest versions for the given version and specifier set. Args: versions (List[str]): The list of versions. version (Optional[str]): The current version. spec (SpecifierSet): The specifier set. Returns: Dict[str, Optional[Union[str, Version]]]: The closest versions. """ results: Dict[str, Optional[Union[str, Version]]] = {"upper": None, "lower": None} if (not version and not spec) or not versions: return results sorted_versions = sorted(versions, key=lambda ver: parse_version(ver), reverse=True) if not version: sorted_versions = spec.filter(sorted_versions, prereleases=False) upper = None lower = None try: sorted_versions = list(sorted_versions) upper = sorted_versions[0] lower = sorted_versions[-1] results["upper"] = upper results["lower"] = lower if upper != lower else None except IndexError: pass return results current_v = parse_version(version) for v in sorted_versions: index = parse_version(v) if index > current_v: results["upper"] = index if index < current_v: results["lower"] = index break return results def compute_sec_ver_for_user( package: Package, secure_vulns_by_user: set, db_full: Dict[str, Any] ) -> List[str]: """ Computes the secure versions for the user. Args: package (Package): The package. secure_vulns_by_user (set): The set of vulnerabilities secured by the user. db_full (Dict[str, Any]): The full database. Returns: List[str]: The list of secure versions. """ versions = package.get_versions(db_full) affected_versions = [] for vuln in db_full.get("vulnerable_packages", {}).get(package.name, []): vuln_id: str = str( next( filter(lambda i: i.get("type", None) == "pyup", vuln.get("ids", [])) ).get("id", "") ) if vuln_id and vuln_id not in secure_vulns_by_user: affected_versions += vuln.get("affected_versions", []) affected_v = set(affected_versions) sec_ver_for_user = list(versions.difference(affected_v)) return sorted(sec_ver_for_user, key=lambda ver: parse_version(ver), reverse=True) def compute_sec_ver( remediations: Dict[str, Dict[str, Any]], packages: Dict[str, Package], secure_vulns_by_user: set, db_full: Dict[str, Any], ) -> None: """ Computes the secure versions and the closest secure version for each remediation. Uses the affected_versions of each no ignored vulnerability of the same package, there is only a remediation for each package. Args: remediations (Dict[str, Dict[str, Any]]): The remediations dictionary. packages (Dict[str, Package]): The packages dictionary. secure_vulns_by_user (set): The set of vulnerabilities secured by the user. db_full (Dict[str, Any]): The full database. """ for pkg_name in remediations.keys(): pkg: Package = packages.get(pkg_name, None) secure_versions = [] if pkg: secure_versions = pkg.secure_versions analyzed = set(remediations[pkg_name].keys()) if not is_using_api_key(): continue for analyzed_requirement in analyzed: rem = remediations[pkg_name][analyzed_requirement] spec = rem.get("requirement").specifier version = rem["version"] if not secure_vulns_by_user: secure_v = sorted( secure_versions, key=lambda ver: parse_version(ver), reverse=True ) else: secure_v = compute_sec_ver_for_user( package=pkg, secure_vulns_by_user=secure_vulns_by_user, db_full=db_full, ) rem["closest_secure_version"] = get_closest_ver(secure_v, version, spec) upgrade = rem["closest_secure_version"].get("upper", None) downgrade = rem["closest_secure_version"].get("lower", None) recommended_version = None if upgrade: recommended_version = upgrade elif downgrade: recommended_version = downgrade rem["recommended_version"] = recommended_version rem["other_recommended_versions"] = [ other_v for other_v in secure_v if other_v != str(recommended_version) ] # Refresh the URL with the recommended version. spec = str(rem["requirement"].specifier) if is_using_api_key(): rem["more_info_url"] = build_remediation_info_url( base_url=rem["more_info_url"], version=version, spec=spec, target_version=recommended_version, ) def calculate_remediations( vulns: List[Vulnerability], db_full: Dict[str, Any] ) -> Dict[str, Dict[str, Any]]: """ Calculates the remediations for the given vulnerabilities. Args: vulns (List[Vulnerability]): The list of vulnerabilities. db_full (Dict[str, Any]): The full database. Returns: Dict[str, Dict[str, Any]]: The calculated remediations. """ remediations = defaultdict(dict) package_metadata = {} secure_vulns_by_user = set() if not db_full: return remediations precompute_remediations(remediations, package_metadata, vulns, secure_vulns_by_user) compute_sec_ver(remediations, package_metadata, secure_vulns_by_user, db_full) return remediations def should_apply_auto_fix( from_ver: Optional[Version], to_ver: Version, allowed_automatic: List[str] ) -> bool: """ Determines if an automatic fix should be applied. Args: from_ver (Optional[Version]): The current version. to_ver (Version): The target version. allowed_automatic (List[str]): The allowed automatic update types. Returns: bool: True if an automatic fix should be applied, False otherwise. """ if not from_ver: return False if "major" in allowed_automatic: return True major_change = to_ver.major - from_ver.major minor_change = to_ver.minor - from_ver.minor if "minor" in allowed_automatic: if major_change != 0: return False return True if "patch" in allowed_automatic: if major_change != 0 or minor_change != 0: return False return True return False def get_update_type(from_ver: Optional[Version], to_ver: Version) -> str: """ Determines the update type. Args: from_ver (Optional[Version]): The current version. to_ver (Version): The target version. Returns: str: The update type. """ if not from_ver or (to_ver.major - from_ver.major) != 0: return "major" if (to_ver.minor - from_ver.minor) != 0: return "minor" return "patch" def process_fixes( files: List[str], remediations: Dict[str, Dict[str, Any]], auto_remediation_limit: List[str], output: str, no_output: bool = True, prompt: bool = False, ) -> List[Fix]: """ Processes the fixes for the given files and remediations. Args: files (List[str]): The list of files. remediations (Dict[str, Dict[str, Any]]): The remediations dictionary. auto_remediation_limit (List[str]): The automatic remediation limits. output (str): The output format. no_output (bool): Whether to suppress output. prompt (bool): Whether to prompt for confirmation. Returns: List[Fix]: The list of applied fixes. """ req_remediations = itertools.chain.from_iterable( rem.values() for pkg_name, rem in remediations.items() ) requirements = compute_fixes_per_requirements( files, req_remediations, auto_remediation_limit, prompt=prompt ) fixes = apply_fixes(requirements, output, no_output, prompt) return fixes def process_fixes_scan( file_to_fix: SafetyPolicyFile, to_fix_spec: List[SafetyRequirement], auto_remediation_limit: List[str], output: str, no_output: bool = True, prompt: bool = False, ) -> List[Fix]: """ Processes the fixes for the given file and specifications in scan mode. Args: file_to_fix (SafetyPolicyFile): The file to fix. to_fix_spec (List[SafetyRequirement]): The specifications to fix. auto_remediation_limit (List[str]): The automatic remediation limits. output (str): The output format. no_output (bool): Whether to suppress output. prompt (bool): Whether to prompt for confirmation. Returns: List[Fix]: The list of applied fixes. """ def get_remmediation_from(spec): upper = None lower = None recommended = None try: upper = ( Version(spec.remediation.closest_secure.upper) if spec.remediation.closest_secure.upper else None ) except Exception: LOG.error( "Error getting upper remediation version, ignoring", exc_info=True ) try: lower = ( Version(spec.remediation.closest_secure.lower) if spec.remediation.closest_secure.lower else None ) except Exception: LOG.error( "Error getting lower remediation version, ignoring", exc_info=True ) try: recommended = Version(spec.remediation.recommended) except Exception: LOG.error( "Error getting recommended version for remediation, ignoring", exc_info=True, ) return { "vulnerabilities_found": spec.remediation.vulnerabilities_found, "version": next(iter(spec.specifier)).version if spec.is_pinned() else None, "requirement": spec, "more_info_url": spec.remediation.more_info_url, "closest_secure_version": {"upper": upper, "lower": lower}, "recommended_version": recommended, "other_recommended_versions": spec.remediation.other_recommended, } req_remediations = iter(get_remmediation_from(spec) for spec in to_fix_spec) SUPPORTED_FILE_TYPES = [FileType.REQUIREMENTS_TXT] if file_to_fix.file_type in SUPPORTED_FILE_TYPES: files = (open(file_to_fix.location),) requirements = compute_fixes_per_requirements( files, req_remediations, auto_remediation_limit, prompt=prompt ) else: requirements = { "files": { str(file_to_fix.location): { "content": None, "fixes": {"TO_SKIP": [], "TO_APPLY": [], "TO_CONFIRM": []}, "supported": False, "filename": file_to_fix.location.name, } }, "dependencies": defaultdict(dict), } fixes = apply_fixes( requirements, output, no_output, prompt, scan_flow=True, auto_remediation_limit=auto_remediation_limit, ) return fixes def compute_fixes_per_requirements( files: List[str], req_remediations: Iterator[Dict[str, Any]], auto_remediation_limit: List[str], prompt: bool = False, ) -> Dict[str, Any]: """ Computes the fixes per requirements. Args: files (List[str]): The list of files. req_remediations (Iterator[Dict[str, Any]]): The remediations iterator. auto_remediation_limit (List[str]): The automatic remediation limits. prompt (bool): Whether to prompt for confirmation. Returns: Dict[str, Any]: The computed requirements with fixes. """ requirements_files = get_requirements_content(files) from dparse.parser import parse, filetypes from packaging.version import Version, InvalidVersion requirements = { "files": {}, "dependencies": defaultdict(dict), } for name, contents in requirements_files.items(): dependency_file = parse( contents, path=name, file_type=filetypes.requirements_txt, resolve=True ) dependency_files = dependency_file.resolved_files + [dependency_file] empty_spec = SpecifierSet("") default_spec = SpecifierSet(">=0") # Support recursive requirements in the multiple requirement files provided for resolved_f in dependency_files: if not resolved_f or isinstance(resolved_f, str): continue file = { "content": resolved_f.content, "fixes": {"TO_SKIP": [], "TO_APPLY": [], "TO_CONFIRM": []}, } requirements["files"][resolved_f.path] = file for d in resolved_f.dependencies: if d.specs == empty_spec: d.specs = default_spec requirements["dependencies"][d.name][str(d.specs)] = ( d, resolved_f.path, ) for remediation in req_remediations: req: SafetyRequirement = remediation.get("requirement") pkg: str = req.name dry_fix = Fix( package=pkg, more_info_url=remediation.get("more_info_url", ""), previous_spec=str(req.specifier), other_options=remediation.get("other_recommended_versions", []), ) from_ver: Optional[str] = remediation.get("version", None) if ( pkg not in requirements["dependencies"] or dry_fix.previous_spec not in requirements["dependencies"][pkg] ): # Let's attach it to the first file scanned. file = next(iter(requirements["files"])) # Let's use the no parsed version. dry_fix.previous_version = from_ver dry_fix.status = "AUTOMATICALLY_SKIPPED_NOT_FOUND_IN_FILE" dry_fix.applied_at = file requirements["files"][file]["fixes"]["TO_SKIP"].append(dry_fix) continue dependency, name = requirements["dependencies"][pkg][dry_fix.previous_spec] dry_fix.applied_at = name fixes = requirements["files"][name]["fixes"] to_ver: Version = remediation["recommended_version"] try: from_ver = parse_version(from_ver) except (InvalidVersion, TypeError): if not dry_fix.previous_spec: dry_fix.status = "AUTOMATICALLY_SKIPPED_INVALID_VERSION" fixes["TO_SKIP"].append(dry_fix) continue dry_fix.previous_version = str(from_ver) if from_ver else from_ver if remediation["recommended_version"] is None: dry_fix.status = "AUTOMATICALLY_SKIPPED_NO_RECOMMENDED_VERSION" fixes["TO_SKIP"].append(dry_fix) continue dry_fix.updated_version = str(to_ver) is_fixed = from_ver == to_ver if is_fixed: dry_fix.status = "AUTOMATICALLY_SKIPPED_ALREADY_FIXED" fixes["TO_SKIP"].append(dry_fix) continue update_type = get_update_type(from_ver, to_ver) dry_fix.update_type = update_type dry_fix.dependency = dependency auto_fix = should_apply_auto_fix(from_ver, to_ver, auto_remediation_limit) TARGET = "TO_APPLY" if auto_fix: dry_fix.status = "PENDING_TO_APPLY" dry_fix.fix_type = "AUTOMATIC" elif prompt: TARGET = "TO_CONFIRM" dry_fix.status = "PENDING_TO_CONFIRM" dry_fix.fix_type = "MANUAL" else: TARGET = "TO_SKIP" dry_fix.status = "AUTOMATICALLY_SKIPPED_UNABLE_TO_CONFIRM" fixes[TARGET].append(dry_fix) return requirements def apply_fixes( requirements: Dict[str, Any], out_type: str, no_output: bool, prompt: bool, scan_flow: bool = False, auto_remediation_limit: List[str] = None, ) -> List[Fix]: """ Applies the fixes to the requirements. Args: requirements (Dict[str, Any]): The requirements with fixes. out_type (str): The output format. no_output (bool): Whether to suppress output. prompt (bool): Whether to prompt for confirmation. scan_flow (bool): Whether it is in scan flow mode. auto_remediation_limit (List[str]): The automatic remediation limits. Returns: List[Fix]: The list of applied fixes. """ from dparse.updater import RequirementsTXTUpdater skip = [] apply = [] confirm = [] brief = [] if not no_output: style_kwargs = {} if not scan_flow: brief.append(("", {})) brief.append(("Safety fix running", style_kwargs)) print_service(brief, out_type) for name, data in requirements["files"].items(): output = [ ("", {}), ( f"Analyzing {name}... [{get_fix_opt_used_msg(auto_remediation_limit)} limit]", { "styling": {"bold": True}, "start_line_decorator": "->", "indent": " ", }, ), ] r_skip = data["fixes"]["TO_SKIP"] r_apply = data["fixes"]["TO_APPLY"] r_confirm = data["fixes"]["TO_CONFIRM"] if data.get("supported", True): new_content = data["content"] updated: bool = False for f in r_apply: new_content = RequirementsTXTUpdater.update( content=new_content, version=f.updated_version, dependency=f.dependency, hashes=get_hashes(f.dependency), ) f.status = "APPLIED" updated = True output.append(("", {})) output.append((f"- {get_applied_msg(f, mode='auto')}", {})) for f in r_skip: output.append(("", {})) output.append((f"- {get_skipped_msg(f)}", {})) if not no_output: print_service(output, out_type) if prompt and not no_output: for f in r_confirm: options = [ f"({index}) =={option}" for index, option in enumerate(f.other_options) ] input_hint = f"Enter ā€œyā€ to update to {f.package}=={f.updated_version}, ā€œnā€ to skip this package upgrade" if len(options) > 0: input_hint += f", or enter the index from these secure versions to upgrade to that version: {', '.join(options)}" print_service([("", {})], out_type) confirmed: str = prompt_service( ( f"- {f.package}{f.previous_spec} requires at least a {f.update_type} version update. Do you want to update {f.package} from {f.previous_spec} to =={f.updated_version}, which is the closest secure version? {input_hint}", {}, ), out_type, ).lower() try: index: int = int(confirmed) if index <= len(f.other_options): confirmed = "y" except ValueError: index = -1 if confirmed == "y" or index > -1: f.status = "APPLIED" updated = True if index > -1: f.updated_version = f.other_options[index] new_content = RequirementsTXTUpdater.update( content=new_content, version=f.updated_version, dependency=f.dependency, hashes=get_hashes(f.dependency), ) output.append( (get_applied_msg(f, mode="manual"), {"indent": " " * 5}) ) else: f.status = "MANUALLY_SKIPPED" output.append((get_skipped_msg(f), {"indent": " " * 5})) if not no_output: print_service(output, out_type) if updated: output.append(("", {})) output.append((f"Updating {name}...", {})) with open(name, mode="w") as r_file: r_file.write(new_content) output.append((f"Changes applied to {name}.", {})) count = len(r_apply) + len( [1 for fix in r_confirm if fix.status == "APPLIED"] ) output.append( ( f"{count} package {pluralize('version', count)} {pluralize('has', count)} been updated to secure versions in {Path(name).name}", {}, ) ) output.append( ("Always check for breaking changes after updating packages.", {}) ) else: output.append((f"No fixes to be made in {name}.", {})) output.append(("", {})) else: not_supported_filename = data.get("filename", name) output.append( ( f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.", {"start_line_decorator": " -", "indent": " "}, ) ) output.append(("", {})) if not no_output: print_service(output, out_type) skip.extend(r_skip) apply.extend(r_apply) confirm.extend(r_confirm) # The scan flow will handle the header and divider, because the scan flow can be called multiple times. if not no_output and not scan_flow: divider = ( f"{'=' * 78}" if out_type == "text" else f"{'=' * (get_terminal_size().columns - 2)}" ) format_text = { "start_line_decorator": "+", "end_line_decorator": "+", "indent": "", } print_service([(divider, {})], out_type, format_text=format_text) return skip + apply + confirm def find_vulnerabilities_fixed( vulnerabilities: Dict[str, Any], fixes: List[Fix] ) -> List[Vulnerability]: """ Finds the vulnerabilities that have been fixed. Args: vulnerabilities (Dict[str, Any]): The dictionary of vulnerabilities. fixes (List[Fix]): The list of applied fixes. Returns: List[Vulnerability]: The list of fixed vulnerabilities. """ fixed_specs = set(fix.previous_spec for fix in fixes) if not fixed_specs: return [] return [ vulnerability for vulnerability in vulnerabilities if str(vulnerability["analyzed_requirement"].specifier) in fixed_specs ] @sync_safety_context def review( *, report: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None ) -> tuple: """ Reviews the report and returns the vulnerabilities and remediations. Args: report (Optional[Dict[str, Any]]): The report. params (Optional[Dict[str, Any]]): Additional parameters. Returns: tuple: A tuple containing the list of vulnerabilities, the remediations, and the found packages. """ SafetyContext().command = "review" vulnerable = [] vulnerabilities = report.get("vulnerabilities", []) + report.get( "ignored_vulnerabilities", [] ) remediations = defaultdict(dict) for key, pkg_rem in report.get("remediations", {}).items(): analyzed = set(pkg_rem["requirements"].keys()) for req in analyzed: req_rem = pkg_rem["requirements"][req] recommended = req_rem.get("recommended_version", None) secure_v = req_rem.get("other_recommended_versions", []) remediations[key][req] = { "vulnerabilities_found": req_rem.get("vulnerabilities_found", 0), "version": req_rem.get("version"), "requirement": SafetyRequirement(req_rem["requirement"]["raw"]), "other_recommended_versions": secure_v, "recommended_version": parse_version(recommended) if recommended else None, # minor isn't supported in review "more_info_url": req_rem.get("more_info_url"), } packages = report.get("scanned_packages", []) pkgs = {} for name, values in packages.items(): requirements = [ SafetyRequirement(r["raw"]) for r in values.get("requirements", []) ] values.update({"requirements": requirements}) pkgs[name] = Package(**values) ctx = SafetyContext() found_packages = list(pkgs.values()) ctx.packages = found_packages ctx.review = report.get("report_meta", []) ctx.key = ctx.review.get("api_key", False) cvssv2 = None cvssv3 = None for vuln in vulnerabilities: vuln["pkg"] = pkgs.get(vuln.get("package_name", None)) XVE_ID = vuln.get("CVE", None) # Trying to get first the CVE ID severity = vuln.get("severity", None) if severity and severity.get("source", False): cvssv2 = severity.get("cvssv2", None) cvssv3 = severity.get("cvssv3", None) # Trying to get the PVE ID if it exists, otherwise it will be the same CVE ID of above XVE_ID = severity.get("source", False) vuln["severity"] = Severity(source=XVE_ID, cvssv2=cvssv2, cvssv3=cvssv3) else: vuln["severity"] = None ignored_expires = vuln.get("ignored_expires", None) if ignored_expires: vuln["ignored_expires"] = validate_expiration_date(ignored_expires) vuln["CVE"] = CVE(name=XVE_ID, cvssv2=cvssv2, cvssv3=cvssv3) if XVE_ID else None vuln["analyzed_requirement"] = SafetyRequirement( vuln["analyzed_requirement"]["raw"] ) vulnerable.append(Vulnerability(**vuln)) return vulnerable, remediations, found_packages @sync_safety_context def get_licenses( *, session: requests.Session, db_mirror: Union[Optional[str], bool] = False, cached: int = 0, telemetry: bool = True, ) -> Dict[str, Any]: """ Retrieves the licenses from the database. Args: session (requests.Session): The requests session. db_mirror (Union[Optional[str], bool]): The database mirror. cached (int): The cache validity in seconds. telemetry (bool): Whether to include telemetry data. Returns: Dict[str, Any]: The licenses dictionary. """ if db_mirror: mirrors = [db_mirror] else: mirrors = API_MIRRORS db_name = "licenses.json" for mirror in mirrors: # mirror can either be a local path or a URL if is_a_remote_mirror(mirror): licenses = fetch_database_url( session, mirror, db_name=db_name, cached=cached, telemetry=telemetry ) else: licenses = fetch_database_file(mirror, db_name=db_name, ecosystem=None) if licenses: return licenses raise DatabaseFetchError() def add_local_notifications( packages: List[Package], ignore_unpinned_requirements: Optional[bool] ) -> List[Dict[str, str]]: """ Adds local notifications for unpinned packages. Args: packages (List[Package]): The list of packages. ignore_unpinned_requirements (Optional[bool]): Whether to ignore unpinned requirements. Returns: List[Dict[str, str]]: The list of notifications. """ announcements = [] unpinned_packages: List[str] = [ f"{pkg.name}" for pkg in packages if pkg.has_unpinned_req() ] if unpinned_packages and ignore_unpinned_requirements is not False: found = len(unpinned_packages) and_msg = "" if found >= 2: last = unpinned_packages.pop() and_msg = f" and {last}" pkgs: str = ( f"{', '.join(unpinned_packages)}{and_msg} {'are' if found > 1 else 'is'}" ) doc_msg: str = get_specifier_range_info(style=False, pin_hint=True) if ignore_unpinned_requirements is None: msg = ( f"Warning: {pkgs} unpinned. Safety by default does not " f"report on potential vulnerabilities in unpinned packages. {doc_msg}" ) else: msg = ( f"Warning: {pkgs} unpinned and potential vulnerabilities are " f"being ignored given `ignore-unpinned-requirements` is True in your config. {doc_msg}" ) announcements.append({"message": msg, "type": "warning", "local": True}) announcements.extend(SafetyContext().local_announcements) return announcements def get_announcements( session: requests.Session, telemetry: bool = True, with_telemetry: Any = None ) -> List[Dict[str, str]]: """ Retrieves announcements from the server. Args: session (requests.Session): The requests session. telemetry (bool): Whether to include telemetry data. with_telemetry (Optional[Dict[str, Any]]): The telemetry data. Returns: List[Dict[str, str]]: The list of announcements. """ LOG.info("Getting announcements") announcements = [] url = f"{DATA_API_BASE_URL}announcements/" method = "post" telemetry_data = ( with_telemetry if with_telemetry else build_telemetry_data(telemetry=telemetry) ) data = asdict(telemetry_data) request_kwargs = {"timeout": 3} data_keyword = "json" source = os.environ.get("SAFETY_ANNOUNCEMENTS_URL", None) if source: LOG.debug(f"Getting the announcement from a different source: {source}") url = source method = "get" data = {"telemetry": json.dumps(data)} data_keyword = "params" request_kwargs[data_keyword] = data request_kwargs["url"] = url request_kwargs["headers"] = get_meta_http_headers() LOG.debug(f"Telemetry data sent: {data}") try: request_func = getattr(session, method) r = request_func(**request_kwargs) LOG.debug(r.text) except Exception as e: LOG.info( "Unexpected but HANDLED Exception happened getting the announcements: %s", e ) return announcements if r.status_code == 200: try: announcements = r.json() if "announcements" in announcements.keys(): announcements = announcements.get("announcements", []) else: LOG.info( "There is not announcements key in the JSON response, is this a wrong structure?" ) announcements = [] except json.JSONDecodeError as e: LOG.info( "Unexpected but HANDLED Exception happened decoding the announcement response: %s", e, ) LOG.info("Announcements fetched") return announcements def get_distribution_location(distribution): """ Get the installation location of a distribution. """ if hasattr(distribution, "_path") and distribution._path: return str(distribution._path) # Fallback: try to get location from metadata try: location = distribution.locate_file("") if location: return str(location) except (AttributeError, Exception): pass return "" def get_distribution_version(distribution): """ Safely get the version of a distribution (Python 3.9+ compatible). """ try: return distribution.version except AttributeError: return distribution.metadata.get("Version", "") def get_distribution_name(distribution): """ Safely get the name of a distribution (Python 3.9+ compatible). """ try: # For Python 3.10+ return distribution.name except AttributeError: # Fallback for Python 3.9 PathDistribution objects return distribution.metadata.get("Name", "") def get_packages( files: Optional[List[str]] = None, stdin: bool = False ) -> List[Package]: """ Retrieves the packages from the given files or standard input. Args: files (Optional[List[str]]): The list of files. stdin (bool): Whether to read from standard input. Returns: List[Package]: The list of packages. """ if files: return list( itertools.chain.from_iterable( read_requirements(f, resolve=True) for f in files ) ) if stdin: return list(read_requirements(sys.stdin)) # Migrated from pkg_resources to importlib.metadata import importlib.metadata def allowed_version(pkg: str, version: str) -> bool: try: parse_version(version) except Exception: SafetyContext.local_announcements.append( { "message": f"Version {version} for {pkg} is invalid and is ignored by Safety. Please See PEP 440.", "type": "warning", "local": True, } ) return False return True distributions = list(importlib.metadata.distributions()) paths = { str(dist._path.parent) for dist in distributions if hasattr(dist, "_path") and dist._path } SafetyContext().scanned_full_path.extend(paths) packages = [] for d in distributions: name = get_distribution_name(d) version = get_distribution_version(d) # Skip if name or version couldn't be determined if not name or not version: continue # Skip excluded packages if name in {"python", "wsgiref", "argparse"} or not allowed_version( name, version ): continue location = get_distribution_location(d) packages.append( Package( name=name, version=version, requirements=[SafetyRequirement(f"{name}=={version}", found=location)], found=location, insecure_versions=[], secure_versions=[], latest_version=None, latest_version_without_known_vulnerabilities=None, more_info_url=None, ) ) return packages def read_vulnerabilities(fh: Any) -> Dict[str, Any]: """ Reads vulnerabilities from a file handle. Args: fh (Any): The file handle. Returns: Dict[str, Any]: The vulnerabilities data. """ try: data = json.load(fh) except json.JSONDecodeError as e: raise MalformedDatabase(reason=e, fetched_from=fh.name) except TypeError as e: raise MalformedDatabase(reason=e, fetched_from=fh.name) return data def get_server_policies( session: requests.Session, policy_file: SafetyPolicyFile, proxy_dictionary: Dict[str, str], ) -> tuple: """ Retrieves the server policies. Args: session (requests.Session): The requests session. policy_file (SafetyPolicyFile): The policy file. proxy_dictionary (Dict[str, str]): The proxy dictionary. Returns: tuple: A tuple containing the policy file and the audit and monitor flag. """ if session.api_key: server_policies = fetch_policy(session) server_audit_and_monitor = server_policies["audit_and_monitor"] server_safety_policy = server_policies["safety_policy"] else: server_audit_and_monitor = False server_safety_policy = "" if server_safety_policy and policy_file: click.secho( "Warning: both a local policy file '{policy_filename}' and a server sent policy are present. " "Continuing with the local policy file.".format( policy_filename=policy_file["filename"] ), fg="yellow", file=sys.stderr, ) elif server_safety_policy: with tempfile.NamedTemporaryFile(prefix="server-safety-policy-") as tmp: tmp.write(server_safety_policy.encode("utf-8")) tmp.seek(0) policy_file = SafetyPolicyFile().convert(tmp.name, param=None, ctx=None) LOG.info("Using server side policy file") return policy_file, server_audit_and_monitor def save_report(path: str, default_name: str, report: str) -> None: """ Saves the report to a file. Args: path (str): The path to save the report. default_name (str): The default name of the report file. report (str): The report content. """ if path: save_at = path if os.path.isdir(path): save_at = os.path.join(path, default_name) with open(save_at, "w+") as report_file: report_file.write(report)