This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,7 @@
from rich.console import Console
from safety_schemas.models import Vulnerability, RemediationModel
from safety.scan.render import get_render_console
console = Console()
Vulnerability.__render__ = get_render_console(Vulnerability)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
# Console Help Theme
CONSOLE_HELP_THEME = {"nhc": "grey82"}
CLI_AUTH_COMMAND_HELP = (
"Authenticate Safety CLI to perform scans. Your default browser will automatically open to "
"https://platform.safetycli.com.\n\n"
"Example:\n safety auth login\n\n"
"For headless authentication, you will receive a URL to paste into an external browser.\n\n"
"Example:\n safety auth login --headless"
)
CLI_AUTH_HEADLESS_HELP = "For headless authentication, you will receive a URL to paste into an external browser."
CLI_SCAN_COMMAND_HELP = (
"Scans a Python project directory."
"\nExample: safety scan to scan the current directory"
)
CLI_SYSTEM_SCAN_COMMAND_HELP = (
"\\[beta] Run a comprehensive scan for packages and vulnerabilities across your entire machine/environment."
"\nExample: safety system-scan"
)
CLI_CHECK_COMMAND_HELP = (
"\\[deprecated] Find vulnerabilities at target files or environments. Now replaced by safety scan, and will be unsupported beyond 1 May 2024."
"\nExample: safety check -r requirements.txt"
)
CLI_LICENSES_COMMAND_HELP = (
"\\[deprecated] Find licenses at target files or environments. This command will be replaced by safety scan, and will be unsupported beyond 1 May 2024."
"\nExample: safety license -r requirements.txt"
)
CLI_ALERT_COMMAND_HELP = (
"\\[deprecated] Create GitHub pull requests or GitHub issues using a `safety check` json report file. Being replaced by newer features."
"\nExample: safety alert --check-report your-report.json --key API_KEY github-pr --repo my-org/my-repo --token github-token"
)
CLI_CHECK_UPDATES_HELP = (
"Check for version updates to Safety CLI.\nExample: safety check-updates"
)
CLI_CONFIGURE_HELP = (
"Set up global configurations for Safety CLI, including proxy settings and organization details."
"\nExample: safety configure --proxy-host 192.168.0.1"
)
CLI_GENERATE_HELP = (
"Generate a boilerplate Safety CLI policy file for customized security policies."
"\nNote: Safety Platform policies will override any local policy files found"
"\nExample: safety generate policy_file"
)
CLI_VALIDATE_HELP = (
"Check if your local Safety CLI policy file is valid."
"\nExample: Example: safety validate --path /path/to/policy.yml"
)
CLI_GATEWAY_CONFIGURE_COMMAND_HELP = (
"Configures the project in the working directory to use Gateway."
)
# Global options help
_CLI_PROXY_TIP_HELP = "[nhc]Note: proxy details can be set globally in a config file.[/nhc]\n\nSee [bold]safety configure --help[/bold]\n\n"
CLI_PROXY_HOST_HELP = (
"Specify a proxy host for network communications. \n\n" + _CLI_PROXY_TIP_HELP
)
CLI_PROXY_PORT_HELP = "Set the proxy port (default: 80).\n\n" + _CLI_PROXY_TIP_HELP
CLI_PROXY_PROTOCOL_HELP = (
"Choose the proxy protocol (default: https).\n\n" + _CLI_PROXY_TIP_HELP
)
CLI_KEY_HELP = (
"The API key required for cicd stage or production stage scans.\n\n"
"[nhc]For development stage scans unset the API key and authenticate using [bold]safety auth[/bold].[/nhc]\n\n"
"[nhc]Tip: the API key can also be set using the environment variable: SAFETY_API_KEY[/nhc]\n\n"
"[bold]Example: safety --key API_KEY scan[/bold]"
)
CLI_STAGE_HELP = (
"Assign a development lifecycle stage to your scan (default: development).\n\n"
"[nhc]This labels the scan and its findings in Safety Platform with this stage.[/nhc]\n\n"
"[bold]Example: safety --stage production scan[/bold]"
)
CLI_DEBUG_HELP = (
"Enable debug mode for detailed output.\n\n"
"[bold]Example: safety --debug scan[/bold]"
)
CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP = (
"Opt-out of sending optional telemetry data. Anonymized telemetry data will remain.\n\n"
"[bold]Example: safety --disable-optional-telemetry scan[/bold]"
)
# Scan Help options
SCAN_POLICY_FILE_HELP = (
"Use a local policy file to configure the scan.\n\n"
"[nhc]Note: Project scan policies defined in Safety Platform will override local policy files[/nhc]\n\n"
"[bold]Example: safety scan --policy-file /path/to/policy.yml[/bold]"
)
SCAN_TARGET_HELP = (
"Define a specific project path to scan. (default: current directory)\n\n"
"[bold]Example: safety scan --target /path/to/project[/bold]"
)
SCAN_OUTPUT_HELP = (
"Set the output format for scan results (default: screen)\n\n"
"[bold]Example: safety scan --output json[/bold]"
)
SCAN_SAVE_AS_HELP = (
"In addition to regular output save the scan results to a json, html, text, or spdx file using: FORMAT FILE_PATH\n\n"
"[bold]Example: safety scan --save-as json results.json[/bold]"
)
SCAN_DETAILED_OUTPUT = (
"Enable a verbose scan report for detailed insights (only for screen output)\n\n"
"[bold]Example: safety scan --detailed-output[/bold]"
)
SCAN_APPLY_FIXES = (
"[bold]Update packages listed in requirements.txt files to secure versions where possible[/bold]\n\n"
"[nhc]Currently supports: requirements.txt files[/nhc]\n\n"
"Note: this will update your requirements.txt file "
)
# System Scan options
SYSTEM_SCAN_POLICY_FILE_HELP = (
"Use a local policy file to configure the scan.\n\n"
"[nhc]Note: Scan policies defined in Safety Platform will override local policy files[/nhc]\n\n"
"[bold]Example: safety scan --policy-file /path/to/policy.yml[/bold]"
)
SYSTEM_SCAN_TARGET_HELP = (
"Define a specific location to start the system scan. (default: current directory)\n\n"
"[bold]Example: safety scan --target /path/to/project[/bold]"
)
SYSTEM_SCAN_OUTPUT_HELP = (
"Set the output format for scan results (default: screen)\n\n"
"[bold]Example: safety scan --output json[/bold]"
)
SYSTEM_SCAN_SAVE_AS_HELP = (
"In addition to the terminal/console output (set by --output), save system-scan results to a screen (text) or json file.\n\n"
"""[nhc]Use [bold]--save-as <FORMAT> <PATH>[/bold]. For example: [bold]--save-as json my-machine-scan.json[/bold] to save the system-scan results to `my-machine-scan.json` in the current directory[/nhc]\n\n"""
"[nhc][Default: json .][/nhc]"
)
# Auth options
CLI_AUTH_LOGIN_HELP = (
"Authenticate with Safety CLI to perform scans. Your default browser will automatically open to https://platform.safetycli.com unless already authenticated.\n\n"
"[bold]Example: safety auth login[/bold]"
)
CLI_AUTH_LOGOUT_HELP = (
"Log out from the current Safety CLI session.\n\n"
"[bold]Example: safety auth logout[/bold]"
)
CLI_AUTH_STATUS_HELP = (
"Show the current authentication status.\n\n"
"[bold]Example: safety auth status[/bold]"
)
# Configure options
CLI_CONFIGURE_PROXY_HOST_HELP = "Specify a proxy host for network communications to be saved into Safety's configuration. \n\n"
CLI_CONFIGURE_PROXY_PORT_HELP = (
"Set the proxy port to be saved into Safety's configuration file (default: 80).\n\n"
)
CLI_CONFIGURE_PROXY_PROTOCOL_HELP = "Choose the proxy protocol to be saved into Safety's configuration file (default: https).\n\n"
CLI_CONFIGURE_PROXY_TIMEOUT = (
"Set the timeout duration for proxy network calls.\n\n"
+ "[bold]Example: safety configure --proxy-timeout 30[/bold]"
)
CLI_CONFIGURE_PROXY_REQUIRED = (
"Enable or disable the requirement for a proxy in network communications\n\n"
+ "[bold]Example: safety configure --proxy-required[/bold]"
)
CLI_CONFIGURE_ORGANIZATION_ID = (
"Set the current device with an organization ID."
" - see your Safety Platform Organization page\n\n"
+ "[bold]Example: safety configure --organization-id your_org_unique_id[/bold]"
)
CLI_CONFIGURE_ORGANIZATION_NAME = (
"Set the current device with an organization name."
" - see your Safety Platform Organization page.\n\n"
+ '[bold]Example: safety configure --organization-name "Your Org Name"[/bold]'
)
CLI_CONFIGURE_SAVE_TO_SYSTEM = (
"Save the configuration to a system config file.\n"
"This will configure Safety CLI for all users on this machine. Use --save-to-user to "
"configure Safety CLI for only your user.\n\n"
"[bold]Example: safety configure --save-to-system[/bold]"
)
# Generate options
CLI_GENERATE_PATH = (
"The path where the generated file will be saved (default: current directory).\n\n"
"[bold]Example: safety generate policy_file --path .my-project-safety-policy.yml[/bold]"
)
CLI_GENERATE_MINIMUM_CVSS_SEVERITY = (
"The minimum CVSS severity to generate the installation policy for.\n\n"
"[bold]Example: safety generate installation_policy --minimum-cvss-severity high[/bold]"
)
# Command default settings
CMD_PROJECT_NAME = "scan"
CMD_SYSTEM_NAME = "system-scan"
DEFAULT_CMD = CMD_PROJECT_NAME
DEFAULT_SPINNER = "bouncingBar"

View File

@@ -0,0 +1,350 @@
from functools import wraps
import logging
import os
from pathlib import Path
from typing import List, Optional
from rich.padding import Padding
from safety_schemas.models import ConfigModel, ProjectModel
from rich.console import Console
from safety.auth.cli import render_email_note
from safety.cli_util import process_auth_status_not_ready
from safety.console import main_console
from safety.constants import SYSTEM_POLICY_FILE, USER_POLICY_FILE
from safety.errors import SafetyException
from safety.scan.main import download_policy, load_policy_file, resolve_policy
from safety.scan.models import ScanOutput, SystemScanOutput
from safety.scan.render import (
print_announcements,
print_header,
print_wait_policy_download,
)
from safety.scan.util import GIT
from ..codebase_utils import load_unverified_project_from_config
from safety.util import build_telemetry_data, pluralize
from safety_schemas.models import (
MetadataModel,
ScanType,
ReportSchemaVersion,
PolicySource,
)
LOG = logging.getLogger(__name__)
def scan_project_command_init(func):
"""
Decorator to make general verifications before each project scan command.
"""
@wraps(func)
def inner(
ctx,
policy_file_path: Optional[Path],
target: Path,
output: ScanOutput,
console: Console = main_console,
*args,
**kwargs,
):
ctx.obj.console = console
ctx.params.pop("console", None)
if output.is_silent():
console.quiet = True
if not ctx.obj.auth.is_valid():
process_auth_status_not_ready(console=console, auth=ctx.obj.auth, ctx=ctx)
upload_request_id = kwargs.pop("upload_request_id", None)
# Load .safety-project.ini
unverified_project = load_unverified_project_from_config(project_root=target)
print_header(console=console, targets=[target])
stage = ctx.obj.auth.stage
session = ctx.obj.auth.client
git_data = GIT(root=target).build_git_data()
origin = None
branch = None
if git_data:
origin = git_data.origin
branch = git_data.branch
if ctx.obj.platform_enabled:
# TODO: Move this to be injected by a codebase service
from safety.init.main import verify_project
link_behavior = "prompt"
if unverified_project.created:
link_behavior = "always"
verify_project(
console,
ctx,
session,
unverified_project,
origin,
link_behavior=link_behavior,
prompt_for_name=True,
)
else:
ctx.obj.project = ProjectModel(
id="",
name="Undefined project",
project_path=unverified_project.project_path,
)
ctx.obj.project.git = git_data
ctx.obj.project.upload_request_id = upload_request_id
if not policy_file_path:
policy_file_path = target / Path(".safety-policy.yml")
# Load Policy file and pull it from CLOUD
local_policy = kwargs.pop("local_policy", load_policy_file(policy_file_path))
cloud_policy = None
if ctx.obj.platform_enabled:
cloud_policy = print_wait_policy_download(
console,
(
download_policy,
{
"session": session,
"project_id": ctx.obj.project.id,
"stage": stage,
"branch": branch,
},
),
)
ctx.obj.project.policy = resolve_policy(local_policy, cloud_policy)
config = (
ctx.obj.project.policy.config
if ctx.obj.project.policy and ctx.obj.project.policy.config
else ConfigModel()
)
# Preserve global telemetry preference.
if ctx.obj.config:
if ctx.obj.config.telemetry_enabled is not None:
config.telemetry_enabled = ctx.obj.config.telemetry_enabled
ctx.obj.config = config
console.print()
if ctx.obj.auth.org and ctx.obj.auth.org.name:
console.print(f"[bold]Organization[/bold]: {ctx.obj.auth.org.name}")
# Check if an API key is set
if ctx.obj.auth.client.get_authentication_type() == "api_key":
details = {"Account": "API key used"}
else:
if ctx.obj.auth.client.get_authentication_type() == "token":
content = ctx.obj.auth.email
if ctx.obj.auth.name != ctx.obj.auth.email:
content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}"
details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"}
else:
details = {"Account": f"Offline - {os.getenv('SAFETY_DB_DIR')}"}
if ctx.obj.project.id:
details["Project"] = ctx.obj.project.id
if ctx.obj.project.git:
details[" Git branch"] = ctx.obj.project.git.branch # type: ignore
details[" Environment"] = ctx.obj.auth.stage
msg = "None, using Safety CLI default policies"
if ctx.obj.project.policy:
if ctx.obj.project.policy.source is PolicySource.cloud:
msg = (
"fetched from Safety Platform, "
"ignoring any local Safety CLI policy files"
)
else:
if ctx.obj.project.id:
msg = f"local {ctx.obj.project.id} project scan policy"
else:
msg = "local scan policy file"
details[" Scan policy"] = msg
for k, v in details.items():
console.print(f"[scan_meta_title]{k}[/scan_meta_title]: {v}")
print_announcements(console=console, ctx=ctx)
console.print()
result = func(ctx, target=target, output=output, *args, **kwargs)
return result
return inner
def scan_system_command_init(func):
"""
Decorator to make general verifications before each system scan command.
"""
@wraps(func)
def inner(
ctx,
policy_file_path: Optional[Path],
targets: List[Path],
output: SystemScanOutput,
console: Console = main_console,
*args,
**kwargs,
):
ctx.obj.console = console
ctx.params.pop("console", None)
if output.is_silent():
console.quiet = True
if not ctx.obj.auth.is_valid():
process_auth_status_not_ready(console=console, auth=ctx.obj.auth, ctx=ctx)
console.print()
print_header(console=console, targets=targets, is_system_scan=True)
if not policy_file_path:
if SYSTEM_POLICY_FILE.exists():
policy_file_path = SYSTEM_POLICY_FILE
elif USER_POLICY_FILE.exists():
policy_file_path = USER_POLICY_FILE
# Load Policy file
ctx.obj.system_scan_policy = (
load_policy_file(policy_file_path) if policy_file_path else None
)
config = (
ctx.obj.system_scan_policy.config
if ctx.obj.system_scan_policy and ctx.obj.system_scan_policy.config
else ConfigModel()
)
# Preserve global telemetry preference.
if ctx.obj.config:
if ctx.obj.config.telemetry_enabled is not None:
config.telemetry_enabled = ctx.obj.config.telemetry_enabled
ctx.obj.config = config
if not any(targets):
if any(config.scan.system_targets):
targets = [
Path(t).expanduser().absolute() for t in config.scan.system_targets
]
else:
targets = [Path("/")]
ctx.obj.metadata.scan_locations = targets
console.print()
if ctx.obj.auth.org and ctx.obj.auth.org.name:
console.print(f"[bold]Organization[/bold]: {ctx.obj.auth.org.name}")
details = {
"Account": f"{ctx.obj.auth.name}, {ctx.obj.auth.email}",
"Scan stage": ctx.obj.auth.stage,
}
if ctx.obj.system_scan_policy:
if ctx.obj.system_scan_policy.source is PolicySource.cloud:
policy_type = "remote"
else:
policy_type = f'local ("{ctx.obj.system_scan_policy.id}")'
org_name = " "
if ctx.obj.auth.org and ctx.obj.auth.org.name:
org_name = f" {ctx.obj.auth.org.name} "
details["System scan policy"] = (
f"{policy_type}{org_name}organization policy:"
)
for k, v in details.items():
console.print(f"[bold]{k}[/bold]: {v}")
if ctx.obj.system_scan_policy:
dirs = [ign for ign in ctx.obj.config.scan.ignore if Path(ign).is_dir()]
policy_details = [
f"-> scanning from root {', '.join([str(t) for t in targets])} to a max folder depth of {ctx.obj.config.scan.max_depth}",
f"-> excluding {len(dirs)} {pluralize('directory', len(dirs))} and their sub-directories",
"-> target ecosystems: Python",
]
for policy_detail in policy_details:
console.print(Padding(policy_detail, (0, 0, 0, 1)), emoji=True)
print_announcements(console=console, ctx=ctx)
console.print()
kwargs.update({"targets": targets})
result = func(ctx, *args, **kwargs)
return result
return inner
def inject_metadata(func):
"""
Build metadata per subcommand. A system scan can trigger a project scan,
the project scan will need to build its own metadata.
"""
@wraps(func)
def inner(ctx, *args, **kwargs):
telemetry = build_telemetry_data(
telemetry=ctx.obj.config.telemetry_enabled,
command=ctx.command.name,
subcommand=ctx.invoked_subcommand,
)
auth_type = ctx.obj.auth.client.get_authentication_type()
scan_type = ScanType(ctx.command.name)
target = kwargs.get("target", None)
targets = kwargs.get("targets", None)
if not scan_type:
raise SafetyException("Missing scan_type.")
if scan_type is ScanType.scan:
if not target:
raise SafetyException("Missing target.")
targets = [target]
metadata = MetadataModel(
scan_type=scan_type,
stage=ctx.obj.auth.stage,
scan_locations=targets, # type: ignore
authenticated=ctx.obj.auth.client.is_using_auth_credentials(),
authentication_type=auth_type,
telemetry=telemetry,
schema_version=ReportSchemaVersion.v3_0,
)
ctx.obj.schema = ReportSchemaVersion.v3_0
ctx.obj.metadata = metadata
ctx.obj.telemetry = telemetry
return func(ctx, *args, **kwargs)
return inner

