Files
Hotel-Booking/Backend/venv/lib/python3.12/site-packages/safety/safety.py
Iliyan Angelov 62c1fe5951 updates
2025-12-01 06:50:10 +02:00

2039 lines
64 KiB
Python

# -*- coding: utf-8 -*-
# type: ignore
from dataclasses import asdict
import errno
import itertools
import json
import logging
import os
from pathlib import Path
import sys
import tempfile
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict, Optional, List, Any, Union, Iterator
import click
import requests
from packaging.specifiers import SpecifierSet
from packaging.utils import canonicalize_name
from packaging.version import parse as parse_version, Version
from pydantic_core import to_jsonable_python
from filelock import FileLock
from safety_schemas.models import Ecosystem, FileType
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from .constants import (
API_MIRRORS,
DB_CACHE_FILE,
OPEN_MIRRORS,
REQUEST_TIMEOUT,
DATA_API_BASE_URL,
JSON_SCHEMA_VERSION,
IGNORE_UNPINNED_REQ_REASON,
)
from .errors import (
DatabaseFetchError,
DatabaseFileNotFoundError,
InvalidCredentialError,
TooManyRequestsError,
NetworkConnectionError,
RequestTimeoutError,
ServerError,
MalformedDatabase,
)
from .meta import get_meta_http_headers
from .models import (
Vulnerability,
CVE,
Severity,
Fix,
is_pinned_requirement,
SafetyRequirement,
)
from .output_utils import (
print_service,
get_applied_msg,
prompt_service,
get_skipped_msg,
get_fix_opt_used_msg,
is_using_api_key,
get_specifier_range_info,
)
from .util import (
build_remediation_info_url,
pluralize,
read_requirements,
Package,
build_telemetry_data,
sync_safety_context,
SafetyContext,
validate_expiration_date,
is_a_remote_mirror,
get_requirements_content,
SafetyPolicyFile,
get_terminal_size,
is_ignore_unpinned_mode,
get_hashes,
)
LOG = logging.getLogger(__name__)
def get_from_cache(
db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False
) -> Optional[Dict[str, Any]]:
"""
Retrieves the database from the cache if it is valid.
Args:
db_name (str): The name of the database.
cache_valid_seconds (int): The validity period of the cache in seconds.
skip_time_verification (bool): Whether to skip time verification.
Returns:
Optional[[Dict[str, Any]]: The cached database if available and valid, otherwise False.
"""
cache_file_lock = f"{DB_CACHE_FILE}.lock"
os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True)
lock = FileLock(cache_file_lock, timeout=10)
with lock:
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE) as f:
try:
data = json.loads(f.read())
if db_name in data:
if "cached_at" in data[db_name]:
if (
data[db_name]["cached_at"] + cache_valid_seconds
> time.time()
or skip_time_verification
):
LOG.debug(
"Getting the database from cache at %s, cache setting: %s",
data[db_name]["cached_at"],
cache_valid_seconds,
)
try:
data[db_name]["db"]["meta"]["base_domain"] = (
"https://data.safetycli.com"
)
except KeyError:
pass
return data[db_name]["db"]
LOG.debug(
"Cached file is too old, it was cached at %s",
data[db_name]["cached_at"],
)
else:
LOG.debug(
"There is not the cached_at key in %s database",
data[db_name],
)
except json.JSONDecodeError:
LOG.debug("JSONDecodeError trying to get the cached database.")
else:
LOG.debug("Cache file doesn't exist...")
return None
def write_to_cache(db_name: str, data: Dict[str, Any]) -> None:
"""
Writes the database to the cache.
Args:
db_name (str): The name of the database.
data (Dict[str, Any]): The database data to be cached.
"""
# cache is in: ~/safety/cache.json
# and has the following form:
# {
# "insecure.json": {
# "cached_at": 12345678
# "db": {}
# },
# "insecure_full.json": {
# "cached_at": 12345678
# "db": {}
# },
# }
if not os.path.exists(os.path.dirname(DB_CACHE_FILE)):
try:
os.makedirs(os.path.dirname(DB_CACHE_FILE))
with open(DB_CACHE_FILE, "w") as _:
_.write(json.dumps({}))
LOG.debug("Cache file created")
except OSError as exc: # Guard against race condition
LOG.debug("Unable to create the cache file because: %s", exc.errno)
if exc.errno != errno.EEXIST:
raise
cache_file_lock = f"{DB_CACHE_FILE}.lock"
lock = FileLock(cache_file_lock, timeout=10)
with lock:
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE, "r") as f:
try:
cache = json.loads(f.read())
except json.JSONDecodeError:
LOG.debug(
"JSONDecodeError in the local cache, dumping the full cache file."
)
cache = {}
else:
cache = {}
with open(DB_CACHE_FILE, "w") as f:
cache[db_name] = {"cached_at": time.time(), "db": data}
f.write(json.dumps(cache))
LOG.debug("Safety updated the cache file for %s database.", db_name)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3),
reraise=True,
retry=retry_if_exception_type(
(
NetworkConnectionError,
RequestTimeoutError,
TooManyRequestsError,
MalformedDatabase,
ServerError,
)
),
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
)
def fetch_database_url(
session: requests.Session,
mirror: str,
db_name: str,
cached: int,
telemetry: bool = True,
ecosystem: Ecosystem = Ecosystem.PYTHON,
from_cache: bool = True,
) -> Dict[str, Any]:
"""
Fetches the database from a URL.
Args:
session (requests.Session): The requests session.
mirror (str): The URL of the mirror.
db_name (str): The name of the database.
cached (int): The cache validity in seconds.
telemetry (bool): Whether to include telemetry data.
ecosystem (Ecosystem): The ecosystem.
from_cache (bool): Whether to fetch from cache.
Returns:
Dict[str, Any]: The fetched database.
"""
headers = {"schema-version": JSON_SCHEMA_VERSION, "ecosystem": ecosystem.value}
headers.update(get_meta_http_headers())
if cached and from_cache:
cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached)
if cached_data:
LOG.info("Database %s returned from cache.", db_name)
return cached_data
url = mirror + db_name
telemetry_data = {
"telemetry": json.dumps(
build_telemetry_data(telemetry=telemetry), default=to_jsonable_python
)
}
try:
r = session.get(
url=url, timeout=REQUEST_TIMEOUT, headers=headers, params=telemetry_data
)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException:
raise DatabaseFetchError()
if r.status_code == 403:
raise InvalidCredentialError(credential=session.get_credential(), reason=r.text)
if r.status_code == 429:
raise TooManyRequestsError(reason=r.text)
if r.status_code >= 500 and r.status_code < 600:
raise ServerError(reason=r.reason)
try:
data = r.json()
except json.JSONDecodeError as e:
raise MalformedDatabase(reason=e)
if cached:
LOG.info("Writing %s to cache because cached value was %s", db_name, cached)
write_to_cache(db_name, data)
return data
def fetch_policy(session: requests.Session) -> Dict[str, Any]:
"""
Fetches the policy from the server.
Args:
session (requests.Session): The requests session.
Returns:
Dict[str, Any]: The fetched policy.
"""
url = f"{DATA_API_BASE_URL}policy/"
try:
LOG.debug("Getting policy")
headers = get_meta_http_headers()
r = session.get(url=url, timeout=REQUEST_TIMEOUT, headers=headers)
LOG.debug(r.text)
return r.json()
except Exception:
LOG.exception("Error fetching policy")
return {"safety_policy": "", "audit_and_monitor": False}
def fetch_database_file(
path: str, db_name: str, cached: int = 0, ecosystem: Optional[Ecosystem] = None
) -> Dict[str, Any]:
"""
Fetches the database from a local file.
Args:
path (str): The path to the local file.
db_name (str): The name of the database.
cached (int): The cache validity in seconds.
ecosystem (Optional[Ecosystem]): The ecosystem.
Returns:
Dict[str, Any]: The fetched database.
"""
full_path = (
(Path(path) / (ecosystem.value if ecosystem else "") / db_name)
.expanduser()
.resolve()
)
if not full_path.exists():
raise DatabaseFileNotFoundError(db=path)
with open(full_path) as f:
data = json.loads(f.read())
if cached:
LOG.info("Writing %s to cache because cached value was %s", db_name, cached)
write_to_cache(db_name, data)
return data
def is_valid_database(db: Dict[str, Any]) -> bool:
"""
Checks if the database is valid.
Args:
db (Dict[str, Any]): The database.
Returns:
bool: True if the database is valid, False otherwise.
"""
try:
if db["meta"]["schema_version"] == JSON_SCHEMA_VERSION:
return True
except Exception:
return False
return False
def fetch_database(
session: requests.Session,
full: bool = False,
db: Union[Optional[str], bool] = False,
cached: int = 0,
telemetry: bool = True,
ecosystem: Optional[Ecosystem] = None,
from_cache: bool = True,
) -> Dict[str, Any]:
"""
Fetches the database from a mirror or a local file.
Args:
session (requests.Session): The requests session.
full (bool): Whether to fetch the full database.
db (Optional[str]): The path to the local database file.
cached (int): The cache validity in seconds.
telemetry (bool): Whether to include telemetry data.
ecosystem (Optional[Ecosystem]): The ecosystem.
from_cache (bool): Whether to fetch from cache.
Returns:
Dict[str, Any]: The fetched database.
"""
if session.is_using_auth_credentials():
mirrors = API_MIRRORS
elif db:
mirrors = [db]
else:
mirrors = OPEN_MIRRORS
db_name = "insecure_full.json" if full else "insecure.json"
for mirror in mirrors:
# mirror can either be a local path or a URL
if is_a_remote_mirror(mirror):
if ecosystem is None:
ecosystem = Ecosystem.PYTHON
data = fetch_database_url(
session,
mirror,
db_name=db_name,
cached=cached,
telemetry=telemetry,
ecosystem=ecosystem,
from_cache=from_cache,
)
else:
data = fetch_database_file(
mirror, db_name=db_name, cached=cached, ecosystem=ecosystem
)
if data:
if is_valid_database(data):
return data
raise MalformedDatabase(
fetched_from=mirror,
reason=f"Not supported schema version. "
f"This Safety version supports only schema version {JSON_SCHEMA_VERSION}",
)
raise DatabaseFetchError()
def get_vulnerabilities(
pkg: str, spec: str, db: Dict[str, Any]
) -> Iterator[Dict[str, Any]]:
"""
Retrieves vulnerabilities for a package from the database.
Args:
pkg (str): The package name.
spec (str): The specifier set.
db (Dict[str, Any]): The database.
Returns:
Iterator[Dict[str, Any]]: An iterator of vulnerabilities.
"""
for entry in db["vulnerable_packages"][pkg]:
for entry_spec in entry["specs"]:
if entry_spec == spec:
yield entry
def get_vulnerability_from(
vuln_id: str,
cve: Optional[CVE],
data: Dict[str, Any],
specifier: str,
db: Dict[str, Any],
name: str,
pkg: Package,
ignore_vulns: Dict[str, Any],
affected: SafetyRequirement,
) -> Vulnerability:
"""
Constructs a Vulnerability object from the provided data.
Args:
vuln_id (str): The vulnerability ID.
cve (Optional[CVE]): The CVE object.
data (Dict[str, Any]): The vulnerability data.
specifier (str): The specifier set.
db (Dict[str, Any]): The database.
name (str): The package name.
pkg (Package): The Package object.
ignore_vulns (Dict[str, Any]): The ignored vulnerabilities.
affected (SafetyRequirement): The affected requirement.
Returns:
Vulnerability: The constructed Vulnerability object.
"""
base_domain = db.get("meta", {}).get("base_domain")
unpinned_ignored = ignore_vulns.get(vuln_id, {}).get("requirements", None)
should_ignore = not unpinned_ignored or str(affected.specifier) in unpinned_ignored
ignored = (
ignore_vulns
and vuln_id in ignore_vulns
and should_ignore
and (
not ignore_vulns[vuln_id]["expires"]
or ignore_vulns[vuln_id]["expires"] > datetime.utcnow()
)
)
more_info_url = f"{base_domain}{data.get('more_info_path', '')}"
severity = None
if cve and (cve.cvssv2 or cve.cvssv3):
severity = Severity(source=cve.name, cvssv2=cve.cvssv2, cvssv3=cve.cvssv3)
analyzed_requirement = affected
analyzed_version = (
next(iter(analyzed_requirement.specifier)).version
if is_pinned_requirement(analyzed_requirement.specifier)
else None
)
vulnerable_spec = set()
vulnerable_spec.add(specifier)
return Vulnerability(
vulnerability_id=vuln_id,
package_name=name,
pkg=pkg,
ignored=ignored,
ignored_reason=ignore_vulns.get(vuln_id, {}).get("reason", None)
if ignore_vulns
else None,
ignored_expires=ignore_vulns.get(vuln_id, {}).get("expires", None)
if ignore_vulns
else None,
vulnerable_spec=vulnerable_spec,
all_vulnerable_specs=data.get("specs", []),
analyzed_version=analyzed_version,
analyzed_requirement=analyzed_requirement,
advisory=data.get("advisory"),
is_transitive=data.get("transitive", False),
published_date=data.get("published_date"),
fixed_versions=[ver for ver in data.get("fixed_versions", []) if ver],
closest_versions_without_known_vulnerabilities=data.get(
"closest_secure_versions", []
),
resources=data.get("vulnerability_resources"),
CVE=cve,
severity=severity,
affected_versions=data.get("affected_versions", []),
more_info_url=more_info_url,
)
def get_cve_from(data: Dict[str, Any], db_full: Dict[str, Any]) -> Optional[CVE]:
"""
Retrieves the CVE object from the provided data.
Args:
data (Dict[str, Any]): The vulnerability data.
db_full (Dict[str, Any]): The full database.
Returns:
Optional[CVE]: The CVE object if found, otherwise None.
"""
try:
xve_id: str = str(
next(
filter(
lambda i: i.get("type", None) in ["cve", "pve"], data.get("ids", [])
)
).get("id", "")
)
except StopIteration:
xve_id: str = ""
if not xve_id:
return None
cve_meta = db_full.get("meta", {}).get("severities", {}).get(xve_id, {})
return CVE(
name=xve_id,
cvssv2=cve_meta.get("cvssv2", None),
cvssv3=cve_meta.get("cvssv3", None),
)
def ignore_vuln_if_needed(
pkg: Package,
vuln_id: str,
cve: Optional[CVE],
ignore_vulns: Dict[str, Any],
ignore_severity_rules: Dict[str, Any],
req: SafetyRequirement,
) -> None:
"""
Determines if a vulnerability should be ignored based on severity rules and updates the ignore_vulns dictionary.
Args:
pkg (Package): The package.
vuln_id (str): The vulnerability ID.
cve (Optional[CVE]): The CVE object.
ignore_vulns (Dict[str, Any]): The ignored vulnerabilities.
ignore_severity_rules (Dict[str, Any]): The severity rules for ignoring vulnerabilities.
req (SafetyRequirement): The affected requirement.
"""
if not ignore_severity_rules:
ignore_severity_rules = {}
if not isinstance(ignore_vulns, dict):
return
severity = None
if cve:
if cve.cvssv2 and cve.cvssv2.get("base_score", None):
severity = cve.cvssv2.get("base_score", None)
if cve.cvssv3 and cve.cvssv3.get("base_score", None):
severity = cve.cvssv3.get("base_score", None)
ignore_severity_below = float(
ignore_severity_rules.get("ignore-cvss-severity-below", 0.0)
)
ignore_unknown_severity = bool(
ignore_severity_rules.get("ignore-cvss-unknown-severity", False)
)
if severity:
if float(severity) < ignore_severity_below:
reason = "Ignored by severity rule in policy file, {0} < {1}".format(
float(severity), ignore_severity_below
)
ignore_vulns[vuln_id] = {"reason": reason, "expires": None}
elif ignore_unknown_severity:
reason = "Unknown CVSS severity, ignored by severity rule in policy file."
ignore_vulns[vuln_id] = {"reason": reason, "expires": None}
version = (
next(iter(req.specifier)).version
if is_pinned_requirement(req.specifier)
else pkg.version
)
is_prev_not_ignored: bool = vuln_id not in ignore_vulns
is_req_not_ignored: bool = "requirements" in ignore_vulns.get(vuln_id, {}) and str(
req.specifier
) not in ignore_vulns.get(vuln_id, {}).get("requirements", set())
if (is_prev_not_ignored or is_req_not_ignored) and is_ignore_unpinned_mode(version):
reason = IGNORE_UNPINNED_REQ_REASON
requirements = set()
requirements.add(str(req.specifier))
ignore_vulns[vuln_id] = {
"reason": reason,
"expires": None,
"requirements": requirements,
}
def is_vulnerable(
vulnerable_spec: SpecifierSet, requirement: SafetyRequirement, package: Package
) -> bool:
"""
Checks if a package version is vulnerable.
Args:
vulnerable_spec (SpecifierSet): The specifier set for vulnerable versions.
requirement (SafetyRequirement): The package requirement.
package (Package): The package.
Returns:
bool: True if the package version is vulnerable, False otherwise.
"""
if is_pinned_requirement(requirement.specifier):
try:
return vulnerable_spec.contains(next(iter(requirement.specifier)).version)
except Exception:
# Ugly for now...
message = f"Version {requirement.specifier} for {package.name} is invalid and is ignored by Safety. Please See PEP 440."
if message not in [a["message"] for a in SafetyContext.local_announcements]:
SafetyContext.local_announcements.append(
{"message": message, "type": "warning", "local": True}
)
return False
return any(
requirement.specifier.filter(
vulnerable_spec.filter(package.insecure_versions, prereleases=True),
prereleases=True,
)
)
@sync_safety_context
def check(
*,
session: requests.Session,
packages: List[Package] = [],
db_mirror: Union[Optional[str], bool] = False,
cached: int = 0,
ignore_vulns: Optional[Dict[str, Any]] = None,
ignore_severity_rules: Optional[Dict[str, Any]] = None,
proxy: Optional[Dict[str, Any]] = None,
include_ignored: bool = False,
is_env_scan: bool = True,
telemetry: bool = True,
params: Optional[Dict[str, Any]] = None,
project: Optional[str] = None,
) -> tuple:
"""
Performs a vulnerability check on the provided packages.
Args:
session (requests.Session): The requests session.
packages (List[Package]): The list of packages to check.
db_mirror (Union[Optional[str], bool]): The database mirror.
cached (int): The cache validity in seconds.
ignore_vulns (Optional[Dict[str, Any]]): The ignored vulnerabilities.
ignore_severity_rules (Optional[Dict[str, Any]]): The severity rules for ignoring vulnerabilities.
proxy (Optional[Dict[str, Any]]): The proxy settings.
include_ignored (bool): Whether to include ignored vulnerabilities.
is_env_scan (bool): Whether it is an environment scan.
telemetry (bool): Whether to include telemetry data.
params (Optional[Dict[str, Any]]): Additional parameters.
project (Optional[str]): The project name.
Returns:
tuple: A tuple containing the list of vulnerabilities and the full database.
"""
SafetyContext().command = "check"
db = fetch_database(session, db=db_mirror, cached=cached, telemetry=telemetry)
db_full = None
vulnerable_packages = frozenset(db.get("vulnerable_packages", []))
vulnerabilities = []
found_pkgs = {}
requirements = iter([])
for p in packages:
requirements = itertools.chain(p.requirements, requirements)
found_pkgs[canonicalize_name(p.name)] = p
# Let's report by req, pinned in environment will be ==version
for req in requirements:
vuln_per_req = {}
name = canonicalize_name(req.name)
pkg = found_pkgs.get(name, None)
if not pkg.version:
if not db_full:
db_full = fetch_database(
session, full=True, db=db_mirror, cached=cached, telemetry=telemetry
)
pkg.refresh_from(db_full)
if name in vulnerable_packages:
# we have a candidate here, build the spec set
for specifier in db["vulnerable_packages"][name]:
spec_set = SpecifierSet(specifiers=specifier)
if is_vulnerable(spec_set, req, pkg):
if not db_full:
db_full = fetch_database(
session,
full=True,
db=db_mirror,
cached=cached,
telemetry=telemetry,
)
if not pkg.latest_version:
pkg.refresh_from(db_full)
for data in get_vulnerabilities(
pkg=name, spec=specifier, db=db_full
):
try:
vuln_id: str = str(
next(
filter(
lambda i: i.get("type", None) == "pyup",
data.get("ids", []),
)
).get("id", "")
)
except StopIteration:
vuln_id: str = ""
if vuln_id in vuln_per_req:
vuln_per_req[vuln_id].vulnerable_spec.add(specifier)
continue
cve = get_cve_from(data, db_full)
ignore_vuln_if_needed(
pkg, vuln_id, cve, ignore_vulns, ignore_severity_rules, req
)
vulnerability = get_vulnerability_from(
vuln_id,
cve,
data,
specifier,
db_full,
name,
pkg,
ignore_vulns,
req,
)
should_add_vuln = not (
vulnerability.is_transitive and is_env_scan
)
if (
include_ignored
or vulnerability.vulnerability_id not in ignore_vulns
) and should_add_vuln:
vuln_per_req[vulnerability.vulnerability_id] = vulnerability
vulnerabilities.append(vulnerability)
return vulnerabilities, db_full
def precompute_remediations(
remediations: Dict[str, Dict[str, Any]],
packages: Dict[str, Package],
vulns: List[Vulnerability],
secure_vulns_by_user: set,
) -> None:
"""
Precomputes the remediations for the given vulnerabilities.
Args:
remediations (Dict[str, Dict[str, Any]]): The remediations dictionary.
packages (Dict[str, Package]): The packages dictionary.
vulns (List[Vulnerability]): The list of vulnerabilities.
secure_vulns_by_user (set): The set of vulnerabilities secured by the user.
"""
for vuln in vulns:
if vuln.ignored and vuln.ignored_reason != IGNORE_UNPINNED_REQ_REASON:
secure_vulns_by_user.add(vuln.vulnerability_id)
continue
if (
vuln.package_name in remediations.keys()
and str(vuln.analyzed_requirement.specifier)
in remediations[vuln.package_name]
):
spec = remediations[vuln.package_name][
str(vuln.analyzed_requirement.specifier)
]
spec["vulnerabilities_found"] = spec.get("vulnerabilities_found", 0) + 1
else:
version = None
is_pinned = is_pinned_requirement(vuln.analyzed_requirement.specifier)
if is_pinned:
version = next(iter(vuln.analyzed_requirement.specifier)).version
if not is_pinned and is_ignore_unpinned_mode(version):
# Let's ignore this requirement
continue
vulns_count = 1
packages[vuln.package_name] = vuln.pkg
remediations[vuln.package_name][
str(vuln.analyzed_requirement.specifier)
] = {
"vulnerabilities_found": vulns_count,
"version": version,
"requirement": vuln.analyzed_requirement,
"more_info_url": vuln.pkg.more_info_url,
}
def get_closest_ver(
versions: List[str], version: Optional[str], spec: SpecifierSet
) -> Dict[str, Optional[Union[str, Version]]]:
"""
Retrieves the closest versions for the given version and specifier set.
Args:
versions (List[str]): The list of versions.
version (Optional[str]): The current version.
spec (SpecifierSet): The specifier set.
Returns:
Dict[str, Optional[Union[str, Version]]]: The closest versions.
"""
results: Dict[str, Optional[Union[str, Version]]] = {"upper": None, "lower": None}
if (not version and not spec) or not versions:
return results
sorted_versions = sorted(versions, key=lambda ver: parse_version(ver), reverse=True)
if not version:
sorted_versions = spec.filter(sorted_versions, prereleases=False)
upper = None
lower = None
try:
sorted_versions = list(sorted_versions)
upper = sorted_versions[0]
lower = sorted_versions[-1]
results["upper"] = upper
results["lower"] = lower if upper != lower else None
except IndexError:
pass
return results
current_v = parse_version(version)
for v in sorted_versions:
index = parse_version(v)
if index > current_v:
results["upper"] = index
if index < current_v:
results["lower"] = index
break
return results
def compute_sec_ver_for_user(
package: Package, secure_vulns_by_user: set, db_full: Dict[str, Any]
) -> List[str]:
"""
Computes the secure versions for the user.
Args:
package (Package): The package.
secure_vulns_by_user (set): The set of vulnerabilities secured by the user.
db_full (Dict[str, Any]): The full database.
Returns:
List[str]: The list of secure versions.
"""
versions = package.get_versions(db_full)
affected_versions = []
for vuln in db_full.get("vulnerable_packages", {}).get(package.name, []):
vuln_id: str = str(
next(
filter(lambda i: i.get("type", None) == "pyup", vuln.get("ids", []))
).get("id", "")
)
if vuln_id and vuln_id not in secure_vulns_by_user:
affected_versions += vuln.get("affected_versions", [])
affected_v = set(affected_versions)
sec_ver_for_user = list(versions.difference(affected_v))
return sorted(sec_ver_for_user, key=lambda ver: parse_version(ver), reverse=True)
def compute_sec_ver(
remediations: Dict[str, Dict[str, Any]],
packages: Dict[str, Package],
secure_vulns_by_user: set,
db_full: Dict[str, Any],
) -> None:
"""
Computes the secure versions and the closest secure version for each remediation.
Uses the affected_versions of each no ignored vulnerability of the same package, there is only a remediation for each package.
Args:
remediations (Dict[str, Dict[str, Any]]): The remediations dictionary.
packages (Dict[str, Package]): The packages dictionary.
secure_vulns_by_user (set): The set of vulnerabilities secured by the user.
db_full (Dict[str, Any]): The full database.
"""
for pkg_name in remediations.keys():
pkg: Package = packages.get(pkg_name, None)
secure_versions = []
if pkg:
secure_versions = pkg.secure_versions
analyzed = set(remediations[pkg_name].keys())
if not is_using_api_key():
continue
for analyzed_requirement in analyzed:
rem = remediations[pkg_name][analyzed_requirement]
spec = rem.get("requirement").specifier
version = rem["version"]
if not secure_vulns_by_user:
secure_v = sorted(
secure_versions, key=lambda ver: parse_version(ver), reverse=True
)
else:
secure_v = compute_sec_ver_for_user(
package=pkg,
secure_vulns_by_user=secure_vulns_by_user,
db_full=db_full,
)
rem["closest_secure_version"] = get_closest_ver(secure_v, version, spec)
upgrade = rem["closest_secure_version"].get("upper", None)
downgrade = rem["closest_secure_version"].get("lower", None)
recommended_version = None
if upgrade:
recommended_version = upgrade
elif downgrade:
recommended_version = downgrade
rem["recommended_version"] = recommended_version
rem["other_recommended_versions"] = [
other_v for other_v in secure_v if other_v != str(recommended_version)
]
# Refresh the URL with the recommended version.
spec = str(rem["requirement"].specifier)
if is_using_api_key():
rem["more_info_url"] = build_remediation_info_url(
base_url=rem["more_info_url"],
version=version,
spec=spec,
target_version=recommended_version,
)
def calculate_remediations(
vulns: List[Vulnerability], db_full: Dict[str, Any]
) -> Dict[str, Dict[str, Any]]:
"""
Calculates the remediations for the given vulnerabilities.
Args:
vulns (List[Vulnerability]): The list of vulnerabilities.
db_full (Dict[str, Any]): The full database.
Returns:
Dict[str, Dict[str, Any]]: The calculated remediations.
"""
remediations = defaultdict(dict)
package_metadata = {}
secure_vulns_by_user = set()
if not db_full:
return remediations
precompute_remediations(remediations, package_metadata, vulns, secure_vulns_by_user)
compute_sec_ver(remediations, package_metadata, secure_vulns_by_user, db_full)
return remediations
def should_apply_auto_fix(
from_ver: Optional[Version], to_ver: Version, allowed_automatic: List[str]
) -> bool:
"""
Determines if an automatic fix should be applied.
Args:
from_ver (Optional[Version]): The current version.
to_ver (Version): The target version.
allowed_automatic (List[str]): The allowed automatic update types.
Returns:
bool: True if an automatic fix should be applied, False otherwise.
"""
if not from_ver:
return False
if "major" in allowed_automatic:
return True
major_change = to_ver.major - from_ver.major
minor_change = to_ver.minor - from_ver.minor
if "minor" in allowed_automatic:
if major_change != 0:
return False
return True
if "patch" in allowed_automatic:
if major_change != 0 or minor_change != 0:
return False
return True
return False
def get_update_type(from_ver: Optional[Version], to_ver: Version) -> str:
"""
Determines the update type.
Args:
from_ver (Optional[Version]): The current version.
to_ver (Version): The target version.
Returns:
str: The update type.
"""
if not from_ver or (to_ver.major - from_ver.major) != 0:
return "major"
if (to_ver.minor - from_ver.minor) != 0:
return "minor"
return "patch"
def process_fixes(
files: List[str],
remediations: Dict[str, Dict[str, Any]],
auto_remediation_limit: List[str],
output: str,
no_output: bool = True,
prompt: bool = False,
) -> List[Fix]:
"""
Processes the fixes for the given files and remediations.
Args:
files (List[str]): The list of files.
remediations (Dict[str, Dict[str, Any]]): The remediations dictionary.
auto_remediation_limit (List[str]): The automatic remediation limits.
output (str): The output format.
no_output (bool): Whether to suppress output.
prompt (bool): Whether to prompt for confirmation.
Returns:
List[Fix]: The list of applied fixes.
"""
req_remediations = itertools.chain.from_iterable(
rem.values() for pkg_name, rem in remediations.items()
)
requirements = compute_fixes_per_requirements(
files, req_remediations, auto_remediation_limit, prompt=prompt
)
fixes = apply_fixes(requirements, output, no_output, prompt)
return fixes
def process_fixes_scan(
file_to_fix: SafetyPolicyFile,
to_fix_spec: List[SafetyRequirement],
auto_remediation_limit: List[str],
output: str,
no_output: bool = True,
prompt: bool = False,
) -> List[Fix]:
"""
Processes the fixes for the given file and specifications in scan mode.
Args:
file_to_fix (SafetyPolicyFile): The file to fix.
to_fix_spec (List[SafetyRequirement]): The specifications to fix.
auto_remediation_limit (List[str]): The automatic remediation limits.
output (str): The output format.
no_output (bool): Whether to suppress output.
prompt (bool): Whether to prompt for confirmation.
Returns:
List[Fix]: The list of applied fixes.
"""
def get_remmediation_from(spec):
upper = None
lower = None
recommended = None
try:
upper = (
Version(spec.remediation.closest_secure.upper)
if spec.remediation.closest_secure.upper
else None
)
except Exception:
LOG.error(
"Error getting upper remediation version, ignoring", exc_info=True
)
try:
lower = (
Version(spec.remediation.closest_secure.lower)
if spec.remediation.closest_secure.lower
else None
)
except Exception:
LOG.error(
"Error getting lower remediation version, ignoring", exc_info=True
)
try:
recommended = Version(spec.remediation.recommended)
except Exception:
LOG.error(
"Error getting recommended version for remediation, ignoring",
exc_info=True,
)
return {
"vulnerabilities_found": spec.remediation.vulnerabilities_found,
"version": next(iter(spec.specifier)).version if spec.is_pinned() else None,
"requirement": spec,
"more_info_url": spec.remediation.more_info_url,
"closest_secure_version": {"upper": upper, "lower": lower},
"recommended_version": recommended,
"other_recommended_versions": spec.remediation.other_recommended,
}
req_remediations = iter(get_remmediation_from(spec) for spec in to_fix_spec)
SUPPORTED_FILE_TYPES = [FileType.REQUIREMENTS_TXT]
if file_to_fix.file_type in SUPPORTED_FILE_TYPES:
files = (open(file_to_fix.location),)
requirements = compute_fixes_per_requirements(
files, req_remediations, auto_remediation_limit, prompt=prompt
)
else:
requirements = {
"files": {
str(file_to_fix.location): {
"content": None,
"fixes": {"TO_SKIP": [], "TO_APPLY": [], "TO_CONFIRM": []},
"supported": False,
"filename": file_to_fix.location.name,
}
},
"dependencies": defaultdict(dict),
}
fixes = apply_fixes(
requirements,
output,
no_output,
prompt,
scan_flow=True,
auto_remediation_limit=auto_remediation_limit,
)
return fixes
def compute_fixes_per_requirements(
files: List[str],
req_remediations: Iterator[Dict[str, Any]],
auto_remediation_limit: List[str],
prompt: bool = False,
) -> Dict[str, Any]:
"""
Computes the fixes per requirements.
Args:
files (List[str]): The list of files.
req_remediations (Iterator[Dict[str, Any]]): The remediations iterator.
auto_remediation_limit (List[str]): The automatic remediation limits.
prompt (bool): Whether to prompt for confirmation.
Returns:
Dict[str, Any]: The computed requirements with fixes.
"""
requirements_files = get_requirements_content(files)
from dparse.parser import parse, filetypes
from packaging.version import Version, InvalidVersion
requirements = {
"files": {},
"dependencies": defaultdict(dict),
}
for name, contents in requirements_files.items():
dependency_file = parse(
contents, path=name, file_type=filetypes.requirements_txt, resolve=True
)
dependency_files = dependency_file.resolved_files + [dependency_file]
empty_spec = SpecifierSet("")
default_spec = SpecifierSet(">=0")
# Support recursive requirements in the multiple requirement files provided
for resolved_f in dependency_files:
if not resolved_f or isinstance(resolved_f, str):
continue
file = {
"content": resolved_f.content,
"fixes": {"TO_SKIP": [], "TO_APPLY": [], "TO_CONFIRM": []},
}
requirements["files"][resolved_f.path] = file
for d in resolved_f.dependencies:
if d.specs == empty_spec:
d.specs = default_spec
requirements["dependencies"][d.name][str(d.specs)] = (
d,
resolved_f.path,
)
for remediation in req_remediations:
req: SafetyRequirement = remediation.get("requirement")
pkg: str = req.name
dry_fix = Fix(
package=pkg,
more_info_url=remediation.get("more_info_url", ""),
previous_spec=str(req.specifier),
other_options=remediation.get("other_recommended_versions", []),
)
from_ver: Optional[str] = remediation.get("version", None)
if (
pkg not in requirements["dependencies"]
or dry_fix.previous_spec not in requirements["dependencies"][pkg]
):
# Let's attach it to the first file scanned.
file = next(iter(requirements["files"]))
# Let's use the no parsed version.
dry_fix.previous_version = from_ver
dry_fix.status = "AUTOMATICALLY_SKIPPED_NOT_FOUND_IN_FILE"
dry_fix.applied_at = file
requirements["files"][file]["fixes"]["TO_SKIP"].append(dry_fix)
continue
dependency, name = requirements["dependencies"][pkg][dry_fix.previous_spec]
dry_fix.applied_at = name
fixes = requirements["files"][name]["fixes"]
to_ver: Version = remediation["recommended_version"]
try:
from_ver = parse_version(from_ver)
except (InvalidVersion, TypeError):
if not dry_fix.previous_spec:
dry_fix.status = "AUTOMATICALLY_SKIPPED_INVALID_VERSION"
fixes["TO_SKIP"].append(dry_fix)
continue
dry_fix.previous_version = str(from_ver) if from_ver else from_ver
if remediation["recommended_version"] is None:
dry_fix.status = "AUTOMATICALLY_SKIPPED_NO_RECOMMENDED_VERSION"
fixes["TO_SKIP"].append(dry_fix)
continue
dry_fix.updated_version = str(to_ver)
is_fixed = from_ver == to_ver
if is_fixed:
dry_fix.status = "AUTOMATICALLY_SKIPPED_ALREADY_FIXED"
fixes["TO_SKIP"].append(dry_fix)
continue
update_type = get_update_type(from_ver, to_ver)
dry_fix.update_type = update_type
dry_fix.dependency = dependency
auto_fix = should_apply_auto_fix(from_ver, to_ver, auto_remediation_limit)
TARGET = "TO_APPLY"
if auto_fix:
dry_fix.status = "PENDING_TO_APPLY"
dry_fix.fix_type = "AUTOMATIC"
elif prompt:
TARGET = "TO_CONFIRM"
dry_fix.status = "PENDING_TO_CONFIRM"
dry_fix.fix_type = "MANUAL"
else:
TARGET = "TO_SKIP"
dry_fix.status = "AUTOMATICALLY_SKIPPED_UNABLE_TO_CONFIRM"
fixes[TARGET].append(dry_fix)
return requirements
def apply_fixes(
requirements: Dict[str, Any],
out_type: str,
no_output: bool,
prompt: bool,
scan_flow: bool = False,
auto_remediation_limit: List[str] = None,
) -> List[Fix]:
"""
Applies the fixes to the requirements.
Args:
requirements (Dict[str, Any]): The requirements with fixes.
out_type (str): The output format.
no_output (bool): Whether to suppress output.
prompt (bool): Whether to prompt for confirmation.
scan_flow (bool): Whether it is in scan flow mode.
auto_remediation_limit (List[str]): The automatic remediation limits.
Returns:
List[Fix]: The list of applied fixes.
"""
from dparse.updater import RequirementsTXTUpdater
skip = []
apply = []
confirm = []
brief = []
if not no_output:
style_kwargs = {}
if not scan_flow:
brief.append(("", {}))
brief.append(("Safety fix running", style_kwargs))
print_service(brief, out_type)
for name, data in requirements["files"].items():
output = [
("", {}),
(
f"Analyzing {name}... [{get_fix_opt_used_msg(auto_remediation_limit)} limit]",
{
"styling": {"bold": True},
"start_line_decorator": "->",
"indent": " ",
},
),
]
r_skip = data["fixes"]["TO_SKIP"]
r_apply = data["fixes"]["TO_APPLY"]
r_confirm = data["fixes"]["TO_CONFIRM"]
if data.get("supported", True):
new_content = data["content"]
updated: bool = False
for f in r_apply:
new_content = RequirementsTXTUpdater.update(
content=new_content,
version=f.updated_version,
dependency=f.dependency,
hashes=get_hashes(f.dependency),
)
f.status = "APPLIED"
updated = True
output.append(("", {}))
output.append((f"- {get_applied_msg(f, mode='auto')}", {}))
for f in r_skip:
output.append(("", {}))
output.append((f"- {get_skipped_msg(f)}", {}))
if not no_output:
print_service(output, out_type)
if prompt and not no_output:
for f in r_confirm:
options = [
f"({index}) =={option}"
for index, option in enumerate(f.other_options)
]
input_hint = f"Enter “y” to update to {f.package}=={f.updated_version}, “n” to skip this package upgrade"
if len(options) > 0:
input_hint += f", or enter the index from these secure versions to upgrade to that version: {', '.join(options)}"
print_service([("", {})], out_type)
confirmed: str = prompt_service(
(
f"- {f.package}{f.previous_spec} requires at least a {f.update_type} version update. Do you want to update {f.package} from {f.previous_spec} to =={f.updated_version}, which is the closest secure version? {input_hint}",
{},
),
out_type,
).lower()
try:
index: int = int(confirmed)
if index <= len(f.other_options):
confirmed = "y"
except ValueError:
index = -1
if confirmed == "y" or index > -1:
f.status = "APPLIED"
updated = True
if index > -1:
f.updated_version = f.other_options[index]
new_content = RequirementsTXTUpdater.update(
content=new_content,
version=f.updated_version,
dependency=f.dependency,
hashes=get_hashes(f.dependency),
)
output.append(
(get_applied_msg(f, mode="manual"), {"indent": " " * 5})
)
else:
f.status = "MANUALLY_SKIPPED"
output.append((get_skipped_msg(f), {"indent": " " * 5}))
if not no_output:
print_service(output, out_type)
if updated:
output.append(("", {}))
output.append((f"Updating {name}...", {}))
with open(name, mode="w") as r_file:
r_file.write(new_content)
output.append((f"Changes applied to {name}.", {}))
count = len(r_apply) + len(
[1 for fix in r_confirm if fix.status == "APPLIED"]
)
output.append(
(
f"{count} package {pluralize('version', count)} {pluralize('has', count)} been updated to secure versions in {Path(name).name}",
{},
)
)
output.append(
("Always check for breaking changes after updating packages.", {})
)
else:
output.append((f"No fixes to be made in {name}.", {}))
output.append(("", {}))
else:
not_supported_filename = data.get("filename", name)
output.append(
(
f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
{"start_line_decorator": " -", "indent": " "},
)
)
output.append(("", {}))
if not no_output:
print_service(output, out_type)
skip.extend(r_skip)
apply.extend(r_apply)
confirm.extend(r_confirm)
# The scan flow will handle the header and divider, because the scan flow can be called multiple times.
if not no_output and not scan_flow:
divider = (
f"{'=' * 78}"
if out_type == "text"
else f"{'=' * (get_terminal_size().columns - 2)}"
)
format_text = {
"start_line_decorator": "+",
"end_line_decorator": "+",
"indent": "",
}
print_service([(divider, {})], out_type, format_text=format_text)
return skip + apply + confirm
def find_vulnerabilities_fixed(
vulnerabilities: Dict[str, Any], fixes: List[Fix]
) -> List[Vulnerability]:
"""
Finds the vulnerabilities that have been fixed.
Args:
vulnerabilities (Dict[str, Any]): The dictionary of vulnerabilities.
fixes (List[Fix]): The list of applied fixes.
Returns:
List[Vulnerability]: The list of fixed vulnerabilities.
"""
fixed_specs = set(fix.previous_spec for fix in fixes)
if not fixed_specs:
return []
return [
vulnerability
for vulnerability in vulnerabilities
if str(vulnerability["analyzed_requirement"].specifier) in fixed_specs
]
@sync_safety_context
def review(
*, report: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None
) -> tuple:
"""
Reviews the report and returns the vulnerabilities and remediations.
Args:
report (Optional[Dict[str, Any]]): The report.
params (Optional[Dict[str, Any]]): Additional parameters.
Returns:
tuple: A tuple containing the list of vulnerabilities, the remediations, and the found packages.
"""
SafetyContext().command = "review"
vulnerable = []
vulnerabilities = report.get("vulnerabilities", []) + report.get(
"ignored_vulnerabilities", []
)
remediations = defaultdict(dict)
for key, pkg_rem in report.get("remediations", {}).items():
analyzed = set(pkg_rem["requirements"].keys())
for req in analyzed:
req_rem = pkg_rem["requirements"][req]
recommended = req_rem.get("recommended_version", None)
secure_v = req_rem.get("other_recommended_versions", [])
remediations[key][req] = {
"vulnerabilities_found": req_rem.get("vulnerabilities_found", 0),
"version": req_rem.get("version"),
"requirement": SafetyRequirement(req_rem["requirement"]["raw"]),
"other_recommended_versions": secure_v,
"recommended_version": parse_version(recommended)
if recommended
else None,
# minor isn't supported in review
"more_info_url": req_rem.get("more_info_url"),
}
packages = report.get("scanned_packages", [])
pkgs = {}
for name, values in packages.items():
requirements = [
SafetyRequirement(r["raw"]) for r in values.get("requirements", [])
]
values.update({"requirements": requirements})
pkgs[name] = Package(**values)
ctx = SafetyContext()
found_packages = list(pkgs.values())
ctx.packages = found_packages
ctx.review = report.get("report_meta", [])
ctx.key = ctx.review.get("api_key", False)
cvssv2 = None
cvssv3 = None
for vuln in vulnerabilities:
vuln["pkg"] = pkgs.get(vuln.get("package_name", None))
XVE_ID = vuln.get("CVE", None) # Trying to get first the CVE ID
severity = vuln.get("severity", None)
if severity and severity.get("source", False):
cvssv2 = severity.get("cvssv2", None)
cvssv3 = severity.get("cvssv3", None)
# Trying to get the PVE ID if it exists, otherwise it will be the same CVE ID of above
XVE_ID = severity.get("source", False)
vuln["severity"] = Severity(source=XVE_ID, cvssv2=cvssv2, cvssv3=cvssv3)
else:
vuln["severity"] = None
ignored_expires = vuln.get("ignored_expires", None)
if ignored_expires:
vuln["ignored_expires"] = validate_expiration_date(ignored_expires)
vuln["CVE"] = CVE(name=XVE_ID, cvssv2=cvssv2, cvssv3=cvssv3) if XVE_ID else None
vuln["analyzed_requirement"] = SafetyRequirement(
vuln["analyzed_requirement"]["raw"]
)
vulnerable.append(Vulnerability(**vuln))
return vulnerable, remediations, found_packages
@sync_safety_context
def get_licenses(
*,
session: requests.Session,
db_mirror: Union[Optional[str], bool] = False,
cached: int = 0,
telemetry: bool = True,
) -> Dict[str, Any]:
"""
Retrieves the licenses from the database.
Args:
session (requests.Session): The requests session.
db_mirror (Union[Optional[str], bool]): The database mirror.
cached (int): The cache validity in seconds.
telemetry (bool): Whether to include telemetry data.
Returns:
Dict[str, Any]: The licenses dictionary.
"""
if db_mirror:
mirrors = [db_mirror]
else:
mirrors = API_MIRRORS
db_name = "licenses.json"
for mirror in mirrors:
# mirror can either be a local path or a URL
if is_a_remote_mirror(mirror):
licenses = fetch_database_url(
session, mirror, db_name=db_name, cached=cached, telemetry=telemetry
)
else:
licenses = fetch_database_file(mirror, db_name=db_name, ecosystem=None)
if licenses:
return licenses
raise DatabaseFetchError()
def add_local_notifications(
packages: List[Package], ignore_unpinned_requirements: Optional[bool]
) -> List[Dict[str, str]]:
"""
Adds local notifications for unpinned packages.
Args:
packages (List[Package]): The list of packages.
ignore_unpinned_requirements (Optional[bool]): Whether to ignore unpinned requirements.
Returns:
List[Dict[str, str]]: The list of notifications.
"""
announcements = []
unpinned_packages: List[str] = [
f"{pkg.name}" for pkg in packages if pkg.has_unpinned_req()
]
if unpinned_packages and ignore_unpinned_requirements is not False:
found = len(unpinned_packages)
and_msg = ""
if found >= 2:
last = unpinned_packages.pop()
and_msg = f" and {last}"
pkgs: str = (
f"{', '.join(unpinned_packages)}{and_msg} {'are' if found > 1 else 'is'}"
)
doc_msg: str = get_specifier_range_info(style=False, pin_hint=True)
if ignore_unpinned_requirements is None:
msg = (
f"Warning: {pkgs} unpinned. Safety by default does not "
f"report on potential vulnerabilities in unpinned packages. {doc_msg}"
)
else:
msg = (
f"Warning: {pkgs} unpinned and potential vulnerabilities are "
f"being ignored given `ignore-unpinned-requirements` is True in your config. {doc_msg}"
)
announcements.append({"message": msg, "type": "warning", "local": True})
announcements.extend(SafetyContext().local_announcements)
return announcements
def get_announcements(
session: requests.Session, telemetry: bool = True, with_telemetry: Any = None
) -> List[Dict[str, str]]:
"""
Retrieves announcements from the server.
Args:
session (requests.Session): The requests session.
telemetry (bool): Whether to include telemetry data.
with_telemetry (Optional[Dict[str, Any]]): The telemetry data.
Returns:
List[Dict[str, str]]: The list of announcements.
"""
LOG.info("Getting announcements")
announcements = []
url = f"{DATA_API_BASE_URL}announcements/"
method = "post"
telemetry_data = (
with_telemetry if with_telemetry else build_telemetry_data(telemetry=telemetry)
)
data = asdict(telemetry_data)
request_kwargs = {"timeout": 3}
data_keyword = "json"
source = os.environ.get("SAFETY_ANNOUNCEMENTS_URL", None)
if source:
LOG.debug(f"Getting the announcement from a different source: {source}")
url = source
method = "get"
data = {"telemetry": json.dumps(data)}
data_keyword = "params"
request_kwargs[data_keyword] = data
request_kwargs["url"] = url
request_kwargs["headers"] = get_meta_http_headers()
LOG.debug(f"Telemetry data sent: {data}")
try:
request_func = getattr(session, method)
r = request_func(**request_kwargs)
LOG.debug(r.text)
except Exception as e:
LOG.info(
"Unexpected but HANDLED Exception happened getting the announcements: %s", e
)
return announcements
if r.status_code == 200:
try:
announcements = r.json()
if "announcements" in announcements.keys():
announcements = announcements.get("announcements", [])
else:
LOG.info(
"There is not announcements key in the JSON response, is this a wrong structure?"
)
announcements = []
except json.JSONDecodeError as e:
LOG.info(
"Unexpected but HANDLED Exception happened decoding the announcement response: %s",
e,
)
LOG.info("Announcements fetched")
return announcements
def get_distribution_location(distribution):
"""
Get the installation location of a distribution.
"""
if hasattr(distribution, "_path") and distribution._path:
return str(distribution._path)
# Fallback: try to get location from metadata
try:
location = distribution.locate_file("")
if location:
return str(location)
except (AttributeError, Exception):
pass
return ""
def get_distribution_version(distribution):
"""
Safely get the version of a distribution (Python 3.9+ compatible).
"""
try:
return distribution.version
except AttributeError:
return distribution.metadata.get("Version", "")
def get_distribution_name(distribution):
"""
Safely get the name of a distribution (Python 3.9+ compatible).
"""
try:
# For Python 3.10+
return distribution.name
except AttributeError:
# Fallback for Python 3.9 PathDistribution objects
return distribution.metadata.get("Name", "")
def get_packages(
files: Optional[List[str]] = None, stdin: bool = False
) -> List[Package]:
"""
Retrieves the packages from the given files or standard input.
Args:
files (Optional[List[str]]): The list of files.
stdin (bool): Whether to read from standard input.
Returns:
List[Package]: The list of packages.
"""
if files:
return list(
itertools.chain.from_iterable(
read_requirements(f, resolve=True) for f in files
)
)
if stdin:
return list(read_requirements(sys.stdin))
# Migrated from pkg_resources to importlib.metadata
import importlib.metadata
def allowed_version(pkg: str, version: str) -> bool:
try:
parse_version(version)
except Exception:
SafetyContext.local_announcements.append(
{
"message": f"Version {version} for {pkg} is invalid and is ignored by Safety. Please See PEP 440.",
"type": "warning",
"local": True,
}
)
return False
return True
distributions = list(importlib.metadata.distributions())
paths = {
str(dist._path.parent)
for dist in distributions
if hasattr(dist, "_path") and dist._path
}
SafetyContext().scanned_full_path.extend(paths)
packages = []
for d in distributions:
name = get_distribution_name(d)
version = get_distribution_version(d)
# Skip if name or version couldn't be determined
if not name or not version:
continue
# Skip excluded packages
if name in {"python", "wsgiref", "argparse"} or not allowed_version(
name, version
):
continue
location = get_distribution_location(d)
packages.append(
Package(
name=name,
version=version,
requirements=[SafetyRequirement(f"{name}=={version}", found=location)],
found=location,
insecure_versions=[],
secure_versions=[],
latest_version=None,
latest_version_without_known_vulnerabilities=None,
more_info_url=None,
)
)
return packages
def read_vulnerabilities(fh: Any) -> Dict[str, Any]:
"""
Reads vulnerabilities from a file handle.
Args:
fh (Any): The file handle.
Returns:
Dict[str, Any]: The vulnerabilities data.
"""
try:
data = json.load(fh)
except json.JSONDecodeError as e:
raise MalformedDatabase(reason=e, fetched_from=fh.name)
except TypeError as e:
raise MalformedDatabase(reason=e, fetched_from=fh.name)
return data
def get_server_policies(
session: requests.Session,
policy_file: SafetyPolicyFile,
proxy_dictionary: Dict[str, str],
) -> tuple:
"""
Retrieves the server policies.
Args:
session (requests.Session): The requests session.
policy_file (SafetyPolicyFile): The policy file.
proxy_dictionary (Dict[str, str]): The proxy dictionary.
Returns:
tuple: A tuple containing the policy file and the audit and monitor flag.
"""
if session.api_key:
server_policies = fetch_policy(session)
server_audit_and_monitor = server_policies["audit_and_monitor"]
server_safety_policy = server_policies["safety_policy"]
else:
server_audit_and_monitor = False
server_safety_policy = ""
if server_safety_policy and policy_file:
click.secho(
"Warning: both a local policy file '{policy_filename}' and a server sent policy are present. "
"Continuing with the local policy file.".format(
policy_filename=policy_file["filename"]
),
fg="yellow",
file=sys.stderr,
)
elif server_safety_policy:
with tempfile.NamedTemporaryFile(prefix="server-safety-policy-") as tmp:
tmp.write(server_safety_policy.encode("utf-8"))
tmp.seek(0)
policy_file = SafetyPolicyFile().convert(tmp.name, param=None, ctx=None)
LOG.info("Using server side policy file")
return policy_file, server_audit_and_monitor
def save_report(path: str, default_name: str, report: str) -> None:
"""
Saves the report to a file.
Args:
path (str): The path to save the report.
default_name (str): The default name of the report file.
report (str): The report content.
"""
if path:
save_at = path
if os.path.isdir(path):
save_at = os.path.join(path, default_name)
with open(save_at, "w+") as report_file:
report_file.write(report)