View File

@@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from typing import List
from safety_schemas.models import Ecosystem, FileType, ConfigModel, DependencyResultModel
from typer import FileTextWrite
NOT_IMPLEMENTED = "Not implemented funtion"
class Inspectable(ABC):
"""
Abstract base class defining the interface for objects that can be inspected for dependencies.
"""
@abstractmethod
def inspect(self, config: ConfigModel) -> DependencyResultModel:
"""
Inspects the object and returns the result of the dependency analysis.
Args:
config (ConfigModel): The configuration model for inspection.
Returns:
DependencyResultModel: The result of the dependency inspection.
"""
return NotImplementedError(NOT_IMPLEMENTED)
class Remediable(ABC):
"""
Abstract base class defining the interface for objects that can be remediated.
"""
@abstractmethod
def remediate(self):
"""
Remediates the object to address any detected issues.
"""
return NotImplementedError(NOT_IMPLEMENTED)
class InspectableFile(Inspectable):
"""
Represents an inspectable file within a specific ecosystem and file type.
"""
def __init__(self, file: FileTextWrite):
"""
Initializes an InspectableFile instance.
Args:
file (FileTextWrite): The file to be inspected.
"""
self.file = file
self.ecosystem: Ecosystem
self.file_type: FileType
self.dependency_results: DependencyResultModel = \
DependencyResultModel(dependencies=[])

View File

@@ -0,0 +1,294 @@
from collections import defaultdict
from pathlib import Path
import sys
from typing import Generator, List, Optional
from safety_schemas.models import FileType, PythonDependency
from safety_schemas.models.package import PythonSpecification
from ..base import InspectableFile
from dparse import parse, filetypes
from packaging.specifiers import SpecifierSet
from packaging.version import parse as parse_version
from packaging.utils import canonicalize_name
def get_closest_ver(versions: List[str], version: Optional[str], spec: SpecifierSet) -> dict:
"""
Gets the closest version to the specified version within a list of versions.
Args:
versions (List[str]): The list of versions.
version (Optional[str]): The target version.
spec (SpecifierSet): The version specifier set.
Returns:
dict: A dictionary containing the upper and lower closest versions.
"""
results = {'upper': None, 'lower': None}
if (not version and not spec) or not versions:
return results
sorted_versions = sorted(versions, key=lambda ver: parse_version(ver), reverse=True)
if not version:
sorted_versions = spec.filter(sorted_versions, prereleases=False)
upper = None
lower = None
try:
sorted_versions = list(sorted_versions)
upper = sorted_versions[0]
lower = sorted_versions[-1]
results['upper'] = upper
results['lower'] = lower if upper != lower else None
except IndexError:
pass
return results
current_v = parse_version(version)
for v in sorted_versions:
index = parse_version(v)
if index > current_v:
results['upper'] = index
if index < current_v:
results['lower'] = index
break
return results
def is_pinned_requirement(spec: SpecifierSet) -> bool:
"""
Checks if a requirement is pinned.
Args:
spec (SpecifierSet): The version specifier set.
Returns:
bool: True if the requirement is pinned, False otherwise.
"""
if not spec or len(spec) != 1:
return False
specifier = next(iter(spec))
return (specifier.operator == '==' and '*' != specifier.version[-1]) \
or specifier.operator == '==='
def find_version(requirements: List[PythonSpecification]) -> Optional[str]:
"""
Finds the version of a requirement.
Args:
requirements (List[PythonSpecification]): The list of requirements.
Returns:
Optional[str]: The version if found, otherwise None.
"""
ver = None
if len(requirements) != 1:
return ver
specs = requirements[0].specifier
if is_pinned_requirement(specs):
ver = next(iter(requirements[0].specifier)).version
return ver
def is_supported_by_parser(path: str) -> bool:
"""
Checks if the file path is supported by the parser.
Args:
path (str): The file path.
Returns:
bool: True if supported, False otherwise.
"""
supported_types = (".txt", ".in", ".yml", ".ini", "Pipfile",
"Pipfile.lock", "setup.cfg", "poetry.lock")
return path.endswith(supported_types)
def parse_requirement(dep: str, found: Optional[str]) -> PythonSpecification:
"""
Parses a requirement and creates a PythonSpecification object.
Args:
dep (str): The dependency string.
found (Optional[str]): The found path.
Returns:
PythonSpecification: The parsed requirement.
"""
req = PythonSpecification(dep)
req.found = Path(found).resolve() if found else None
if req.specifier == SpecifierSet(''):
req.specifier = SpecifierSet('>=0')
return req
def read_requirements(fh, resolve: bool = True) -> Generator[PythonDependency, None, None]:
"""
Reads requirements from a file-like object and (optionally) from referenced files.
Args:
fh: The file-like object to read from.
resolve (bool): Whether to resolve referenced files.
Returns:
Generator[PythonDependency, None, None]: A generator of PythonDependency objects.
"""
is_temp_file = not hasattr(fh, 'name')
path = None
found = Path('temp_file')
file_type = filetypes.requirements_txt
absolute_path: Optional[Path] = None
if not is_temp_file and is_supported_by_parser(fh.name):
path = fh.name
absolute_path = Path(path).resolve()
found = absolute_path
file_type = None
content = fh.read()
dependency_file = parse(content, path=path, resolve=resolve,
file_type=file_type)
reqs_pkg = defaultdict(list)
for req in dependency_file.resolved_dependencies:
reqs_pkg[canonicalize_name(req.name)].append(req)
for pkg, reqs in reqs_pkg.items():
specifications = list(
map(lambda req: parse_requirement(req, str(absolute_path)), reqs))
version = find_version(specifications)
yield PythonDependency(name=pkg, version=version,
specifications=specifications,
found=found,
absolute_path=absolute_path,
insecure_versions=[],
secure_versions=[], latest_version=None,
latest_version_without_known_vulnerabilities=None,
more_info_url=None)
def read_dependencies(fh, resolve: bool = True) -> Generator[PythonDependency, None, None]:
"""
Reads dependencies from a file-like object.
Args:
fh: The file-like object to read from.
resolve (bool): Whether to resolve referenced files.
Returns:
Generator[PythonDependency, None, None]: A generator of PythonDependency objects.
"""
path = fh.name
absolute_path = Path(path).resolve()
found = absolute_path
content = fh.read()
dependency_file = parse(content, path=path, resolve=resolve)
reqs_pkg = defaultdict(list)
for req in dependency_file.resolved_dependencies:
reqs_pkg[canonicalize_name(req.name)].append(req)
for pkg, reqs in reqs_pkg.items():
specifications = list(
map(lambda req: parse_requirement(req, str(absolute_path)), reqs))
version = find_version(specifications)
yield PythonDependency(name=pkg, version=version,
specifications=specifications,
found=found,
absolute_path=absolute_path,
insecure_versions=[],
secure_versions=[], latest_version=None,
latest_version_without_known_vulnerabilities=None,
more_info_url=None)
def read_virtual_environment_dependencies(f: InspectableFile) -> Generator[PythonDependency, None, None]:
"""
Reads dependencies from a virtual environment.
Args:
f (InspectableFile): The inspectable file representing the virtual environment.
Returns:
Generator[PythonDependency, None, None]: A generator of PythonDependency objects.
"""
env_path = Path(f.file.name).resolve().parent
if sys.platform.startswith('win'):
site_pkgs_path = env_path / Path("Lib/site-packages/")
else:
site_pkgs_path = Path('lib/')
try:
site_pkgs_path = next((env_path / site_pkgs_path).glob("*/site-packages/"))
except StopIteration:
# Unable to find packages for foo env
return
if not site_pkgs_path.resolve().exists():
# Unable to find packages for foo env
return
dep_paths = site_pkgs_path.glob("*/METADATA")
for path in dep_paths:
if not path.is_file():
continue
dist_info_folder = path.parent
dep_name, dep_version = dist_info_folder.name.replace(".dist-info", "").split("-")
yield PythonDependency(name=dep_name, version=dep_version,
specifications=[
PythonSpecification(f"{dep_name}=={dep_version}",
found=site_pkgs_path)],
found=site_pkgs_path, insecure_versions=[],
secure_versions=[], latest_version=None,
latest_version_without_known_vulnerabilities=None,
more_info_url=None)
def get_dependencies(f: InspectableFile) -> List[PythonDependency]:
"""
Gets the dependencies for the given inspectable file.
Args:
f (InspectableFile): The inspectable file.
Returns:
List[PythonDependency]: A list of PythonDependency objects.
"""
if not f.file_type:
return []
if f.file_type in [FileType.REQUIREMENTS_TXT, FileType.POETRY_LOCK,
FileType.PIPENV_LOCK, FileType.PYPROJECT_TOML]:
return list(read_dependencies(f.file, resolve=True))
if f.file_type == FileType.VIRTUAL_ENVIRONMENT:
return list(read_virtual_environment_dependencies(f))
return []

View File

@@ -0,0 +1,463 @@
from datetime import datetime
import itertools
import logging
from typing import List
from safety_schemas.models import FileType, PythonDependency, ClosestSecureVersion, \
ConfigModel, PythonSpecification, RemediationModel, DependencyResultModel, \
Vulnerability
from safety_schemas.models import VulnerabilitySeverityLabels, IgnoredItemDetail, \
IgnoredItems, IgnoreCodes
from typer import FileTextWrite
from safety.models import Severity
from safety.util import build_remediation_info_url
from ....constants import IGNORE_UNPINNED_REQ_REASON
from ....safety import get_cve_from, get_from_cache, get_vulnerabilities
from ..python.dependencies import get_closest_ver, get_dependencies, \
is_pinned_requirement
from ..base import InspectableFile, Remediable
from packaging.version import parse as parse_version
from packaging.utils import canonicalize_name
from packaging.specifiers import SpecifierSet
LOG = logging.getLogger(__name__)
def ignore_vuln_if_needed(
dependency: PythonDependency, file_type: FileType,
vuln_id: str, cve, ignore_vulns,
ignore_unpinned: bool, ignore_environment: bool,
specification: PythonSpecification,
ignore_severity: List[VulnerabilitySeverityLabels] = []
) -> None:
"""
Ignores vulnerabilities based on the provided rules.
Args:
dependency (PythonDependency): The Python dependency.
file_type (FileType): The type of the file.
vuln_id (str): The vulnerability ID.
cve: The CVE object.
ignore_vulns: The dictionary of ignored vulnerabilities.
ignore_unpinned (bool): Whether to ignore unpinned specifications.
ignore_environment (bool): Whether to ignore environment results.
specification (PythonSpecification): The specification.
ignore_severity (List[VulnerabilitySeverityLabels]): List of severity labels to ignore.
"""
vuln_ignored: bool = vuln_id in ignore_vulns
if vuln_ignored and ignore_vulns[vuln_id].code is IgnoreCodes.manual:
if (not ignore_vulns[vuln_id].expires
or ignore_vulns[vuln_id].expires > datetime.utcnow().date()):
return
del ignore_vulns[vuln_id]
if ignore_environment and file_type is FileType.VIRTUAL_ENVIRONMENT:
reason = "Ignored environment by rule in policy file."
ignore_vulns[vuln_id] = IgnoredItemDetail(
code=IgnoreCodes.environment_dependency, reason=reason)
return
severity_label = VulnerabilitySeverityLabels.UNKNOWN
if cve:
if cve.cvssv3 and cve.cvssv3.get("base_severity", None):
severity_label = VulnerabilitySeverityLabels(
cve.cvssv3["base_severity"].lower())
if severity_label in ignore_severity:
reason = f"{severity_label.value.capitalize()} severity ignored by rule in policy file."
ignore_vulns[vuln_id] = IgnoredItemDetail(
code=IgnoreCodes.cvss_severity, reason=reason)
return
spec_ignored: bool = False
vuln = ignore_vulns.get(vuln_id)
if vuln is not None and vuln.specifications is not None and str(specification.specifier) in vuln.specifications:
spec_ignored = True
if (not spec_ignored) and \
(ignore_unpinned and not specification.is_pinned()):
reason = IGNORE_UNPINNED_REQ_REASON
specifications = set()
specifications.add(str(specification.specifier))
ignore_vulns[vuln_id] = IgnoredItemDetail(
code=IgnoreCodes.unpinned_specification, reason=reason,
specifications=specifications)
def should_fail(config: ConfigModel, vulnerability: Vulnerability) -> bool:
"""
Determines if a vulnerability should cause a failure based on the configuration.
Args:
config (ConfigModel): The configuration model.
vulnerability (Vulnerability): The vulnerability.
Returns:
bool: True if the vulnerability should cause a failure, False otherwise.
"""
if not config.depedendency_vulnerability.fail_on.enabled:
return False
# If Severity is None type, it will be considered as UNKNOWN and NONE
# They are not the same, but we are handling like the same when a
# vulnerability does not have a severity value.
severities = [VulnerabilitySeverityLabels.NONE,
VulnerabilitySeverityLabels.UNKNOWN]
if vulnerability.severity and vulnerability.severity.cvssv3:
base_severity = vulnerability.severity.cvssv3.get(
"base_severity")
if base_severity:
base_severity = base_severity.lower()
# A vulnerability only has a single severity value, this is just
# to handle cases where the severity value is not in the expected
# format and fallback to the default severity values [None, unknown].
matched_severities = [
label
for label in VulnerabilitySeverityLabels
if label.value == base_severity
]
if matched_severities:
severities = matched_severities
else:
LOG.warning(
f"Unexpected base severity value {base_severity} for "
f"{vulnerability.vulnerability_id}"
)
return any(
severity in config.depedendency_vulnerability.fail_on.cvss_severity
for severity in severities
)
def get_vulnerability(
vuln_id: str, cve, data, specifier,
db, name, ignore_vulns: IgnoredItems,
affected: PythonSpecification
) -> Vulnerability:
"""
Creates a Vulnerability object from the given data.
Args:
vuln_id (str): The vulnerability ID.
cve: The CVE object.
data: The vulnerability data.
specifier: The specifier set.
db: The database.
name: The package name.
ignore_vulns (IgnoredItems): The ignored vulnerabilities.
affected (PythonSpecification): The affected specification.
Returns:
Vulnerability: The created Vulnerability object.
"""
base_domain = db.get('meta', {}).get('base_domain')
unpinned_ignored = ignore_vulns[vuln_id].specifications \
if vuln_id in ignore_vulns.keys() else None
should_ignore = not unpinned_ignored or str(affected.specifier) in unpinned_ignored
ignored: bool = bool(ignore_vulns and
vuln_id in ignore_vulns and
should_ignore)
more_info_url = f"{base_domain}{data.get('more_info_path', '')}"
severity = None
if cve and (cve.cvssv2 or cve.cvssv3):
severity = Severity(source=cve.name, cvssv2=cve.cvssv2, cvssv3=cve.cvssv3)
analyzed_requirement = affected
analyzed_version = next(iter(analyzed_requirement.specifier)).version if affected.is_pinned() else None
vulnerable_spec = set()
vulnerable_spec.add(specifier)
reason = None
expires = None
ignore_code = None
if ignored:
reason = ignore_vulns[vuln_id].reason
expires = str(ignore_vulns[vuln_id].expires) if ignore_vulns[vuln_id].expires else None
ignore_code = ignore_vulns[vuln_id].code.value
return Vulnerability(
vulnerability_id=vuln_id,
package_name=name,
ignored=ignored,
ignored_reason=reason,
ignored_expires=expires,
ignored_code=ignore_code,
vulnerable_spec=vulnerable_spec,
all_vulnerable_specs=data.get("specs", []),
analyzed_version=analyzed_version,
analyzed_requirement=str(analyzed_requirement),
advisory=data.get("advisory"),
is_transitive=data.get("transitive", False),
published_date=data.get("published_date"),
fixed_versions=[ver for ver in data.get("fixed_versions", []) if ver],
closest_versions_without_known_vulnerabilities=data.get("closest_secure_versions", []),
resources=data.get("vulnerability_resources"),
CVE=cve,
severity=severity,
affected_versions=data.get("affected_versions", []),
more_info_url=more_info_url
)
class PythonFile(InspectableFile, Remediable):
"""
A class representing a Python file that can be inspected for vulnerabilities and remediated.
"""
def __init__(self, file_type: FileType, file: FileTextWrite) -> None:
"""
Initializes the PythonFile instance.
Args:
file_type (FileType): The type of the file.
file (FileTextWrite): The file object.
"""
super().__init__(file=file)
self.ecosystem = file_type.ecosystem
self.file_type = file_type
def __find_dependency_vulnerabilities__(self, dependencies: List[PythonDependency],
config: ConfigModel) -> None:
"""
Finds vulnerabilities in the dependencies.
Args:
dependencies (List[PythonDependency]): The list of dependencies.
config (ConfigModel): The configuration model.
"""
ignored_vulns_data = {}
ignore_vulns = {} \
if not config.depedendency_vulnerability.ignore_vulnerabilities \
else config.depedendency_vulnerability.ignore_vulnerabilities
ignore_severity = config.depedendency_vulnerability.ignore_cvss_severity
ignore_unpinned = config.depedendency_vulnerability.python_ignore.unpinned_specifications
ignore_environment = config.depedendency_vulnerability.python_ignore.environment_results
db = get_from_cache(db_name="insecure.json", skip_time_verification=True)
if not db:
LOG.debug("Cache data for insecure.json is not available or is invalid.")
return
db_full = None
vulnerable_packages = frozenset(db.get('vulnerable_packages', []))
found_dependencies = {}
specifications = iter([])
for dependency in dependencies:
specifications = itertools.chain(dependency.specifications, specifications)
found_dependencies[
canonicalize_name(dependency.name)
] = dependency
# Let's report by req, pinned in environment will be ==version
for spec in specifications:
vuln_per_req = {}
name = canonicalize_name(spec.name)
dependency: PythonDependency = found_dependencies.get(name, None)
if not dependency:
continue
if not dependency.version:
if not db_full:
db_full = get_from_cache(db_name="insecure_full.json",
skip_time_verification=True)
if not db_full:
LOG.debug("Cache data for insecure_full.json is not available or is invalid.")
return
dependency.refresh_from(db_full)
if name in vulnerable_packages:
# we have a candidate here, build the spec set
for specifier in db['vulnerable_packages'][name]:
spec_set = SpecifierSet(specifiers=specifier)
if spec.is_vulnerable(spec_set, dependency.insecure_versions):
if not db_full:
db_full = get_from_cache(db_name="insecure_full.json",
skip_time_verification=True)
if not db_full:
LOG.debug("Cache data for insecure_full.json is not available or is invalid.")
return
if not dependency.latest_version:
dependency.refresh_from(db_full)
for data in get_vulnerabilities(pkg=name, spec=specifier, db=db_full):
try:
vuln_id: str = str(next(filter(lambda i: i.get('type', None) == 'pyup', data.get('ids', []))).get('id', ''))
except StopIteration:
vuln_id: str = ''
if vuln_id in vuln_per_req:
vuln_per_req[vuln_id].vulnerable_spec.add(specifier)
continue
cve = get_cve_from(data, db_full)
ignore_vuln_if_needed(dependency=dependency,
file_type=self.file_type,
vuln_id=vuln_id, cve=cve,
ignore_vulns=ignore_vulns,
ignore_severity=ignore_severity,
ignore_unpinned=ignore_unpinned,
ignore_environment=ignore_environment,
specification=spec)
include_ignored = True
vulnerability = get_vulnerability(vuln_id, cve, data,
specifier, db_full,
name, ignore_vulns, spec)
should_add_vuln = not (vulnerability.is_transitive and
dependency.found and
dependency.found.parts[-1] == FileType.VIRTUAL_ENVIRONMENT.value)
if vulnerability.ignored:
ignored_vulns_data[
vulnerability.vulnerability_id] = vulnerability
if not self.dependency_results.failed and not vulnerability.ignored:
self.dependency_results.failed = should_fail(config, vulnerability)
if (include_ignored or vulnerability.vulnerability_id not in ignore_vulns) and should_add_vuln:
vuln_per_req[vulnerability.vulnerability_id] = vulnerability
spec.vulnerabilities.append(vulnerability)
# TODO: dep_result Save if it should fail the JOB
self.dependency_results.dependencies = [dep for _, dep in found_dependencies.items()]
self.dependency_results.ignored_vulns = ignore_vulns
self.dependency_results.ignored_vulns_data = ignored_vulns_data
def inspect(self, config: ConfigModel) -> None:
"""
Inspects the file for vulnerabilities based on the given configuration.
Args:
config (ConfigModel): The configuration model.
"""
# We only support vulnerability checking for now
dependencies = get_dependencies(self)
if not dependencies:
self.results = []
self.__find_dependency_vulnerabilities__(dependencies=dependencies,
config=config)
def __get_secure_specifications_for_user__(self, dependency: PythonDependency, db_full,
secure_vulns_by_user=None) -> List[str]:
"""
Gets secure specifications for the user.
Args:
dependency (PythonDependency): The Python dependency.
db_full: The full database.
secure_vulns_by_user: The set of secure vulnerabilities by user.
Returns:
List[str]: The list of secure specifications.
"""
if not db_full:
return
if not secure_vulns_by_user:
secure_vulns_by_user = set()
versions = dependency.get_versions(db_full)
affected_versions = []
for vuln in db_full.get('vulnerable_packages', {}).get(dependency.name, []):
vuln_id: str = str(next(filter(lambda i: i.get('type', None) == 'pyup', vuln.get('ids', []))).get('id', ''))
if vuln_id and vuln_id not in secure_vulns_by_user:
affected_versions += vuln.get('affected_versions', [])
affected_v = set(affected_versions)
sec_ver_for_user = list(versions.difference(affected_v))
return sorted(sec_ver_for_user, key=lambda ver: parse_version(ver), reverse=True)
def remediate(self) -> None:
"""
Remediates the vulnerabilities in the file.
"""
db_full = get_from_cache(db_name="insecure_full.json",
skip_time_verification=True)
if not db_full:
return
for dependency in self.dependency_results.get_affected_dependencies():
secure_versions = dependency.secure_versions
if not secure_versions:
secure_versions = []
secure_vulns_by_user = set(self.dependency_results.ignored_vulns.keys())
if not secure_vulns_by_user:
secure_v = sorted(secure_versions, key=lambda ver: parse_version(ver),
reverse=True)
else:
secure_v = self.__get_secure_specifications_for_user__(
dependency=dependency, db_full=db_full,
secure_vulns_by_user=secure_vulns_by_user)
for specification in dependency.specifications:
if len(specification.vulnerabilities) <= 0:
continue
version = None
if is_pinned_requirement(specification.specifier):
version = next(iter(specification.specifier)).version
closest_secure = {key: str(value) if value else None for key, value in
get_closest_ver(secure_v,
version,
specification.specifier).items()}
closest_secure = ClosestSecureVersion(**closest_secure)
recommended = None
if closest_secure.upper:
recommended = closest_secure.upper
elif closest_secure.lower:
recommended = closest_secure.lower
other_recommended = [other_v for other_v in secure_v if other_v != str(recommended)]
remed_more_info_url = dependency.more_info_url
if remed_more_info_url:
remed_more_info_url = build_remediation_info_url(
base_url=remed_more_info_url, version=version,
spec=str(specification.specifier),
target_version=recommended)
if not remed_more_info_url:
remed_more_info_url = "-"
vulns_found = sum(1 for vuln in specification.vulnerabilities if not vuln.ignored)
specification.remediation = RemediationModel(vulnerabilities_found=vulns_found,
more_info_url=remed_more_info_url,
closest_secure=closest_secure if recommended else None,
recommended=recommended,
other_recommended=other_recommended)

View File

@@ -0,0 +1,88 @@
from pathlib import Path
import logging
from safety_schemas.models import Ecosystem, FileType
from typer import FileTextWrite
from .python.main import PythonFile
from ...encoding import detect_encoding
logger = logging.getLogger(__name__)
class InspectableFileContext:
"""
Context manager for handling the lifecycle of an inspectable file.
This class ensures that the file is properly opened and closed, handling any
exceptions that may occur during the process.
"""
def __init__(self, file_path: Path, file_type: FileType) -> None:
"""
Initializes the InspectableFileContext.
Args:
file_path (Path): The path to the file.
file_type (FileType): The type of the file.
"""
self.file_path = file_path
self.inspectable_file = None
self.file_type = file_type
def __enter__(self): # TODO: Handle permission issue /Applications/...
"""
Enters the runtime context related to this object.
Opens the file and creates the appropriate inspectable file object based on the file type.
Returns:
The inspectable file object.
"""
try:
encoding = detect_encoding(self.file_path)
file: FileTextWrite = open(self.file_path, mode="r+", encoding=encoding) # type: ignore
self.inspectable_file = TargetFile.create(
file_type=self.file_type, file=file
)
except Exception:
logger.exception("Error opening file")
return self.inspectable_file
def __exit__(self, exc_type, exc_value, traceback):
"""
Exits the runtime context related to this object.
Ensures that the file is properly closed.
"""
if self.inspectable_file:
self.inspectable_file.file.close()
class TargetFile:
"""
Factory class for creating inspectable file objects based on the file type and ecosystem.
"""
@classmethod
def create(cls, file_type: FileType, file: FileTextWrite):
"""
Creates an inspectable file object based on the file type and ecosystem.
Args:
file_type (FileType): The type of the file.
file (FileTextWrite): The file object.
Returns:
An instance of the appropriate inspectable file class.
Raises:
ValueError: If the ecosystem or file type is unsupported.
"""
if file_type.ecosystem == Ecosystem.PYTHON:
return PythonFile(file=file, file_type=file_type)
raise ValueError(
"Unsupported ecosystem or file type: "
f"{file_type.ecosystem}:{file_type.value}"
)

View File

@@ -0,0 +1,7 @@
from .file_finder import FileFinder
from .handlers import PythonFileHandler
__all__ = [
"FileFinder",
"PythonFileHandler"
]

View File

@@ -0,0 +1,167 @@
# type: ignore
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from safety_schemas.models import Ecosystem, FileType
from .handlers import FileHandler, ECOSYSTEM_HANDLER_MAPPING
LOG = logging.getLogger(__name__)
def should_exclude(excludes: Set[Path], to_analyze: Path) -> bool:
"""
Determines whether a given path should be excluded based on the provided exclusion set.
Args:
excludes (Set[Path]): Set of paths to exclude.
to_analyze (Path): The path to analyze.
Returns:
bool: True if the path should be excluded, False otherwise.
"""
if not to_analyze.is_absolute():
to_analyze = to_analyze.resolve()
for exclude in excludes:
if not exclude.is_absolute():
exclude = exclude.resolve()
try:
if to_analyze == exclude or to_analyze.relative_to(exclude):
return True
except ValueError:
pass
return False
class FileFinder:
""" "
Defines a common interface to agree in what type of components Safety is trying to
find depending on the language type.
"""
def __init__(
self,
max_level: int,
ecosystems: List[Ecosystem],
target: Path,
live_status=None,
exclude: Optional[List[str]] = None,
include_files: Optional[Dict[FileType, List[Path]]] = None,
handlers: Optional[Set[FileHandler]] = None,
) -> None:
"""
Initializes the FileFinder with the specified parameters.
Args:
max_level (int): Maximum directory depth to search.
ecosystems (List[Ecosystem]): List of ecosystems to consider.
target (Path): Target directory to search.
console: Console object for output.
live_status: Live status object for updates.
exclude (Optional[List[str]]): List of patterns to exclude from the search.
include_files (Optional[Dict[FileType, List[Path]]]): Dictionary of files to include in the search.
handlers (Optional[Set[FileHandler]]): Set of file handlers.
"""
self.max_level = max_level
self.target = target
self.include_files = include_files
# If no handlers are provided, initialize them from the ecosystem mapping
if not handlers:
handlers = set(
ECOSYSTEM_HANDLER_MAPPING[ecosystem]() for ecosystem in ecosystems
)
self.handlers = handlers
self.file_count = 0
self.exclude_dirs: Set[Path] = set()
self.exclude_files: Set[Path] = set()
exclude = [] if not exclude else exclude
# Populate the exclude_dirs and exclude_files sets based on the provided patterns
for pattern in exclude:
for path in Path(target).glob(pattern):
if path.is_dir():
self.exclude_dirs.add(path)
else:
self.exclude_files.add(path)
self.live_status = live_status
def process_directory(
self, dir_path: str, max_deep: Optional[int] = None
) -> Tuple[str, Dict[str, Set[Path]]]:
"""
Processes the specified directory to find files matching the handlers' criteria.
Args:
dir_path (str): The directory path to process.
max_deep (Optional[int]): Maximum depth to search within the directory.
Returns:
Tuple[str, Dict[str, Set[Path]]]: The directory path and a dictionary of file types and their corresponding paths.
"""
files: Dict[str, Set[Path]] = {}
level: int = 0
initial_depth = len(Path(dir_path).parts) - 1
for root, dirs, filenames in os.walk(dir_path):
root_path = Path(root)
current_depth = len(root_path.parts) - initial_depth
# Filter directories based on exclusion criteria
dirs[:] = [
d
for d in dirs
if not should_exclude(
excludes=self.exclude_dirs, to_analyze=(root_path / Path(d))
)
]
if dirs:
LOG.info(f"Directories to inspect -> {', '.join(dirs)}")
LOG.info(f"Current -> {root}")
if self.live_status:
self.live_status.update(f":mag: Scanning {root}")
# Stop descending into directories if the maximum depth is reached
if max_deep is not None and current_depth > max_deep:
# Don't go deeper
del dirs[:]
# Filter filenames based on exclusion criteria
filenames[:] = [
f
for f in filenames
if not should_exclude(excludes=self.exclude_files, to_analyze=Path(f))
]
self.file_count += len(filenames)
for file_name in filenames:
for handler in self.handlers:
file_type = handler.can_handle(root, file_name, self.include_files)
if file_type:
inspectable_file: Path = Path(root, file_name)
if file_type.value not in files or not files[file_type.value]:
files[file_type.value] = set()
files[file_type.value].add(inspectable_file)
break
level += 1
return dir_path, files
def search(self) -> Tuple[str, Dict[str, Set[Path]]]:
"""
Initiates the search for files within the target directory.
Returns:
Tuple[str, Dict[str, Set[Path]]]: The target directory and a dictionary of file types and their corresponding paths.
"""
return self.process_directory(self.target, self.max_level)

View File

@@ -0,0 +1,122 @@
from abc import ABC, abstractmethod
import os
from pathlib import Path
from types import MappingProxyType
from typing import Dict, List, Optional, Optional, Tuple
from safety_schemas.models import Ecosystem, FileType
NOT_IMPLEMENTED = "You should implement this."
class FileHandler(ABC):
"""
Abstract base class for file handlers that define how to handle specific types of files
within an ecosystem.
"""
def __init__(self) -> None:
self.ecosystem: Optional[Ecosystem] = None
def can_handle(self, root: str, file_name: str, include_files: Dict[FileType, List[Path]]) -> Optional[FileType]:
"""
Determines if the handler can handle the given file based on its type and inclusion criteria.
Args:
root (str): The root directory of the file.
file_name (str): The name of the file.
include_files (Dict[FileType, List[Path]]): Dictionary of file types and their paths to include.
Returns:
Optional[FileType]: The type of the file if it can be handled, otherwise None.
"""
# Keeping it simple for now
if not self.ecosystem:
return None
for f_type in self.ecosystem.file_types:
if f_type in include_files:
current = Path(root, file_name).resolve()
paths = [p.resolve() if p.is_absolute() else (root / p).resolve() for p in include_files[f_type]]
if current in paths:
return f_type
# Let's compare by name only for now
# We can put heavier logic here, but for speed reasons,
# right now is very basic, we will improve this later.
# Custom matching per File Type
if file_name.lower().endswith(f_type.value.lower()):
return f_type
return None
@abstractmethod
def download_required_assets(self, session) -> Dict[str, str]:
"""
Abstract method to download required assets for handling files. Should be implemented
by subclasses.
Args:
session: The session object for making network requests.
Returns:
Dict[str, str]: A dictionary of downloaded assets.
"""
return NotImplementedError(NOT_IMPLEMENTED)
class PythonFileHandler(FileHandler):
"""
Handler for Python files within the Python ecosystem.
"""
# Example of a Python File Handler
def __init__(self) -> None:
super().__init__()
self.ecosystem = Ecosystem.PYTHON
def download_required_assets(self, session) -> None:
"""
Downloads the required assets for handling Python files, specifically the Safety database.
Args:
session: The session object for making network requests.
"""
from safety.safety import fetch_database
SAFETY_DB_DIR = os.getenv("SAFETY_DB_DIR")
db = False if SAFETY_DB_DIR is None else SAFETY_DB_DIR
# Fetch both the full and partial Safety databases
fetch_database(session=session, full=False, db=db, cached=True,
telemetry=True, ecosystem=Ecosystem.PYTHON,
from_cache=False)
fetch_database(session=session, full=True, db=db, cached=True,
telemetry=True, ecosystem=Ecosystem.PYTHON,
from_cache=False)
class SafetyProjectFileHandler(FileHandler):
"""
Handler for Safety project files within the Safety project ecosystem.
"""
# Example of a Python File Handler
def __init__(self) -> None:
super().__init__()
self.ecosystem = Ecosystem.SAFETY_PROJECT
def download_required_assets(self, session) -> None:
"""
No required assets to download for Safety project files.
"""
pass
# Mapping of ecosystems to their corresponding file handlers
ECOSYSTEM_HANDLER_MAPPING = MappingProxyType({
Ecosystem.PYTHON: PythonFileHandler,
Ecosystem.SAFETY_PROJECT: SafetyProjectFileHandler,
})

View File

@@ -0,0 +1,227 @@
import random
import time
from rich.text import Text
from rich.style import Style
# -----------------------------
# Celebration Effects
# -----------------------------
def show_confetti(console):
# Characters to use as confetti
chars = ["*", "o", "+", "~"]
width = console.size.width
height = console.size.height
frames = 10
for _ in range(frames):
console.clear()
for __ in range(random.randint(50, 100)): # number of confetti pieces
x = random.randint(0, max(0, width - 1))
y = random.randint(0, max(0, height - 2))
char = random.choice(chars)
color = random.choice(["red", "green", "yellow", "blue", "magenta", "cyan"])
console.print(
Text(char, style=Style(color=color)),
end="",
style=color,
justify="left",
overflow="ignore",
no_wrap=True,
soft_wrap=False,
)
console.file.write(f"\x1b[{y};{x}H") # Move cursor to position
console.file.flush()
time.sleep(0.3)
console.clear()
console.print(
"The confetti has settled! Congrats on a clean scan!", style="bold green"
)
def show_trophy(console):
"""Displays a celebratory trophy with sparkles."""
trophy = (
r"""
___________
'._==_==_=_.'
.-\: /-.
| (|:. |) |
'-|:. |-'
\::. /
'::. .'
) (
_.' '._
`"""
""""`
"""
)
for _ in range(5): # Trophy animation
console.clear()
sparkles = random.choice(
[":sparkles:", ":glowing_star:", ":dizzy:", ":party_popper:"]
)
console.print(trophy, style="bold yellow")
console.print(
f"{sparkles} Scan Complete! No vulnerabilities found! {sparkles}",
style="bold green",
justify="center",
)
time.sleep(0.5)
console.print("Your code is SAFE and SOUND! :trophy:", style="bold yellow")
def show_balloons(console):
"""Displays celebratory balloons popping."""
balloons = [":balloon:", ":party_popper:", ":sparkles:", ":collision:"]
width = console.size.width
for _ in range(10): # Number of balloons
console.clear()
for __ in range(random.randint(5, 10)): # Balloons per frame
x = random.randint(0, width - 1)
balloon = random.choice(balloons)
console.print(
Text(balloon, style=Style(color="yellow")), end="", overflow="ignore"
)
console.file.write(f"\x1b[{random.randint(1, 10)};{x}H")
console.file.flush()
time.sleep(0.5)
console.print(
":balloon: POP! :party_popper: No vulnerabilities detected!", style="bold green"
)
def show_victory_parade(console):
"""Displays a victory parade of emojis."""
parade = [
":party_popper:",
":confetti_ball:",
":trophy:",
":partying_face:",
":sparkles:",
":laptop_computer:",
":locked:",
":white_heavy_check_mark:",
]
width = console.size.width
for _ in range(20): # Duration of parade
console.clear()
line = " ".join(random.choices(parade, k=width // 2))
console.print(line, style="bold green", justify="center")
time.sleep(0.2)
console.print(
"The parade is over. Your code is safe! :trophy:", style="bold yellow"
)
def show_confetti_rain(console):
"""Displays a colorful confetti rain effect."""
colors = ["red", "green", "yellow", "blue", "magenta", "cyan", "white"]
width = console.size.width
for _ in range(10): # Number of confetti frames
console.clear()
for __ in range(100): # Confetti pieces per frame
x = random.randint(0, width - 1)
char = random.choice(["*", "+", "~", ":sparkles:", "o"])
color = random.choice(colors)
console.print(
Text(char, style=Style(color=color)), end="", overflow="ignore"
)
console.file.write(f"\x1b[{random.randint(1, 10)};{x}H")
console.file.flush()
time.sleep(0.3)
console.print(
":party_popper: Confetti celebration complete! You're vulnerability-free! :party_popper:",
style="bold cyan",
)
def show_fireworks_display(console):
"""Displays a celebratory fireworks animation."""
fireworks = [
":collision:",
":sparkles:",
":glowing_star:",
":fireworks:",
":sparkler:",
]
width = console.size.width
for _ in range(15): # Number of fireworks
x = random.randint(5, width - 5)
y = random.randint(2, 8)
firework = random.choice(fireworks)
color = random.choice(["red", "yellow", "green", "blue", "magenta"])
console.print(
Text(firework, style=Style(color=color)), end="", overflow="ignore"
)
console.file.write(f"\x1b[{y};{x}H") # Position fireworks
console.file.flush()
time.sleep(0.3)
console.print(
":fireworks: Fireworks display finished! Code is secure! :fireworks:",
style="bold magenta",
)
def show_star_trail(console):
"""Displays a shooting star trail effect."""
stars = [":white_medium_star:", ":glowing_star:", ":sparkles:", ":dizzy:"]
width = console.size.width
for _ in range(10): # Number of shooting stars
console.clear()
start_x = random.randint(0, width // 2)
trail = "".join(random.choices(stars, k=10))
console.print(f"{' ' * start_x}{trail}", style="bold yellow", justify="left")
time.sleep(0.3)
console.print(
":sparkles: Your code shines bright with no vulnerabilities! :sparkles:",
style="bold cyan",
)
def show_celebration_wave(console):
"""Displays a celebratory wave effect with emojis."""
emojis = [
":party_popper:",
":confetti_ball:",
":sparkles:",
":partying_face:",
":balloon:",
]
width = console.size.width
wave = [random.choice(emojis) for _ in range(width)]
for _ in range(10): # Number of waves
console.clear()
line = "".join(wave)
console.print(line, style="bold yellow", justify="center")
wave.insert(0, wave.pop()) # Shift wave
time.sleep(0.3)
console.print(
":water_wave: Celebration wave ends! Your scan is clean! :glowing_star:",
style="bold green",
)
# List of all celebratory effects
CELEBRATION_EFFECTS = [
show_confetti,
show_trophy,
show_balloons,
show_victory_parade,
show_confetti_rain,
show_fireworks_display,
show_star_trail,
show_celebration_wave,
]

View File

@@ -0,0 +1,284 @@
import random
import time
import os
from safety.scan.fun_mode.celebration_effects import CELEBRATION_EFFECTS
# -----------------------------
# Data: ASCII Arts, Fortunes, EMOJIS, etc.
# -----------------------------
ASCII_ARTS = {
"ascii": [
# Hedgehog guarding a shield
r"""
/\ /\
{ `---' }
{ O O }
~~> V <~~
\ \|/ /
`-----'__
/ \ `^\_
{ }\ |\_\_ *Safely Protected from vulnerabilities!*
| \_/ |/ / \_\_
\__/ /(_E \__/
( /
MM
""",
# Cat with a shield
r"""
/\_/\
=( °w° )= *Purrr... no vulnerabilities dare to cross!*
( * )
---( )---
/ \
/ ^ \
( ( ) )
\_) |_(_/
__||___||__
| |
| SAFE |
| & |
| SECURE! |
|___________|
/ \
| |
""",
# Bunny with a shield
r"""
(\_/)
( •_•) *Hop-hop! SafetyCLI ready, no vulns here!*
( >:carrot:< )
/ \
/ | | \
/ | | \
/ |___| \
( )
\____|||_____/
||||
__||__
| |
| SAFE |
| FROM |
| BUGS |
|______|
""",
# Dog behind a shield
r"""
/ \__
( o\____
/ O
/ (_______/ *Woof! Our shield is strong, no vulns inside!*
/_____/
""",
]
}
FORTUNES = [
"Your dependencies are safer than a password-manager's vault.",
"Your code sparkles with zero known vulnerabilities!",
"All vulnerabilities fear your security prowess!",
"Your build is as solid as a rock!",
"Your code is a fortress; no bug can breach it.",
"In the realm of code, you are the vigilant guardian.",
"Each line you write fortifies the castle of your code.",
"Your code is a well-oiled machine, impervious to rust.",
"Your code is a symphony of security and efficiency.",
"Your code is a beacon of safety in the digital ocean.",
"Your code is a masterpiece, untouched by the hands of vulnerabilities.",
"Your code stands tall, a citadel against cyber threats.",
"Your code is a tapestry woven with threads of safety.",
"Your code is a lighthouse, guiding ships away from the rocks of vulnerabilities.",
"Your code is a garden where no weeds of bugs can grow.",
"In the realm of software, your security measures are legendary.",
]
EMOJIS = [
":dog_face:",
":dog2:",
":guide_dog:",
":service_dog:",
":poodle:",
":wolf:",
":fox_face:",
":cat_face:",
":cat2:",
":cat2:",
":lion_face:",
":tiger_face:",
":tiger2:",
":leopard:",
":horse_face:",
":deer:",
":deer:",
":racehorse:",
":unicorn_face:",
":zebra:",
":deer:",
":bison:",
":cow_face:",
":ox:",
":water_buffalo:",
":cow2:",
":ram:",
":sheep:",
":goat:",
":dromedary_camel:",
":two-hump_camel:",
":llama:",
":giraffe:",
":elephant:",
":mammoth:",
":rhinoceros:",
":hippopotamus:",
":mouse_face:",
":mouse2:",
":rat:",
":hamster:",
":rabbit_face:",
":rabbit2:",
":chipmunk:",
":beaver:",
":hedgehog:",
":bat:",
":bear:",
":polar_bear:",
":koala:",
":panda_face:",
":otter:",
":kangaroo:",
":badger:",
":turkey:",
":chicken:",
":rooster:",
":baby_chick:",
":hatched_chick:",
":bird:",
":penguin:",
":dove:",
":eagle:",
":duck:",
":swan:",
":owl:",
":dodo:",
":flamingo:",
":peacock:",
":parrot:",
":bird:",
":goose:",
":phoenix:",
":frog:",
":crocodile:",
":turtle:",
":lizard:",
":dragon:",
":sauropod:",
":t-rex:",
":whale:",
":whale2:",
":flipper:",
":seal:",
":fish:",
":tropical_fish:",
":blowfish:",
":shark:",
":octopus:",
":jellyfish:",
":crab:",
":lobster:",
":squid:",
":snail:",
":butterfly:",
":bug:",
":bee:",
]
# -----------------------------
# Helper functions (Effects)
# -----------------------------
def show_race(console):
# Pick two different EMOJIS at random
emoji1, emoji2 = random.sample(EMOJIS, 2)
finish_line = 50
pos1 = 0
pos2 = 0
console.print("Ready... Set... Go!", style="bold cyan")
time.sleep(1)
console.clear()
while True:
# Move contestants forward by random increments
pos1 += random.randint(1, 3)
pos2 += random.randint(1, 3)
console.clear()
console.print("[green]Finish line[/green]" + " " * (finish_line - 10) + "|")
line1 = " " * pos1 + emoji1
line2 = " " * pos2 + emoji2
console.print(f"{emoji1} lane: {line1}")
console.print(f"{emoji2} lane: {line2}")
time.sleep(0.3)
finished1 = pos1 >= finish_line
finished2 = pos2 >= finish_line
if finished1 and finished2:
console.print(
"It's a tie! Both reached the finish line at the same time!",
style="bold magenta",
)
break
elif finished1:
console.print(
f"The {emoji1} wins! Slow and steady (or maybe fast?), it prevailed!",
style="bold green",
)
break
elif finished2:
console.print(
f"The {emoji2} wins! Speed and agility triumphed!", style="bold green"
)
break
time.sleep(2)
console.clear()
console.print("Hope you enjoyed the race! :party_popper:", style="bold cyan")
# -----------------------------
# Main Easter Egg Dispatcher
# -----------------------------
def run_easter_egg(console, exit_code: int) -> None:
"""
Runs an easter egg based on the SAFETY_FUN_MODE environment variable.
This function can be easily removed or commented out.
"""
egg_mode = os.getenv("SAFETY_FUN_MODE", "").strip().lower()
allowed_modes = {"ascii", "fx", "race", "fortune"}
if exit_code == 0 and egg_mode in allowed_modes:
if egg_mode == "ascii":
art = random.choice(ASCII_ARTS["ascii"])
console.print(art, style="green")
elif egg_mode == "fx":
effect = random.choice(CELEBRATION_EFFECTS)
effect(console) # Run the randomly selected effect
elif egg_mode == "race":
show_race(console)
elif egg_mode == "fortune":
fortune_message = random.choice(FORTUNES)
console.print(f"\n[italic cyan]{fortune_message}[/italic cyan]\n")

View File

@@ -0,0 +1,564 @@
import logging
from enum import Enum
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Generator,
List,
Optional,
Tuple,
Union,
Literal,
)
from pydantic import BaseModel, ConfigDict
import typer
from safety.auth.constants import SAFETY_PLATFORM_URL
from safety.errors import SafetyException
from safety.scan.command import (
ScannableEcosystems,
initialize_file_finder,
scan_project_directory,
)
from safety.scan.main import (
download_policy,
load_policy_file,
process_files,
resolve_policy,
)
from safety_schemas.models import (
Ecosystem,
FileModel,
FileType,
RemediationModel,
ReportModel,
ReportSchemaVersion,
ScanType,
VulnerabilitySeverityLabels,
MetadataModel,
TelemetryModel,
ProjectModel,
Stage,
AuthenticationType,
)
from safety.scan.util import GIT
from safety.util import build_telemetry_data
# Define typed models for scan results
class ScanResultType(str, Enum):
"""Types of scan results that can be yielded by the init_scan function"""
INIT = "init"
PROGRESS = "progress"
UPLOADING = "uploading"
STATUS = "status"
COMPLETE = "complete"
class BaseScanResult(BaseModel):
"""
Base class for all scan results with common attributes
"""
# No fields here - each subclass will define its own type
pass
class InitScanResult(BaseScanResult):
"""
Initial scan result with basic dependency info
"""
model_config = ConfigDict(frozen=True)
type: Literal[ScanResultType.INIT]
dependencies: int
progress: int = 0
class ProgressScanResult(BaseScanResult):
"""
Progress update during scanning with current counts
"""
model_config = ConfigDict(frozen=True)
type: Literal[ScanResultType.PROGRESS]
percent: int
dependencies: int
critical: Optional[int] = None
high: Optional[int] = None
medium: Optional[int] = None
low: Optional[int] = None
others: Optional[int] = None
fixes: Optional[int] = None
fixed_vulns: Optional[int] = None
file: str
file_pkg_count: int
file_count: int
venv_count: int
vulns_count: int
class CompleteScanResult(BaseScanResult):
"""
Final scan result with complete vulnerability counts
"""
model_config = ConfigDict(frozen=True)
type: Literal[ScanResultType.COMPLETE]
scan_id: Optional[str] = None
percent: int = 100
dependencies: int
critical: int
high: int
medium: int
low: int
others: int
vulns_count: int
fixes: int
fixed_vulns: int
codebase_url: Optional[str] = None
class StatusScanResult(BaseScanResult):
"""
Generic status update that can be used for any process
"""
model_config = ConfigDict(frozen=True)
type: Literal[ScanResultType.STATUS]
message: str
action: str # The specific action being performed (e.g., "analyzing", "preparing")
percent: Optional[int] = None
class UploadingScanResult(BaseScanResult):
"""
Status update when uploading results to server
"""
model_config = ConfigDict(frozen=True)
type: Literal[ScanResultType.UPLOADING]
message: str
percent: Optional[int] = None
# Union type for all possible result types
ScanResult = Union[
InitScanResult,
ProgressScanResult,
StatusScanResult,
UploadingScanResult,
CompleteScanResult,
]
LOG = logging.getLogger(__name__)
if TYPE_CHECKING:
from safety_schemas.models import (
ConfigModel,
ProjectModel,
MetadataModel,
TelemetryModel,
ReportModel,
FileModel,
)
def init_scan(
ctx: Any,
target: Path,
config: "ConfigModel",
metadata: "MetadataModel",
telemetry: "TelemetryModel",
project: "ProjectModel",
use_server_matching: bool = False,
) -> Generator[ScanResult, None, Tuple["ReportModel", List["FileModel"]]]:
"""
Core scanning logic that yields results as they become available.
Contains no UI-related code - purely logic for scanning.
Args:
ctx: The context object with necessary configurations
target: The target directory to scan
config: The application configuration
metadata: Metadata to include in the report
telemetry: Telemetry data to include in the report
project: The project object
version: The schema version
use_server_matching: Whether to use server-side vulnerability matching
Yields:
Dict containing scan progress information and results as they become available
Returns:
Tuple containing the final report model and list of files
"""
# Emit status that scan is starting
yield StatusScanResult(
type=ScanResultType.STATUS,
message="Starting safety scan",
action="initializing",
percent=0,
)
# Initialize ecosystems
ecosystems = [Ecosystem(member.value) for member in list(ScannableEcosystems)]
# Initialize file finder and locate project files
from rich.console import Console
console = Console()
console.quiet = True
yield StatusScanResult(
type=ScanResultType.STATUS,
message="Locating project files",
action="discovering",
percent=5,
)
file_finder = initialize_file_finder(ctx, target, None, ecosystems)
yield StatusScanResult(
type=ScanResultType.STATUS,
message="Scanning project directory",
action="scanning",
percent=10,
)
_, file_paths = scan_project_directory(file_finder, console)
total_files = sum(len(file_set) for file_set in file_paths.values())
yield StatusScanResult(
type=ScanResultType.STATUS,
message=f"Found {total_files} files to analyze",
action="analyzing",
percent=15,
)
# Initialize counters and data structures
files: List[FileModel] = []
count = 0 # Total dependencies processed
affected_count = 0
critical_vulns_count = 0
high_vulns_count = 0
medium_vulns_count = 0
low_vulns_count = 0
others_vulns_count = 0
vulns_count = 0
fixes_count = 0
total_resolved_vulns = 0
file_count = 0
venv_count = 0
scan_id = None
# Count the total number of files across all types
# Initial yield with dependency info
yield InitScanResult(type=ScanResultType.INIT, dependencies=count)
# Status update before processing files
yield StatusScanResult(
type=ScanResultType.STATUS,
message="Processing files for dependencies and vulnerabilities",
action="analyzing",
percent=20,
)
# Process each file for dependencies and vulnerabilities
for idx, (path, analyzed_file) in enumerate(
process_files(
paths=file_paths,
config=config,
use_server_matching=use_server_matching,
obj=ctx.obj,
target=target,
)
):
# Calculate progress percentage
# Calculate progress and ensure it never exceeds 100%
if total_files > 0:
progress = min(int((idx + 1) / total_files * 100), 100)
else:
progress = 100
# Update counts for dependencies
file_pkg_count = len(analyzed_file.dependency_results.dependencies)
count += file_pkg_count
# Track environment/file types
if analyzed_file.file_type is FileType.VIRTUAL_ENVIRONMENT:
venv_count += 1
else:
file_count += 1
# Get affected specifications
affected_specifications = (
analyzed_file.dependency_results.get_affected_specifications()
)
affected_count += len(affected_specifications)
# Count vulnerabilities by severity
current_critical = 0
current_high = 0
current_medium = 0
current_low = 0
current_others = 0
current_fixes = 0
current_resolved_vulns = 0
# Process each affected specification
for spec in affected_specifications:
# Access vulnerabilities
for vuln in spec.vulnerabilities:
if vuln.ignored:
continue
vulns_count += 1
# Determine vulnerability severity
severity = severity = VulnerabilitySeverityLabels.UNKNOWN
if (
hasattr(vuln, "CVE")
and vuln.CVE
and hasattr(vuln.CVE, "cvssv3")
and vuln.CVE.cvssv3
):
severity_str = vuln.CVE.cvssv3.get("base_severity", "none").lower()
severity = VulnerabilitySeverityLabels(severity_str)
# Count based on severity
if severity is VulnerabilitySeverityLabels.CRITICAL:
current_critical += 1
elif severity is VulnerabilitySeverityLabels.HIGH:
current_high += 1
elif severity is VulnerabilitySeverityLabels.MEDIUM:
current_medium += 1
elif severity is VulnerabilitySeverityLabels.LOW:
current_low += 1
else:
current_others += 1
# Check for available fixes - safely access remediation attributes
if spec.remediation:
# Access remediation properties safely without relying on specific attribute names
remediation: RemediationModel = spec.remediation
has_recommended_version = True if remediation.recommended else False
if has_recommended_version:
current_fixes += 1
current_resolved_vulns += len(
[v for v in spec.vulnerabilities if not v.ignored]
)
# Update total counts
critical_vulns_count += current_critical
high_vulns_count += current_high
medium_vulns_count += current_medium
low_vulns_count += current_low
others_vulns_count += current_others
fixes_count += current_fixes
total_resolved_vulns += current_resolved_vulns
# Save file data for further processing
file = FileModel(
location=path,
file_type=analyzed_file.file_type,
results=analyzed_file.dependency_results,
)
files.append(file)
# Yield current analysis results
yield ProgressScanResult(
type=ScanResultType.PROGRESS,
percent=progress,
dependencies=count,
critical=critical_vulns_count,
high=high_vulns_count,
medium=medium_vulns_count,
low=low_vulns_count,
others=others_vulns_count,
vulns_count=vulns_count,
fixes=fixes_count,
fixed_vulns=total_resolved_vulns,
file=str(path),
file_pkg_count=file_pkg_count,
file_count=file_count,
venv_count=venv_count,
)
# All files processed, create the report
project.files = files
yield StatusScanResult(
type=ScanResultType.STATUS,
message="Creating final report",
action="reporting",
percent=90,
)
# Convert dictionaries to model objects if needed
if isinstance(metadata, dict):
metadata_model = MetadataModel(**metadata)
else:
metadata_model = metadata
if isinstance(telemetry, dict):
telemetry_model = TelemetryModel(**telemetry)
else:
telemetry_model = telemetry
report = ReportModel(
version=ReportSchemaVersion.v3_0,
metadata=metadata_model,
telemetry=telemetry_model,
files=[],
projects=[project],
)
# Emit uploading status before starting upload
yield UploadingScanResult(
type=ScanResultType.UPLOADING, message="Preparing to upload scan results"
)
# TODO: Decouple platform upload logic
try:
# Convert report to JSON format
yield UploadingScanResult(
type=ScanResultType.UPLOADING,
message="Converting report to JSON format",
percent=25,
)
json_format = report.as_v30().json()
# Start upload
yield UploadingScanResult(
type=ScanResultType.UPLOADING,
message="Uploading results to Safety platform",
percent=50,
)
result = ctx.obj.auth.client.upload_report(json_format)
# Upload complete
yield UploadingScanResult(
type=ScanResultType.UPLOADING,
message="Upload completed successfully",
percent=100,
)
scan_id = result.get("uuid")
codebase_url = f"{SAFETY_PLATFORM_URL}{result['url']}"
except Exception as e:
# Emit error status
yield UploadingScanResult(
type=ScanResultType.UPLOADING, message=f"Error uploading results: {str(e)}"
)
raise e
# Final yield with completed flag
yield CompleteScanResult(
type=ScanResultType.COMPLETE,
dependencies=count,
critical=critical_vulns_count,
high=high_vulns_count,
medium=medium_vulns_count,
low=low_vulns_count,
others=others_vulns_count,
vulns_count=vulns_count,
fixes=fixes_count,
fixed_vulns=total_resolved_vulns,
codebase_url=codebase_url,
scan_id=scan_id,
)
# Return the complete report and files
return report, files
def start_scan(
ctx: "typer.Context",
auth_type: AuthenticationType,
is_authenticated: bool,
target: Path,
client: Any,
project: ProjectModel,
branch: Optional[str] = None,
stage: Stage = Stage.development,
platform_enabled: bool = False,
telemetry_enabled: bool = True,
use_server_matching: bool = False,
) -> Generator["ScanResult", None, Tuple["ReportModel", List["FileModel"]]]:
"""
Initialize and start a scan, returning an iterator that yields scan results.
This function handles setting up all required parameters for the scan.
Args:
ctx: The Typer context object containing configuration and project information
target: The target directory to scan
use_server_matching: Whether to use server-side vulnerability matching
Returns:
An iterator that yields scan results
"""
if not branch:
if git_data := GIT(root=target).build_git_data():
branch = git_data.branch
command_name = "scan"
telemetry = build_telemetry_data(
telemetry=telemetry_enabled, command=command_name, subcommand=None
)
scan_type = ScanType(command_name)
targets = [target]
if not scan_type:
raise SafetyException("Missing scan_type.")
metadata = MetadataModel(
scan_type=scan_type,
stage=stage,
scan_locations=targets,
authenticated=is_authenticated,
authentication_type=auth_type,
telemetry=telemetry,
schema_version=ReportSchemaVersion.v3_0,
)
policy_file_path = target / Path(".safety-policy.yml")
# Load Policy file and pull it from CLOUD
local_policy = load_policy_file(policy_file_path)
cloud_policy = None
if platform_enabled:
cloud_policy = download_policy(
client, project_id=project.id, stage=stage, branch=branch
)
project.policy = resolve_policy(local_policy, cloud_policy)
config = (
project.policy.config
if project.policy and project.policy.config
else ConfigModel()
)
return init_scan(
ctx=ctx,
target=target,
config=config,
metadata=metadata,
telemetry=telemetry,
project=project,
use_server_matching=use_server_matching,
)

View File

@@ -0,0 +1,279 @@
import logging
import os
import platform
import time
from pathlib import Path
from typing import Any, Dict, Generator, Optional, Set, Tuple
from pydantic import ValidationError
from safety_schemas.models import (
ConfigModel,
FileType,
PolicyFileModel,
PolicySource,
ScanType,
Stage,
)
from safety.scan.util import GIT
from ..encoding import detect_encoding
from ..auth.utils import SafetyAuthSession
from ..errors import SafetyError
from .ecosystems.base import InspectableFile
from .ecosystems.target import InspectableFileContext
from .models import ScanExport
from ..meta import get_version
LOG = logging.getLogger(__name__)
def download_policy(
session: SafetyAuthSession, project_id: str, stage: Stage, branch: Optional[str]
) -> Optional[PolicyFileModel]:
"""
Downloads the policy file from the cloud for the given project and stage.
Args:
session (SafetyAuthSession): SafetyAuthSession object for authentication.
project_id (str): The ID of the project.
stage (Stage): The stage of the project.
branch (Optional[str]): The branch of the project (optional).
Returns:
Optional[PolicyFileModel]: PolicyFileModel object if successful, otherwise None.
"""
result = session.download_policy(project_id=project_id, stage=stage, branch=branch)
if result and "uuid" in result and result["uuid"]:
LOG.debug(f"Loading CLOUD policy file {result['uuid']} from cloud.")
LOG.debug(result)
uuid = result["uuid"]
err = f'Unable to load the Safety Policy file ("{uuid}"), from cloud.'
config = None
try:
yml_raw = result["settings"]
# TODO: Move this to safety_schemas
parse = "parse_obj"
import importlib
module_name = "safety_schemas.config.schemas.v3_0.main"
module = importlib.import_module(module_name)
config_model = module.Config
validated_policy_file = getattr(config_model, parse)(yml_raw)
config = ConfigModel.from_v30(obj=validated_policy_file)
except ValidationError as e:
LOG.error(f"Failed to parse policy file {uuid}.", exc_info=True)
raise SafetyError(f"{err}, details: {e}")
except ValueError as e:
LOG.error(f"Wrong YML file for policy file {uuid}.", exc_info=True)
raise SafetyError(f"{err}, details: {e}")
return PolicyFileModel(
id=result["uuid"], source=PolicySource.cloud, location=None, config=config
)
return None
def load_policy_file(path: Path) -> Optional[PolicyFileModel]:
"""
Loads a policy file from the specified path.
Args:
path (Path): The path to the policy file.
Returns:
Optional[PolicyFileModel]: PolicyFileModel object if successful, otherwise None.
"""
config = None
if not path or not path.exists():
return None
err = (
f'Unable to load the Safety Policy file ("{path}"), this command '
"only supports version 3.0"
)
try:
config = ConfigModel.parse_policy_file(raw_report=path)
except ValidationError as e:
LOG.error(f"Failed to parse policy file {path}.", exc_info=True)
raise SafetyError(f"{err}, details: {e}")
except ValueError as e:
LOG.error(f"Wrong YML file for policy file {path}.", exc_info=True)
raise SafetyError(f"{err}, details: {e}")
return PolicyFileModel(
id=str(path), source=PolicySource.local, location=path, config=config
)
def resolve_policy(
local_policy: Optional[PolicyFileModel], cloud_policy: Optional[PolicyFileModel]
) -> Optional[PolicyFileModel]:
"""
Resolves the policy to be used, preferring cloud policy over local policy.
Args:
local_policy (Optional[PolicyFileModel]): The local policy file model (optional).
cloud_policy (Optional[PolicyFileModel]): The cloud policy file model (optional).
Returns:
Optional[PolicyFileModel]: The resolved PolicyFileModel object.
"""
policy = None
if cloud_policy:
policy = cloud_policy
elif local_policy:
policy = local_policy
return policy
def save_report_as(
scan_type: ScanType, export_type: ScanExport, at: Path, report: Any
) -> None:
"""
Saves the scan report to the specified location.
Args:
scan_type (ScanType): The type of scan.
export_type (ScanExport): The type of export.
at (Path): The path to save the report.
report (Any): The report content.
"""
tag = int(time.time())
if at.is_dir():
at = at / Path(
f"{scan_type.value}-{export_type.get_default_file_name(tag=tag)}"
)
with open(at, "w+") as report_file:
report_file.write(report)
def build_meta(target: Path) -> Dict[str, Any]:
"""
Build the meta JSON object for a file.
Args:
target (Path): The path of the repository.
Returns:
Dict[str, Any]: The metadata dictionary.
"""
target_obj = target.resolve()
git_utils = GIT(target_obj)
git_data = git_utils.build_git_data()
git_metadata = {
"branch": git_data.branch if git_data else None,
"commit": git_data.commit if git_data else None,
"dirty": git_data.dirty if git_data else None,
"tag": git_data.tag if git_data else None,
"origin": git_data.origin if git_data else None,
}
os_metadata = {
"type": os.environ.get("SAFETY_OS_TYPE", None) or platform.system(),
"release": os.environ.get("SAFETY_OS_RELEASE", None) or platform.release(),
"description": os.environ.get("SAFETY_OS_DESCRIPTION", None)
or platform.platform(),
}
python_metadata = {
"version": platform.python_version(),
}
client_metadata = {
"version": get_version(),
}
return {
"target": str(target),
"os": os_metadata,
"git": git_metadata,
"python": python_metadata,
"client": client_metadata,
}
def process_files(
paths: Dict[str, Set[Path]],
config: Optional[ConfigModel] = None,
use_server_matching: bool = False,
obj=None,
target=Path("."),
) -> Generator[Tuple[Path, InspectableFile], None, None]:
"""
Processes the files and yields each file path along with its inspectable file.
Args:
paths (Dict[str, Set[Path]]): A dictionary of file paths by file type.
config (Optional[ConfigModel]): The configuration model (optional).
Yields:
Tuple[Path, InspectableFile]: A tuple of file path and inspectable file.
"""
if not config:
config = ConfigModel()
# old GET implementation
if not use_server_matching:
for file_type_key, f_paths in paths.items():
file_type = FileType(file_type_key)
if not file_type or not file_type.ecosystem:
continue
for f_path in f_paths:
with InspectableFileContext(
f_path, file_type=file_type
) as inspectable_file:
if inspectable_file and inspectable_file.file_type:
inspectable_file.inspect(config=config)
inspectable_file.remediate()
yield f_path, inspectable_file
# new POST implementation
else:
files = []
meta = build_meta(target)
for file_type_key, f_paths in paths.items():
file_type = FileType(file_type_key)
if not file_type or not file_type.ecosystem:
continue
for f_path in f_paths:
relative_path = os.path.relpath(f_path, start=os.getcwd())
# Read the file content
try:
with open(f_path, "r", encoding=detect_encoding(f_path)) as file:
content = file.read()
except Exception as e:
LOG.error(f"Error reading file {f_path}: {e}")
continue
# Append metadata to the payload
files.append(
{
"name": relative_path,
"content": content,
}
)
# Prepare the payload with metadata at the top level
payload = {
"meta": meta,
"files": files,
}
response = obj.auth.client.upload_requirements(payload) # type: ignore
if response.status_code == 200:
LOG.info("Scan Payload successfully sent to the API.")
else:
LOG.error(
f"Failed to send scan payload to the API. Status code: {response.status_code}"
)
LOG.error(f"Response: {response.text}")

View File

@@ -0,0 +1,119 @@
from enum import Enum
from typing import Optional
class FormatMixin:
"""
Mixin class providing format-related utilities for Enum classes.
"""
@classmethod
def is_format(cls, format_sub: Optional[Enum], format_instance: Enum) -> bool:
"""
Check if the value is a variant of the specified format.
Args:
format_sub (Optional[Enum]): The format to check.
format_instance (Enum): The instance of the format to compare against.
Returns:
bool: True if the format matches, otherwise False.
"""
if not format_sub:
return False
if format_sub is format_instance:
return True
prefix = format_sub.value.split('@')[0]
return prefix == format_instance.value
@property
def version(self) -> Optional[str]:
"""
Return the version of the format.
Returns:
Optional[str]: The version of the format if available, otherwise None.
"""
result = self.value.split('@')
if len(result) == 2:
return result[1]
return None
class ScanOutput(FormatMixin, str, Enum):
"""
Enum representing different scan output formats.
"""
JSON = "json"
SPDX = "spdx"
SPDX_2_3 = "spdx@2.3"
SPDX_2_2 = "spdx@2.2"
HTML = "html"
SCREEN = "screen"
NONE = "none"
def is_silent(self) -> bool:
"""
Check if the output format is silent.
Returns:
bool: True if the output format is silent, otherwise False.
"""
return self in (ScanOutput.JSON, ScanOutput.SPDX, ScanOutput.SPDX_2_3, ScanOutput.SPDX_2_2, ScanOutput.HTML)
class ScanExport(FormatMixin, str, Enum):
"""
Enum representing different scan export formats.
"""
JSON = "json"
SPDX = "spdx"
SPDX_2_3 = "spdx@2.3"
SPDX_2_2 = "spdx@2.2"
HTML = "html"
def get_default_file_name(self, tag: int) -> str:
"""
Get the default file name for the export format.
Args:
tag (int): A unique tag to include in the file name.
Returns:
str: The default file name.
"""
if self is ScanExport.JSON:
return f"safety-report-{tag}.json"
elif self in [ScanExport.SPDX, ScanExport.SPDX_2_3, ScanExport.SPDX_2_2]:
return f"safety-report-spdx-{tag}.json"
elif self is ScanExport.HTML:
return f"safety-report-{tag}.html"
else:
raise ValueError("Unsupported scan export type")
class SystemScanOutput(str, Enum):
"""
Enum representing different system scan output formats.
"""
JSON = "json"
SCREEN = "screen"
def is_silent(self) -> bool:
"""
Check if the output format is silent.
Returns:
bool: True if the output format is silent, otherwise False.
"""
return self in (SystemScanOutput.JSON,)
class SystemScanExport(str, Enum):
"""
Enum representing different system scan export formats.
"""
JSON = "json"

View File

@@ -0,0 +1,824 @@
# type: ignore
import datetime
import itertools
import json
import logging
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
import sys
import typer
from rich.console import Console
from rich.padding import Padding
from rich.prompt import Prompt
from rich.text import Text
from safety_schemas.models import (
Ecosystem,
FileType,
IgnoreCodes,
PolicyFileModel,
PolicySource,
ProjectModel,
PythonDependency,
ReportModel,
Vulnerability,
)
from safety import safety
from safety.auth.constants import SAFETY_PLATFORM_URL
from safety.errors import SafetyException
from safety.meta import get_version
from safety.output_utils import parse_html
from safety.scan.constants import DEFAULT_SPINNER
from safety.util import clean_project_id, get_basic_announcements
LOG = logging.getLogger(__name__)
def render_header(targets: List[Path], is_system_scan: bool) -> Text:
"""
Render the header text for the scan.
Args:
targets (List[Path]): List of target paths for the scan.
is_system_scan (bool): Indicates if the scan is a system scan.
Returns:
Text: Rendered header text.
"""
version = get_version()
scan_datetime = datetime.datetime.now(datetime.timezone.utc).strftime(
"%Y-%m-%d %H:%M:%S %Z"
)
action = f"scanning {', '.join([str(t) for t in targets])}"
if is_system_scan:
action = "running [bold]system scan[/bold]"
return Text.from_markup(f"[bold]Safety[/bold] {version} {action}\n{scan_datetime}")
def print_header(console, targets: List[Path], is_system_scan: bool = False) -> None:
"""
Print the header for the scan.
Args:
console (Console): The console for output.
targets (List[Path]): List of target paths for the scan.
is_system_scan (bool): Indicates if the scan is a system scan.
"""
console.print(render_header(targets, is_system_scan), markup=True)
def print_announcements(console: Console, ctx: typer.Context):
"""
Print announcements from Safety.
Args:
console (Console): The console for output.
ctx (typer.Context): The context of the Typer command.
"""
colors = {"error": "red", "warning": "yellow", "info": "default"}
announcements = safety.get_announcements(
ctx.obj.auth.client,
telemetry=ctx.obj.config.telemetry_enabled,
with_telemetry=ctx.obj.telemetry,
)
basic_announcements = get_basic_announcements(announcements, False)
if any(basic_announcements):
console.print()
console.print("[bold]Safety Announcements:[/bold]")
console.print()
for announcement in announcements:
color = colors.get(announcement.get("type", "info"), "default")
console.print(f"[{color}]* {announcement.get('message')}[/{color}]")
def print_detected_ecosystems_section(
console: Console, file_paths: Dict[str, Set[Path]], include_safety_prjs: bool = True
) -> None:
"""
Print detected ecosystems section.
Args:
console (Console): The console for output.
file_paths (Dict[str, Set[Path]]): Dictionary of file paths by type.
include_safety_prjs (bool): Whether to include safety projects.
"""
detected: Dict[Ecosystem, Dict[FileType, int]] = {}
for file_type_key, f_paths in file_paths.items():
file_type = FileType(file_type_key)
if file_type.ecosystem:
if file_type.ecosystem not in detected:
detected[file_type.ecosystem] = {}
detected[file_type.ecosystem][file_type] = len(f_paths)
for ecosystem, f_type_count in detected.items():
if not include_safety_prjs and ecosystem is Ecosystem.SAFETY_PROJECT:
continue
brief = "Found "
file_types = []
for f_type, count in f_type_count.items():
file_types.append(f"{count} {f_type.human_name(plural=count > 1)}")
if len(file_types) > 1:
brief += ", ".join(file_types[:-1]) + " and " + file_types[-1]
else:
brief += file_types[0]
msg = f"{ecosystem.name.replace('_', ' ').title()} detected. {brief}"
console.print(msg)
def print_fixes_section(
console: Console,
requirements_txt_found: bool = False,
is_detailed_output: bool = False,
) -> None:
"""
Print the section on applying fixes.
Args:
console (Console): The console for output.
requirements_txt_found (bool): Indicates if a requirements.txt file was found.
is_detailed_output (bool): Indicates if detailed output is enabled.
"""
console.print("-" * console.size.width)
console.print("Apply Fixes")
console.print("-" * console.size.width)
console.print()
if requirements_txt_found:
console.print(
"[green]Run `safety scan --apply-fixes`[/green] to update these packages and fix these vulnerabilities. "
"Documentation, limitations, and configurations for applying automated fixes: [link]https://docs.safetycli.com/safety-docs/vulnerability-remediation/applying-fixes[/link]"
)
console.print()
console.print(
"Alternatively, use your package manager to update packages to their secure versions. Always check for breaking changes when updating packages."
)
else:
msg = "Use your package manager to update packages to their secure versions. Always check for breaking changes when updating packages."
console.print(msg)
if not is_detailed_output:
console.print(
"[tip]Tip[/tip]: For more detailed output on each vulnerability, add the `--detailed-output` flag to safety scan."
)
console.print()
console.print("-" * console.size.width)
def print_summary(
console: Console,
total_issues_with_duplicates: int,
total_ignored_issues: int,
project: ProjectModel,
dependencies_count: int = 0,
fixes_count: int = 0,
resolved_vulns_per_fix: int = 0,
is_detailed_output: bool = False,
ignored_vulns_data: Optional[Dict[str, Vulnerability]] = None,
) -> None:
"""
Prints a concise summary of scan results including vulnerabilities, fixes, and ignored vulnerabilities.
This function summarizes the results of a security scan, displaying the number of dependencies scanned,
vulnerabilities found, suggested fixes, and the impact of those fixes. It also optionally provides a
detailed breakdown of ignored vulnerabilities based on predefined policies.
Args:
console (Console): The console object used to print formatted output.
total_issues_with_duplicates (int): The total number of security issues, including duplicates.
total_ignored_issues (int): The number of issues that were ignored based on project policies.
project (ProjectModel): The project model containing the scanned project details and policies.
dependencies_count (int, optional): The total number of dependencies scanned for vulnerabilities. Defaults to 0.
fixes_count (int, optional): The number of fixes suggested by the scan. Defaults to 0.
resolved_vulns_per_fix (int, optional): The number of vulnerabilities that can be resolved by the suggested fixes. Defaults to 0.
is_detailed_output (bool, optional): Flag to indicate whether detailed output, especially for ignored vulnerabilities, should be shown. Defaults to False.
ignored_vulns_data (Optional[Dict[str, Vulnerability]], optional): A dictionary of vulnerabilities that were ignored, categorized by their reason for being ignored. Defaults to None.
Returns:
None: This function does not return any value. It prints the summary to the console.
Usage:
Call this function after a vulnerability scan to display the results in a clear, formatted manner.
Example:
print_summary(console, unique_issues, 10, 2, project_model, dependencies_count=5, fixes_count=2)
"""
from ..util import pluralize
# Set the policy message based on the project source
if project.policy:
policy_msg = (
"policy fetched from Safety Platform"
if project.policy.source is PolicySource.cloud
else f"local {project.id or 'scan policy file'} project scan policy"
)
else:
policy_msg = "default Safety CLI policies"
console.print(
f"Tested [number]{dependencies_count}[/number] {pluralize('dependency', dependencies_count)} for security issues using {policy_msg}"
)
if total_issues_with_duplicates == 0:
console.print("0 security issues found, 0 fixes suggested.")
else:
# Print security issues and ignored vulnerabilities
console.print(
f"[number]{total_issues_with_duplicates}[/number] {pluralize('vulnerability', total_issues_with_duplicates)} found, "
f"[number]{total_ignored_issues}[/number] ignored due to policy."
)
console.print(
f"[number]{fixes_count}[/number] {pluralize('fix', fixes_count)} suggested, resolving [number]{resolved_vulns_per_fix}[/number] vulnerabilities."
)
if is_detailed_output:
if not ignored_vulns_data:
ignored_vulns_data = iter([])
manual_ignored = {}
cvss_severity_ignored = {}
cvss_severity_ignored_pkgs = set()
unpinned_ignored = {}
unpinned_ignored_pkgs = set()
environment_ignored = {}
environment_ignored_pkgs = set()
for vuln_data in ignored_vulns_data:
code = IgnoreCodes(vuln_data.ignored_code)
if code is IgnoreCodes.manual:
manual_ignored[vuln_data.vulnerability_id] = vuln_data
elif code is IgnoreCodes.cvss_severity:
cvss_severity_ignored[vuln_data.vulnerability_id] = vuln_data
cvss_severity_ignored_pkgs.add(vuln_data.package_name)
elif code is IgnoreCodes.unpinned_specification:
unpinned_ignored[vuln_data.vulnerability_id] = vuln_data
unpinned_ignored_pkgs.add(vuln_data.package_name)
elif code is IgnoreCodes.environment_dependency:
environment_ignored[vuln_data.vulnerability_id] = vuln_data
environment_ignored_pkgs.add(vuln_data.package_name)
if manual_ignored:
count = len(manual_ignored)
console.print(
f"[number]{count}[/number] were manually ignored due to the project policy:"
)
for vuln in manual_ignored.values():
render_to_console(
vuln,
console,
rich_kwargs={"emoji": True, "overflow": "crop"},
detailed_output=is_detailed_output,
)
if cvss_severity_ignored:
count = len(cvss_severity_ignored)
console.print(
f"[number]{count}[/number] {pluralize('vulnerability', count)} {pluralize('was', count)} ignored because "
"of their severity or exploitability impacted the following"
f" {pluralize('package', len(cvss_severity_ignored_pkgs))}: {', '.join(cvss_severity_ignored_pkgs)}"
)
if environment_ignored:
count = len(environment_ignored)
console.print(
f"[number]{count}[/number] {pluralize('vulnerability', count)} {pluralize('was', count)} ignored because "
"they are inside an environment dependency."
)
if unpinned_ignored:
count = len(unpinned_ignored)
console.print(
f"[number]{count}[/number] {pluralize('vulnerability', count)} {pluralize('was', count)} ignored because "
f"{pluralize('this', len(unpinned_ignored_pkgs))} {pluralize('package', len(unpinned_ignored_pkgs))} {pluralize('has', len(unpinned_ignored_pkgs))} unpinned specs: "
f"{', '.join(unpinned_ignored_pkgs)}"
)
def print_wait_project_verification(
console: Console,
project_id: str,
closure: Tuple[Any, Dict[str, Any]],
on_error_delay: int = 1,
) -> Any:
"""
Print a waiting message while verifying a project.
Args:
console (Console): The console for output.
project_id (str): The project ID.
closure (Tuple[Any, Dict[str, Any]]): The function and its arguments to call.
on_error_delay (int): Delay in seconds on error.
Returns:
Any: The status of the project verification.
"""
status = None
wait_msg = f"Verifying project {project_id} with Safety Platform."
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
try:
f, kwargs = closure
status = f(**kwargs)
except Exception as e:
LOG.exception(f"Unable to verify the project, reason: {e}")
reason = (
"We are currently unable to verify the project, "
"and it is necessary to link the scan to a specific "
f"project. \n\nAdditional Information: \n{e}"
)
raise SafetyException(message=reason)
if not status:
wait_msg = f'Unable to verify "{project_id}". Starting again...'
time.sleep(on_error_delay)
return status
def print_project_info(console: Console, project: ProjectModel):
"""
Print information about the project.
Args:
console (Console): The console for output.
project (ProjectModel): The project model.
"""
config_msg = "loaded without policies or custom configuration."
if project.policy:
if project.policy.source is PolicySource.local:
rel_location = (
project.policy.location.name if project.policy.location else ""
)
config_msg = f"configuration and policies fetched from {rel_location}."
else:
config_msg = " policies fetched from Safety Platform."
msg = f"[bold]{project.id} project found[/bold] - {config_msg}"
console.print(msg)
def print_wait_policy_download(
console: Console, closure: Tuple[Any, Dict[str, Any]]
) -> Optional[PolicyFileModel]:
"""
Print a waiting message while downloading a policy from the cloud.
Args:
console (Console): The console for output.
closure (Tuple[Any, Dict[str, Any]]): The function and its arguments to call.
Returns:
Optional[PolicyFileModel]: The downloaded policy file model.
"""
policy = None
wait_msg = "Looking for a policy from cloud..."
with console.status(wait_msg, spinner=DEFAULT_SPINNER):
try:
f, kwargs = closure
policy = f(**kwargs)
except Exception as e:
LOG.exception(f"Policy download failed, reason: {e}")
console.print("Not using cloud policy file.")
if policy:
wait_msg = "Policy fetched from Safety Platform."
else:
# TODO: Send a log
pass
return policy
def prompt_project_id(console: Console, default_id: str) -> str:
"""
Prompt the user to set a project ID, on a non-interactive mode it will
fallback to the default ID parameter.
"""
default_prj_id = clean_project_id(default_id)
if not sys.stdin.isatty():
LOG.info("Fallback to default project id, because of non-interactive mode.")
return default_prj_id
prompt_text = f"\nEnter a name for this codebase (or press [bold]Enter[/bold] to use '\\[{default_prj_id}]')"
while True:
result = Prompt.ask(
prompt_text, console=console, default=default_prj_id, show_default=False
)
return clean_project_id(result) if result != default_prj_id else default_prj_id
def prompt_link_project(console: Console, prj_name: str, prj_admin_email: str) -> bool:
"""
Prompt the user to link the scan with an existing project. If the console is not interactive
it will fallback to True.
Args:
console (Console): The console for output.
prj_name (str): The project name.
prj_admin_email (str): The project admin email.
Returns:
bool: True if the user wants to link the scan, False otherwise.
"""
if not sys.stdin.isatty():
LOG.info("Linking to existing project because of non-interactive mode.")
return True
console.print(
"\n[bold]Safety found an existing codebase with this name in your organization:[/bold]"
)
for detail in (
f"[bold]Codebase name:[/bold] {prj_name}",
f"[bold]Codebase admin:[/bold] {prj_admin_email}",
):
console.print(Padding(detail, (0, 0, 0, 2)), emoji=True)
console.print()
prompt_question = "Do you want to link it with this existing codebase?"
answer = Prompt.ask(
prompt=prompt_question,
choices=["y", "n"],
default="y",
show_default=True,
console=console,
).lower()
return answer == "y"
def render_to_console(
cls: Vulnerability,
console: Console,
rich_kwargs: Dict[str, Any],
detailed_output: bool = False,
) -> None:
"""
Render a vulnerability to the console.
Args:
cls (Vulnerability): The vulnerability instance.
console (Console): The console for output.
rich_kwargs (Dict[str, Any]): Additional arguments for rendering.
detailed_output (bool): Indicates if detailed output is enabled.
"""
cls.__render__(console, detailed_output, rich_kwargs)
def get_render_console(entity_type: Any) -> Any:
"""
Get the render function for a specific entity type.
Args:
entity_type (Any): The entity type.
Returns:
Any: The render function.
"""
if entity_type is Vulnerability:
def __render__(self, console: Console, detailed_output: bool, rich_kwargs):
if not rich_kwargs:
rich_kwargs = {}
pre = " Ignored:" if self.ignored else ""
severity_detail = None
if self.severity and self.severity.source:
severity_detail = self.severity.source
if self.severity.cvssv3 and "base_severity" in self.severity.cvssv3:
severity_detail += f", CVSS Severity {self.severity.cvssv3['base_severity'].upper()}"
advisory_length = 200 if detailed_output else 110
console.print(
Padding(
f"->{pre} Vuln ID [vuln_id]{self.vulnerability_id}[/vuln_id]: {severity_detail if severity_detail else ''}",
(0, 0, 0, 2),
),
**rich_kwargs,
)
console.print(
Padding(
f"{self.advisory[:advisory_length]}{'...' if len(self.advisory) > advisory_length else ''}",
(0, 0, 0, 5),
),
**rich_kwargs,
)
if detailed_output:
console.print(
Padding(
f"For more information: [link]{self.more_info_url}[/link]",
(0, 0, 0, 5),
),
**rich_kwargs,
)
return __render__
def render_scan_html(report: ReportModel, obj: Any) -> str:
"""
Render the scan report to HTML.
Args:
report (ReportModel): The scan report model.
obj (Any): The object containing additional settings.
Returns:
str: The rendered HTML report.
"""
from safety.scan.command import ScannableEcosystems
project = report.projects[0] if any(report.projects) else None
scanned_packages = 0
affected_packages = 0
ignored_packages = 0
remediations_recommended = 0
ignored_vulnerabilities = 0
vulnerabilities = 0
vulns_per_file = defaultdict(int)
remed_per_file = defaultdict(int)
for file in project.files:
scanned_packages += len(file.results.dependencies)
affected_packages += len(file.results.get_affected_dependencies())
ignored_vulnerabilities += len(file.results.ignored_vulns)
for spec in file.results.get_affected_specifications():
vulnerabilities += len(spec.vulnerabilities)
vulns_per_file[file.location] += len(spec.vulnerabilities)
if spec.remediation:
remed_per_file[file.location] += 1
remediations_recommended += 1
ignored_packages += len(file.results.ignored_vulns)
# TODO: Get this information for the report model (?)
summary = {
"scanned_packages": scanned_packages,
"affected_packages": affected_packages,
"remediations_recommended": remediations_recommended,
"ignored_vulnerabilities": ignored_vulnerabilities,
"vulnerabilities": vulnerabilities,
}
vulnerabilities = []
# TODO: This should be based on the configs per command
ecosystems = [
(
f"{ecosystem.name.title()}",
[file_type.human_name(plural=True) for file_type in ecosystem.file_types],
)
for ecosystem in [
Ecosystem(member.value) for member in list(ScannableEcosystems)
]
]
settings = {
"audit_and_monitor": True,
"platform_url": SAFETY_PLATFORM_URL,
"ecosystems": ecosystems,
}
template_context = {
"report": report,
"summary": summary,
"announcements": [],
"project": project,
"platform_enabled": obj.platform_enabled,
"settings": settings,
"vulns_per_file": vulns_per_file,
"remed_per_file": remed_per_file,
}
return parse_html(kwargs=template_context, template="scan/index.html")
def generate_spdx_creation_info(spdx_version: str, project_identifier: str) -> Any:
"""
Generate SPDX creation information.
Args:
spdx_version (str): The SPDX version.
project_identifier (str): The project identifier.
Returns:
Any: The SPDX creation information.
"""
from spdx_tools.spdx.model import (
Actor,
ActorType,
CreationInfo,
)
version = int(time.time())
SPDX_ID_TYPE = "SPDXRef-DOCUMENT"
DOC_NAME = f"{project_identifier}-{version}"
DOC_NAMESPACE = f"https://spdx.org/spdxdocs/{DOC_NAME}"
# DOC_NAMESPACE = f"urn:safety:{project_identifier}:{version}"
DOC_COMMENT = f"This document was created using SPDX {spdx_version}"
CREATOR_COMMENT = (
"Safety CLI automatically created this SPDX document from a scan report."
)
TOOL_ID = "safety"
TOOL_VERSION = get_version()
doc_creator = Actor(
actor_type=ActorType.TOOL, name=f"{TOOL_ID}-{TOOL_VERSION}", email=None
)
creation_info = CreationInfo(
spdx_version=f"SPDX-{spdx_version}",
spdx_id=SPDX_ID_TYPE,
name=DOC_NAME,
document_namespace=DOC_NAMESPACE,
creators=[doc_creator],
created=datetime.datetime.now(),
document_comment=DOC_COMMENT,
creator_comment=CREATOR_COMMENT,
)
return creation_info
def create_pkg_ext_ref(*, package: PythonDependency, version: Optional[str]) -> Any:
"""
Create an external package reference for SPDX.
Args:
package (PythonDependency): The package dependency.
version (Optional[str]): The package version.
Returns:
Any: The external package reference.
"""
from spdx_tools.spdx.model import (
ExternalPackageRef,
ExternalPackageRefCategory,
)
version_detail = f"@{version}" if version else ""
pkg_ref = ExternalPackageRef(
ExternalPackageRefCategory.PACKAGE_MANAGER,
"purl",
f"pkg:pypi/{package.name}{version_detail}",
)
return pkg_ref
def create_packages(dependencies: List[PythonDependency]) -> List[Any]:
"""
Create a list of SPDX packages.
Args:
dependencies (List[PythonDependency]): List of Python dependencies.
Returns:
List[Any]: List of SPDX packages.
"""
from spdx_tools.spdx.model import (
Package,
)
from spdx_tools.spdx.model.spdx_no_assertion import SpdxNoAssertion
doc_pkgs = []
pkgs_added = set([])
for dependency in dependencies:
for spec in dependency.specifications:
pkg_version = (
next(iter(spec.specifier)).version
if spec.is_pinned()
else f"{spec.specifier}"
)
dep_name = dependency.name.replace("_", "-")
pkg_id = (
f"SPDXRef-pip-{dep_name}-{pkg_version}"
if spec.is_pinned()
else f"SPDXRef-pip-{dep_name}"
)
if pkg_id in pkgs_added:
continue
pkg_ref = create_pkg_ext_ref(package=dependency, version=pkg_version)
pkg = Package(
spdx_id=pkg_id,
name=f"pip:{dep_name}",
download_location=SpdxNoAssertion(),
version=pkg_version,
file_name="",
supplier=SpdxNoAssertion(),
files_analyzed=False,
license_concluded=SpdxNoAssertion(),
license_declared=SpdxNoAssertion(),
copyright_text=SpdxNoAssertion(),
external_references=[pkg_ref],
)
pkgs_added.add(pkg_id)
doc_pkgs.append(pkg)
return doc_pkgs
def create_spdx_document(*, report: ReportModel, spdx_version: str) -> Optional[Any]:
"""
Create an SPDX document.
Args:
report (ReportModel): The scan report model.
spdx_version (str): The SPDX version.
Returns:
Optional[Any]: The SPDX document.
"""
from spdx_tools.spdx.model import (
Document,
Relationship,
RelationshipType,
)
project = report.projects[0] if any(report.projects) else None
if not project:
return None
prj_id = project.id
if not prj_id:
parent_name = project.project_path.parent.name
prj_id = parent_name if parent_name else str(int(time.time()))
creation_info = generate_spdx_creation_info(
spdx_version=spdx_version, project_identifier=prj_id
)
depedencies = iter([])
for file in project.files:
depedencies = itertools.chain(depedencies, file.results.dependencies)
packages = create_packages(depedencies)
# Requirement for document to have atleast one relationship
relationship = Relationship(
"SPDXRef-DOCUMENT", RelationshipType.DESCRIBES, "SPDXRef-DOCUMENT"
)
spdx_doc = Document(creation_info, packages, [], [], [], [relationship], [])
return spdx_doc
def render_scan_spdx(
report: ReportModel, obj: Any, spdx_version: Optional[str]
) -> Optional[Any]:
"""
Render the scan report to SPDX format.
Args:
report (ReportModel): The scan report model.
obj (Any): The object containing additional settings.
spdx_version (Optional[str]): The SPDX version.
Returns:
Optional[Any]: The rendered SPDX document in JSON format.
"""
from spdx_tools.spdx.writer.write_utils import convert, validate_and_deduplicate
# Set to latest supported if a version is not specified
if not spdx_version:
spdx_version = "2.3"
document_obj = create_spdx_document(report=report, spdx_version=spdx_version)
document_obj = validate_and_deduplicate(
document=document_obj, validate=True, drop_duplicates=True
)
doc = None
if document_obj:
doc = convert(document=document_obj, converter=None)
return json.dumps(doc) if doc else None

View File

@@ -0,0 +1,251 @@
from enum import Enum
import logging
import os
from pathlib import Path
import subprocess
from typing import TYPE_CHECKING, Optional, Tuple
from safety.scan.finder.handlers import (
FileHandler,
PythonFileHandler,
SafetyProjectFileHandler,
)
from safety_schemas.models import Stage
if TYPE_CHECKING:
from safety_schemas.models import GITModel
LOG = logging.getLogger(__name__)
class Language(str, Enum):
"""
Enum representing supported programming languages.
"""
python = "python"
javascript = "javascript"
safety_project = "safety_project"
def handler(self) -> FileHandler:
"""
Get the appropriate file handler for the language.
Returns:
FileHandler: The file handler for the language.
"""
if self is Language.python:
return PythonFileHandler()
if self is Language.safety_project:
return SafetyProjectFileHandler()
return PythonFileHandler()
class Output(Enum):
"""
Enum representing output formats.
"""
json = "json"
class AuthenticationType(str, Enum):
"""
Enum representing authentication types.
"""
token = "token"
api_key = "api_key"
none = "unauthenticated"
def is_allowed_in(self, stage: Stage = Stage.development) -> bool:
"""
Check if the authentication type is allowed in the given stage.
Args:
stage (Stage): The current stage.
Returns:
bool: True if the authentication type is allowed, otherwise False.
"""
if self is AuthenticationType.none:
return False
if stage == Stage.development and self is AuthenticationType.api_key:
return False
if (not stage == Stage.development) and self is AuthenticationType.token:
return False
return True
class GIT:
"""
Class representing Git operations.
"""
ORIGIN_CMD: Tuple[str, ...] = ("remote", "get-url", "origin")
BRANCH_CMD: Tuple[str, ...] = ("symbolic-ref", "--short", "-q", "HEAD")
TAG_CMD: Tuple[str, ...] = ("describe", "--tags", "--exact-match")
DESCRIBE_CMD: Tuple[str, ...] = (
"describe",
'--match=""',
"--always",
"--abbrev=40",
"--dirty",
)
GIT_CHECK_CMD: Tuple[str, ...] = ("rev-parse", "--is-inside-work-tree")
def __init__(self, root: Path = Path(".")) -> None:
"""
Initialize the GIT class with the given root directory.
Args:
root (Path): The root directory for Git operations.
"""
self.git = ("git", "-C", root.resolve())
def __run__(
self, cmd: Tuple[str, ...], env_var: Optional[str] = None
) -> Optional[str]:
"""
Run a Git command.
Args:
cmd (Tuple[str, ...]): The Git command to run.
env_var (Optional[str]): An optional environment variable to check for the command result.
Returns:
Optional[str]: The result of the Git command, or None if an error occurred.
"""
if env_var and os.environ.get(env_var):
return os.environ.get(env_var)
try:
return (
subprocess.run(
self.git + cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
)
.stdout.decode("utf-8")
.strip()
)
except Exception as e:
LOG.exception(e)
return None
def origin(self) -> Optional[str]:
"""
Get the Git origin URL.
Returns:
Optional[str]: The Git origin URL, or None if an error occurred.
"""
return self.__run__(self.ORIGIN_CMD, env_var="SAFETY_GIT_ORIGIN")
def branch(self) -> Optional[str]:
"""
Get the current Git branch.
Returns:
Optional[str]: The current Git branch, or None if an error occurred.
"""
return self.__run__(self.BRANCH_CMD, env_var="SAFETY_GIT_BRANCH")
def tag(self) -> Optional[str]:
"""
Get the current Git tag.
Returns:
Optional[str]: The current Git tag, or None if an error occurred.
"""
return self.__run__(self.TAG_CMD, env_var="SAFETY_GIT_TAG")
def describe(self) -> Optional[str]:
"""
Get the Git describe output.
Returns:
Optional[str]: The Git describe output, or None if an error occurred.
"""
return self.__run__(self.DESCRIBE_CMD)
def dirty(self, raw_describe: str) -> bool:
"""
Check if the working directory is dirty.
Args:
raw_describe (str): The raw describe output.
Returns:
bool: True if the working directory is dirty, otherwise False.
"""
if (is_dirty := os.environ.get("SAFETY_GIT_DIRTY")) and is_dirty in ["0", "1"]:
return bool(int(is_dirty))
return raw_describe.endswith("-dirty")
def commit(self, raw_describe: str) -> Optional[str]:
"""
Get the current Git commit hash.
Args:
raw_describe (str): The raw describe output.
Returns:
Optional[str]: The current Git commit hash, or None if an error occurred.
"""
if os.environ.get("SAFETY_GIT_COMMIT"):
return os.environ.get("SAFETY_GIT_COMMIT")
try:
return raw_describe.split("-dirty")[0]
except Exception:
pass
def is_git(self) -> bool:
"""
Check if the current directory is a Git repository.
Returns:
bool: True if the current directory is a Git repository, otherwise False.
"""
result = self.__run__(self.GIT_CHECK_CMD)
if result == "true":
return True
return False
def build_git_data(self) -> Optional["GITModel"]:
"""
Build a GITModel object with Git data.
Returns:
Optional[GITModel]: The GITModel object with Git data, or None if the directory is not a Git repository.
"""
from safety_schemas.models import GITModel
if self.is_git():
raw_describe = self.describe()
commit = None
dirty = False
# TODO: describe fails when there are not commits,
# GitModel needs to support this case too
if raw_describe:
commit = self.commit(raw_describe)
dirty = self.dirty(raw_describe)
return GITModel(
branch=self.branch(),
tag=self.tag(),
commit=commit,
dirty=dirty,
origin=self.origin(),
)
return None

View File

@@ -0,0 +1,75 @@
import os
from pathlib import Path
from typing import Optional, Tuple
import typer
from safety.scan.models import ScanExport, ScanOutput
from safety_schemas.models import AuthenticationType
MISSING_SPDX_EXTENSION_MSG = "spdx extra is not installed, please install it with: pip install safety[spdx]"
def raise_if_not_spdx_extension_installed() -> None:
"""
Raises an error if the spdx extension is not installed.
"""
try:
import spdx_tools.spdx
except Exception as e:
raise typer.BadParameter(MISSING_SPDX_EXTENSION_MSG)
def save_as_callback(save_as: Optional[Tuple[ScanExport, Path]]) -> Tuple[Optional[str], Optional[Path]]:
"""
Callback function to handle save_as parameter and validate if spdx extension is installed.
Args:
save_as (Optional[Tuple[ScanExport, Path]]): The export type and path.
Returns:
Tuple[Optional[str], Optional[Path]]: The validated export type and path.
"""
export_type, export_path = save_as if save_as else (None, None)
if ScanExport.is_format(export_type, ScanExport.SPDX):
raise_if_not_spdx_extension_installed()
return (export_type.value, export_path) if export_type and export_path else (export_type, export_path)
def output_callback(output: ScanOutput) -> str:
"""
Callback function to handle output parameter and validate if spdx extension is installed.
Args:
output (ScanOutput): The output format.
Returns:
str: The validated output format.
"""
if ScanOutput.is_format(output, ScanExport.SPDX):
raise_if_not_spdx_extension_installed()
return output.value
def fail_if_not_allowed_stage(ctx: typer.Context):
"""
Fail the command if the authentication type is not allowed in the current stage.
Args:
ctx (typer.Context): The context of the Typer command.
"""
if ctx.resilient_parsing:
return
stage = ctx.obj.auth.stage
auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type()
if os.getenv("SAFETY_DB_DIR"):
return
if not auth_type.is_allowed_in(stage):
raise typer.BadParameter(f"'{auth_type.value}' auth type isn't allowed with " \
f"the '{stage}' stage.")