This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
__author__ = """safetycli.com"""
__email__ = 'cli@safetycli.com'

View File

@@ -0,0 +1,6 @@
"""Allow safety to be executable through `python -m safety`."""
from safety.cli import cli
if __name__ == "__main__": # pragma: no cover
cli(prog_name="safety")

View File

@@ -0,0 +1,110 @@
import logging
import sys
import json
from typing import Any, IO
import click
from dataclasses import dataclass
from safety.constants import CONTEXT_COMMAND_TYPE
from . import github
from safety.util import SafetyPolicyFile
from safety.scan.constants import CLI_ALERT_COMMAND_HELP
LOG = logging.getLogger(__name__)
def get_safety_cli_legacy_group():
from safety.cli_util import SafetyCLILegacyGroup
return SafetyCLILegacyGroup
def get_context_settings():
from safety.cli_util import CommandType
return {CONTEXT_COMMAND_TYPE: CommandType.UTILITY}
@dataclass
class Alert:
"""
Data class for storing alert details.
Attributes:
report (Any): The report data.
key (str): The API key for the safetycli.com vulnerability database.
policy (Any): The policy data.
requirements_files (Any): The requirements files data.
"""
report: Any
key: str
policy: Any = None
requirements_files: Any = None
@click.group(
cls=get_safety_cli_legacy_group(),
help=CLI_ALERT_COMMAND_HELP,
deprecated=True,
context_settings=get_context_settings(),
)
@click.option(
"--check-report",
help="JSON output of Safety Check to work with.",
type=click.File("r"),
default=sys.stdin,
required=True,
)
@click.option(
"--key",
envvar="SAFETY_API_KEY",
help="API Key for safetycli.com's vulnerability database. Can be set as SAFETY_API_KEY "
"environment variable.",
required=True,
)
@click.option(
"--policy-file",
type=SafetyPolicyFile(),
default=".safety-policy.yml",
help="Define the policy file to be used",
)
@click.pass_context
def alert(
ctx: click.Context, check_report: IO[str], policy_file: SafetyPolicyFile, key: str
) -> None:
"""
Command for processing the Safety Check JSON report.
Args:
ctx (click.Context): The Click context object.
check_report (IO[str]): The file containing the JSON report.
policy_file (SafetyPolicyFile): The policy file to be used.
key (str): The API key for the safetycli.com vulnerability database.
"""
LOG.info("alert started")
LOG.info(f"check_report is using stdin: {check_report == sys.stdin}")
with check_report:
# TODO: This breaks --help for subcommands
try:
safety_report = json.load(check_report)
except json.decoder.JSONDecodeError as e:
LOG.info("Error in the JSON report.")
click.secho("Error decoding input JSON: {}".format(e.msg), fg="red")
sys.exit(1)
if "report_meta" not in safety_report:
click.secho("You must pass in a valid Safety Check JSON report", fg="red")
sys.exit(1)
ctx.obj = Alert(
report=safety_report, policy=policy_file if policy_file else {}, key=key
)
# Adding subcommands for GitHub integration
alert.add_command(github.github_pr)
alert.add_command(github.github_issue)

View File

@@ -0,0 +1,575 @@
# type: ignore
import itertools
import logging
import re
import sys
from typing import Any, Optional
import click
try:
import github as pygithub
except ImportError:
pygithub = None
from packaging.specifiers import SpecifierSet
from packaging.utils import canonicalize_name
from . import utils, requirements
LOG = logging.getLogger(__name__)
def create_branch(repo: Any, base_branch: str, new_branch: str) -> None:
"""
Create a new branch in the given GitHub repository.
Args:
repo (Any): The GitHub repository object.
base_branch (str): The name of the base branch.
new_branch (str): The name of the new branch to create.
"""
ref = repo.get_git_ref("heads/" + base_branch)
repo.create_git_ref(ref="refs/heads/" + new_branch, sha=ref.object.sha)
def delete_branch(repo: Any, branch: str) -> None:
"""
Delete a branch from the given GitHub repository.
Args:
repo (Any): The GitHub repository object.
branch (str): The name of the branch to delete.
"""
ref = repo.get_git_ref(f"heads/{branch}")
ref.delete()
@click.command()
@click.option("--repo", help="GitHub standard repo path (eg, my-org/my-project)")
@click.option("--token", help="GitHub Access Token")
@click.option(
"--base-url",
help="Optional custom Base URL, if you're using GitHub enterprise",
default=None,
)
@click.pass_obj
@utils.require_files_report
def github_pr(obj: Any, repo: str, token: str, base_url: Optional[str]) -> None:
"""
Create a GitHub PR to fix any vulnerabilities using Safety's remediation data.
This is usually run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results.
Args:
obj (Any): The Click context object containing report data.
repo (str): The GitHub repository path.
token (str): The GitHub Access Token.
base_url (Optional[str]): Custom base URL for GitHub Enterprise, if applicable.
"""
if pygithub is None:
click.secho(
"pygithub is not installed. Did you install Safety with GitHub support? Try pip install safety[github]",
fg="red",
)
sys.exit(1)
# Load alert configurations from the policy
alert = obj.policy.get("alert", {}) or {}
security = alert.get("security", {}) or {}
config_pr = security.get("github-pr", {}) or {}
branch_prefix = config_pr.get("branch-prefix", "pyup/")
pr_prefix = config_pr.get("pr-prefix", "[PyUp] ")
assignees = config_pr.get("assignees", [])
labels = config_pr.get("labels", ["security"])
label_severity = config_pr.get("label-severity", True)
ignore_cvss_severity_below = config_pr.get("ignore-cvss-severity-below", 0)
ignore_cvss_unknown_severity = config_pr.get("ignore-cvss-unknown-severity", False)
# Authenticate with GitHub
gh = pygithub.Github(token, **({"base_url": base_url} if base_url else {}))
repo_name = repo
repo = gh.get_repo(repo)
try:
self_user = gh.get_user().login
except pygithub.GithubException:
# If we're using a token from an action (or integration) we can't call `get_user()`. Fall back
# to assuming we're running under an action
self_user = "web-flow"
# Collect all remediations from the report
req_remediations = list(
itertools.chain.from_iterable(
rem.get("requirements", {}).values()
for pkg_name, rem in obj.report["remediations"].items()
)
)
# Get all open pull requests for the repository
pulls = repo.get_pulls(state="open", sort="created", base=repo.default_branch)
pending_updates = set(
[
f"{canonicalize_name(req_rem['requirement']['name'])}{req_rem['requirement']['specifier']}"
for req_rem in req_remediations
]
)
created = 0
# TODO: Refactor this loop into a fn to iterate over remediations nicely
# Iterate over all requirements files and process each remediation
for name, contents in obj.requirements_files.items():
raw_contents = contents
contents = contents.decode("utf-8") # TODO - encoding?
parsed_req_file = requirements.RequirementFile(name, contents)
for remediation in req_remediations:
pkg = remediation["requirement"]["name"]
pkg_canonical_name: str = canonicalize_name(pkg)
analyzed_spec: str = remediation["requirement"]["specifier"]
# Skip remediations without a recommended version
if remediation["recommended_version"] is None:
LOG.debug(
f"The GitHub PR alerter only currently supports remediations that have a recommended_version: {pkg}"
)
continue
# We have a single remediation that can have multiple vulnerabilities
# Find all vulnerabilities associated with the remediation
vulns = [
x
for x in obj.report["vulnerabilities"]
if x["package_name"] == pkg_canonical_name
and x["analyzed_requirement"]["specifier"] == analyzed_spec
]
# Skip if all vulnerabilities have unknown severity and the ignore flag is set
if ignore_cvss_unknown_severity and all(
x["severity"] is None for x in vulns
):
LOG.debug(
"All vulnerabilities have unknown severity, and ignore_cvss_unknown_severity is set."
)
continue
highest_base_score = 0
for vuln in vulns:
if vuln["severity"] is not None:
highest_base_score = max(
highest_base_score,
(vuln["severity"].get("cvssv3", {}) or {}).get(
"base_score", 10
),
)
# Skip if none of the vulnerabilities meet the severity threshold
if ignore_cvss_severity_below:
at_least_one_match = False
for vuln in vulns:
# Consider a None severity as a match, since it's controlled by a different flag
# If we can't find a base_score but we have severity data, assume it's critical for now.
if (
vuln["severity"] is None
or (vuln["severity"].get("cvssv3", {}) or {}).get(
"base_score", 10
)
>= ignore_cvss_severity_below
):
at_least_one_match = True
if not at_least_one_match:
LOG.debug(
f"None of the vulnerabilities found have a score greater than or equal to the ignore_cvss_severity_below of {ignore_cvss_severity_below}"
)
continue
for parsed_req in parsed_req_file.requirements:
specs = (
SpecifierSet(">=0")
if parsed_req.specs == SpecifierSet("")
else parsed_req.specs
)
# Check if the requirement matches the remediation
if (
canonicalize_name(parsed_req.name) == pkg_canonical_name
and str(specs) == analyzed_spec
):
updated_contents = parsed_req.update_version(
contents, remediation["recommended_version"]
)
pending_updates.discard(f"{pkg_canonical_name}{analyzed_spec}")
new_branch = branch_prefix + utils.generate_branch_name(
pkg, remediation
)
skip_create = False
# Few possible cases:
# 1. No existing PRs exist for this change (don't need to handle)
# 2. An existing PR exists, and it's out of date (eg, recommended 0.5.1 and we want 0.5.2)
# 3. An existing PR exists, and it's not mergable anymore (eg, needs a rebase)
# 4. An existing PR exists, and everything's up to date.
# 5. An existing PR exists, but it's not needed anymore (perhaps we've been updated to a later version)
# 6. No existing PRs exist, but a branch does exist (perhaps the PR was closed but a stale branch left behind)
# In any case, we only act if we've been the only committer to the branch.
# Handle various cases for existing pull requests
for pr in pulls:
if not pr.head.ref.startswith(branch_prefix):
continue
authors = [
commit.committer.login for commit in pr.get_commits()
]
only_us = all([x == self_user for x in authors])
try:
_, pr_pkg, pr_spec, pr_ver = pr.head.ref.split("/")
except ValueError:
# It's possible that something weird has manually been done, so skip that
# Skip invalid branch names
LOG.debug(
"Found an invalid branch name on an open PR, that matches our prefix. Skipping."
)
continue
pr_pkg = canonicalize_name(pr_pkg)
if pr_pkg != pkg_canonical_name:
continue
# Case 4: An up-to-date PR exists
if (
pr_pkg == pkg_canonical_name
and pr_spec == analyzed_spec
and pr_ver == remediation["recommended_version"]
and pr.mergeable
):
LOG.debug(
f"An up to date PR #{pr.number} for {pkg} was found, no action will be taken."
)
skip_create = True
continue
if not only_us:
LOG.debug(
f"There are other committers on the PR #{pr.number} for {pkg}. No further action will be taken."
)
continue
# Case 2: An existing PR is out of date
if (
pr_pkg == pkg_canonical_name
and pr_spec == analyzed_spec
and pr_ver != remediation["recommended_version"]
):
LOG.debug(
f"Closing stale PR #{pr.number} for {pkg} as a newer recommended version became"
)
pr.create_issue_comment(
"This PR has been replaced, since a newer recommended version became available."
)
pr.edit(state="closed")
delete_branch(repo, pr.head.ref)
# Case 3: An existing PR is not mergeable
if not pr.mergeable:
LOG.debug(
f"Closing PR #{pr.number} for {pkg} as it has become unmergable and we were the only committer"
)
pr.create_issue_comment(
"This PR has been replaced since it became unmergable."
)
pr.edit(state="closed")
delete_branch(repo, pr.head.ref)
# Skip if no changes were made
if updated_contents == contents:
LOG.debug(
f"Couldn't update {pkg} to {remediation['recommended_version']}"
)
continue
# Skip creation if indicated
if skip_create:
continue
# Create a new branch and commit the changes
try:
create_branch(repo, repo.default_branch, new_branch)
except pygithub.GithubException as e:
if e.data["message"] == "Reference already exists":
# There might be a stale branch. If the bot is the only committer, nuke it.
comparison = repo.compare(repo.default_branch, new_branch)
authors = [
commit.committer.login for commit in comparison.commits
]
only_us = all([x == self_user for x in authors])
if only_us:
delete_branch(repo, new_branch)
create_branch(repo, repo.default_branch, new_branch)
else:
LOG.debug(
f"The branch '{new_branch}' already exists - but there is no matching PR and this branch has committers other than us. This remediation will be skipped."
)
continue
else:
raise e
try:
repo.update_file(
path=name,
message=utils.generate_commit_message(pkg, remediation),
content=updated_contents,
branch=new_branch,
sha=utils.git_sha1(raw_contents),
)
except pygithub.GithubException as e:
if "does not match" in e.data["message"]:
click.secho(
f"GitHub blocked a commit on our branch to the requirements file, {name}, as the local hash we computed didn't match the version on {repo.default_branch}. Make sure you're running safety against the latest code on your default branch.",
fg="red",
)
continue
else:
raise e
pr = repo.create_pull(
title=pr_prefix + utils.generate_title(pkg, remediation, vulns),
body=utils.generate_body(
pkg, remediation, vulns, api_key=obj.key
),
head=new_branch,
base=repo.default_branch,
)
LOG.debug(f"Created Pull Request to update {pkg}")
created += 1
# Add assignees and labels to the PR
for assignee in assignees:
pr.add_to_assignees(assignee)
for label in labels:
pr.add_to_labels(label)
if label_severity:
score_as_label = utils.cvss3_score_to_label(highest_base_score)
if score_as_label:
pr.add_to_labels(score_as_label)
if len(pending_updates) > 0:
click.secho(
"The following remediations were not followed: {}".format(
", ".join(pending_updates)
),
fg="red",
)
if created:
click.secho(
f"Safety successfully created {created} GitHub PR{'s' if created > 1 else ''} for repo {repo_name}"
)
else:
click.secho(
"No PRs created; please run the command with debug mode for more information."
)
@click.command()
@click.option("--repo", help="GitHub standard repo path (eg, my-org/my-project)")
@click.option("--token", help="GitHub Access Token")
@click.option(
"--base-url",
help="Optional custom Base URL, if you're using GitHub enterprise",
default=None,
)
@click.pass_obj
@utils.require_files_report # TODO: For now, it can be removed in the future to support env scans.
def github_issue(obj: Any, repo: str, token: str, base_url: Optional[str]) -> None:
"""
Create a GitHub Issue for any vulnerabilities found using PyUp's remediation data.
Normally, this is run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results.
Args:
obj (Any): The Click context object containing report data.
repo (str): The GitHub repository path.
token (str): The GitHub Access Token.
base_url (Optional[str]): Custom base URL for GitHub Enterprise, if applicable.
"""
LOG.info("github_issue")
if pygithub is None:
click.secho(
"pygithub is not installed. Did you install Safety with GitHub support? Try pip install safety[github]",
fg="red",
)
sys.exit(1)
# Load alert configurations from the policy
alert = obj.policy.get("alert", {}) or {}
security = alert.get("security", {}) or {}
config_issue = security.get("github-issue", {}) or {}
issue_prefix = config_issue.get("issue-prefix", "[PyUp] ")
assignees = config_issue.get("assignees", [])
labels = config_issue.get("labels", ["security"])
label_severity = config_issue.get("label-severity", True)
ignore_cvss_severity_below = config_issue.get("ignore-cvss-severity-below", 0)
ignore_cvss_unknown_severity = config_issue.get(
"ignore-cvss-unknown-severity", False
)
# Authenticate with GitHub
gh = pygithub.Github(token, **({"base_url": base_url} if base_url else {}))
repo_name = repo
repo = gh.get_repo(repo)
# Get all open issues for the repository
issues = list(repo.get_issues(state="open", sort="created"))
ISSUE_TITLE_REGEX = re.escape(issue_prefix) + r"Security Vulnerability in (.+)"
req_remediations = list(
itertools.chain.from_iterable(
rem.get("requirements", {}).values()
for pkg_name, rem in obj.report["remediations"].items()
)
)
created = 0
# Iterate over all requirements files and process each remediation
for name, contents in obj.requirements_files.items():
contents = contents.decode("utf-8") # TODO - encoding?
parsed_req_file = requirements.RequirementFile(name, contents)
for remediation in req_remediations:
pkg: str = remediation["requirement"]["name"]
pkg_canonical_name: str = canonicalize_name(pkg)
analyzed_spec: str = remediation["requirement"]["specifier"]
# Skip remediations without a recommended version
if remediation["recommended_version"] is None:
LOG.debug(
f"The GitHub Issue alerter only currently supports remediations that have a recommended_version: {pkg}"
)
continue
# We have a single remediation that can have multiple vulnerabilities
# Find all vulnerabilities associated with the remediation
vulns = [
x
for x in obj.report["vulnerabilities"]
if x["package_name"] == pkg_canonical_name
and x["analyzed_requirement"]["specifier"] == analyzed_spec
]
# Skip if all vulnerabilities have unknown severity and the ignore flag is set
if ignore_cvss_unknown_severity and all(
x["severity"] is None for x in vulns
):
LOG.debug(
"All vulnerabilities have unknown severity, and ignore_cvss_unknown_severity is set."
)
continue
highest_base_score = 0
for vuln in vulns:
if vuln["severity"] is not None:
highest_base_score = max(
highest_base_score,
(vuln["severity"].get("cvssv3", {}) or {}).get(
"base_score", 10
),
)
# Skip if none of the vulnerabilities meet the severity threshold
if ignore_cvss_severity_below:
at_least_one_match = False
for vuln in vulns:
# Consider a None severity as a match, since it's controlled by a different flag
# If we can't find a base_score but we have severity data, assume it's critical for now.
if (
vuln["severity"] is None
or (vuln["severity"].get("cvssv3", {}) or {}).get(
"base_score", 10
)
>= ignore_cvss_severity_below
):
at_least_one_match = True
break
if not at_least_one_match:
LOG.debug(
f"None of the vulnerabilities found have a score greater than or equal to the ignore_cvss_severity_below of {ignore_cvss_severity_below}"
)
continue
for parsed_req in parsed_req_file.requirements:
specs = (
SpecifierSet(">=0")
if parsed_req.specs == SpecifierSet("")
else parsed_req.specs
)
if (
canonicalize_name(parsed_req.name) == pkg_canonical_name
and str(specs) == analyzed_spec
):
skip = False
for issue in issues:
match = re.match(ISSUE_TITLE_REGEX, issue.title)
if match:
group = match.group(1)
if (
group == f"{pkg}{analyzed_spec}"
or group == f"{pkg_canonical_name}{analyzed_spec}"
):
skip = True
break
# For now, we just skip issues if they already exist - we don't try and update them.
# Skip if an issue already exists for this remediation
if skip:
LOG.debug(
f"An issue already exists for {pkg}{analyzed_spec} - skipping"
)
continue
# Create a new GitHub issue
pr = repo.create_issue(
title=issue_prefix
+ utils.generate_issue_title(pkg, remediation),
body=utils.generate_issue_body(
pkg, remediation, vulns, api_key=obj.key
),
)
created += 1
LOG.debug(f"Created issue to update {pkg}")
# Add assignees and labels to the issue
for assignee in assignees:
pr.add_to_assignees(assignee)
for label in labels:
pr.add_to_labels(label)
if label_severity:
score_as_label = utils.cvss3_score_to_label(highest_base_score)
if score_as_label:
pr.add_to_labels(score_as_label)
if created:
click.secho(
f"Safety successfully created {created} new GitHub Issue{'s' if created > 1 else ''} for repo {repo_name}"
)
else:
click.secho(
"No issues created; please run the command with debug mode for more information."
)

View File

@@ -0,0 +1,564 @@
# type: ignore
from __future__ import unicode_literals
from packaging.version import parse as parse_version
from packaging.specifiers import SpecifierSet
import requests
from typing import Any, Optional, Generator, Tuple, List
from safety.meta import get_meta_http_headers
from datetime import datetime
from dparse import parse, parser, updater, filetypes
from dparse.dependencies import Dependency
from dparse.parser import setuptools_parse_requirements_backport as parse_requirements
class RequirementFile(object):
"""
Class representing a requirements file with its content and metadata.
Attributes:
path (str): The file path.
content (str): The content of the file.
sha (Optional[str]): The SHA of the file.
"""
def __init__(self, path: str, content: str, sha: Optional[str] = None):
self.path = path
self.content = content
self.sha = sha
self._requirements: Optional[List] = None
self._other_files: Optional[List] = None
self._is_valid = None
self.is_pipfile = False
self.is_pipfile_lock = False
self.is_setup_cfg = False
def __str__(self) -> str:
return (
"RequirementFile(path='{path}', sha='{sha}', content='{content}')".format(
path=self.path,
content=self.content[:30] + "[truncated]"
if len(self.content) > 30
else self.content,
sha=self.sha,
)
)
@property
def is_valid(self) -> Optional[bool]:
"""
Checks if the requirements file is valid by parsing it.
Returns:
bool: True if the file is valid, False otherwise.
"""
if self._is_valid is None:
self._parse()
return self._is_valid
@property
def requirements(self) -> Optional[List]:
"""
Returns the list of requirements parsed from the file.
Returns:
List: The list of requirements.
"""
if not self._requirements:
self._parse()
return self._requirements
@property
def other_files(self) -> Optional[List]:
"""
Returns the list of other files resolved from the requirements file.
Returns:
List: The list of other files.
"""
if not self._other_files:
self._parse()
return self._other_files
@staticmethod
def parse_index_server(line: str) -> Optional[str]:
"""
Parses the index server from a given line.
Args:
line (str): The line to parse.
Returns:
str: The parsed index server.
"""
return parser.Parser.parse_index_server(line)
def _hash_parser(self, line: str) -> Optional[Tuple[str, List[str]]]:
"""
Parses the hashes from a given line.
Args:
line (str): The line to parse.
Returns:
List: The list of parsed hashes.
"""
return parser.Parser.parse_hashes(line)
def _parse_requirements_txt(self) -> None:
"""
Parses the requirements.txt file format.
"""
self.parse_dependencies(filetypes.requirements_txt)
def _parse_conda_yml(self) -> None:
"""
Parses the conda.yml file format.
"""
self.parse_dependencies(filetypes.conda_yml)
def _parse_tox_ini(self) -> None:
"""
Parses the tox.ini file format.
"""
self.parse_dependencies(filetypes.tox_ini)
def _parse_pipfile(self) -> None:
"""
Parses the Pipfile format.
"""
self.parse_dependencies(filetypes.pipfile)
self.is_pipfile = True
def _parse_pipfile_lock(self) -> None:
"""
Parses the Pipfile.lock format.
"""
self.parse_dependencies(filetypes.pipfile_lock)
self.is_pipfile_lock = True
def _parse_setup_cfg(self) -> None:
"""
Parses the setup.cfg format.
"""
self.parse_dependencies(filetypes.setup_cfg)
self.is_setup_cfg = True
def _parse(self) -> None:
"""
Parses the requirements file to extract dependencies and other files.
"""
self._requirements, self._other_files = [], []
if self.path.endswith(".yml") or self.path.endswith(".yaml"):
self._parse_conda_yml()
elif self.path.endswith(".ini"):
self._parse_tox_ini()
elif self.path.endswith("Pipfile"):
self._parse_pipfile()
elif self.path.endswith("Pipfile.lock"):
self._parse_pipfile_lock()
elif self.path.endswith("setup.cfg"):
self._parse_setup_cfg()
else:
self._parse_requirements_txt()
self._is_valid = len(self._requirements) > 0 or len(self._other_files) > 0
def parse_dependencies(self, file_type: str) -> None:
"""
Parses the dependencies from the content based on the file type.
Args:
file_type (str): The type of the file.
"""
result = parse(
self.content,
path=self.path,
sha=self.sha,
file_type=file_type,
marker=(
("pyup: ignore file", "pyup:ignore file"), # file marker
("pyup: ignore", "pyup:ignore"), # line marker
),
)
for dep in result.dependencies:
req = Requirement(
name=dep.name,
specs=dep.specs,
line=dep.line,
lineno=dep.line_numbers[0] if dep.line_numbers else 0,
extras=dep.extras,
file_type=file_type,
)
req.index_server = dep.index_server
if self.is_pipfile:
req.pipfile = self.path
req.hashes = dep.hashes
self._requirements.append(req)
self._other_files = result.resolved_files
def iter_lines(self, lineno: int = 0) -> Generator[str, None, None]:
"""
Iterates over lines in the content starting from a specific line number.
Args:
lineno (int): The line number to start from.
Yields:
str: The next line in the content.
"""
for line in self.content.splitlines()[lineno:]:
yield line
@classmethod
def resolve_file(cls, file_path: str, line: str) -> str:
"""
Resolves a file path from a given line.
Args:
file_path (str): The file path to resolve.
line (str): The line containing the file path.
Returns:
str: The resolved file path.
"""
return parser.Parser.resolve_file(file_path, line)
class Requirement(object):
"""
Class representing a single requirement.
Attributes:
name (str): The name of the requirement.
specs (SpecifierSet): The version specifiers for the requirement.
line (str): The line containing the requirement.
lineno (int): The line number of the requirement.
extras (List): The extras for the requirement.
file_type (str): The type of the file containing the requirement.
"""
def __init__(
self,
name: str,
specs: SpecifierSet,
line: str,
lineno: int,
extras: List,
file_type: str,
):
self.name = name
self.key = name.lower()
self.specs = specs
self.line = line
self.lineno = lineno
self.index_server = None
self.extras = extras
self.hashes = []
self.file_type = file_type
self.pipfile: Optional[str] = None
self.hashCmp = (
self.key,
self.specs,
frozenset(self.extras),
)
self._is_insecure = None
self._changelog = None
# Convert compatible releases to a range of versions
if (
len(self.specs._specs) == 1
and next(iter(self.specs._specs))._spec[0] == "~="
):
# convert compatible releases to something more easily consumed,
# e.g. '~=1.2.3' is equivalent to '>=1.2.3,<1.3.0', while '~=1.2'
# is equivalent to '>=1.2,<2.0'
min_version = next(iter(self.specs._specs))._spec[1]
max_version = list(parse_version(min_version).release)
max_version[-1] = 0
max_version[-2] = max_version[-2] + 1
max_version = ".".join(str(x) for x in max_version)
self.specs = SpecifierSet(">=%s,<%s" % (min_version, max_version))
def __eq__(self, other: Any) -> bool:
return isinstance(other, Requirement) and self.hashCmp == other.hashCmp
def __ne__(self, other: Any) -> bool:
return not self == other
def __str__(self) -> str:
return "Requirement.parse({line}, {lineno})".format(
line=self.line, lineno=self.lineno
)
def __repr__(self) -> str:
return self.__str__()
@property
def is_pinned(self) -> bool:
"""
Checks if the requirement is pinned to a specific version.
Returns:
bool: True if pinned, False otherwise.
"""
if (
len(self.specs._specs) == 1
and next(iter(self.specs._specs))._spec[0] == "=="
):
return True
return False
@property
def is_open_ranged(self) -> bool:
"""
Checks if the requirement has an open range of versions.
Returns:
bool: True if open ranged, False otherwise.
"""
if (
len(self.specs._specs) == 1
and next(iter(self.specs._specs))._spec[0] == ">="
):
return True
return False
@property
def is_ranged(self) -> bool:
"""
Checks if the requirement has a range of versions.
Returns:
bool: True if ranged, False otherwise.
"""
return len(self.specs._specs) >= 1 and not self.is_pinned
@property
def is_loose(self) -> bool:
"""
Checks if the requirement has no version specifiers.
Returns:
bool: True if loose, False otherwise.
"""
return len(self.specs._specs) == 0
@staticmethod
def convert_semver(version: str) -> dict:
"""
Converts a version string to a semantic version dictionary.
Args:
version (str): The version string.
Returns:
dict: The semantic version dictionary.
"""
semver = {"major": 0, "minor": 0, "patch": 0}
version_parts = version.split(".")
# don't be overly clever here. repitition makes it more readable and works exactly how
# it is supposed to
try:
semver["major"] = int(version_parts[0])
semver["minor"] = int(version_parts[1])
semver["patch"] = int(version_parts[2])
except (IndexError, ValueError):
pass
return semver
@property
def can_update_semver(self) -> bool:
"""
Checks if the requirement can be updated based on semantic versioning rules.
Returns:
bool: True if it can be updated, False otherwise.
"""
# return early if there's no update filter set
if "pyup: update" not in self.line:
return True
update = self.line.split("pyup: update")[1].strip().split("#")[0]
current_version = Requirement.convert_semver(
next(iter(self.specs._specs))._spec[1]
)
next_version = Requirement.convert_semver(self.latest_version) # type: ignore
if update == "major":
if current_version["major"] < next_version["major"]:
return True
elif update == "minor":
if (
current_version["major"] < next_version["major"]
or current_version["minor"] < next_version["minor"]
):
return True
return False
@property
def filter(self):
"""
Returns the filter for the requirement if specified.
Returns:
Optional[SpecifierSet]: The filter specifier set, or None if not specified.
"""
rqfilter = False
if "rq.filter:" in self.line:
rqfilter = self.line.split("rq.filter:")[1].strip().split("#")[0]
elif "pyup:" in self.line:
if "pyup: update" not in self.line:
rqfilter = self.line.split("pyup:")[1].strip().split("#")[0]
# unset the filter once the date set in 'until' is reached
if "until" in rqfilter:
rqfilter, until = [part.strip() for part in rqfilter.split("until")]
try:
until = datetime.strptime(until, "%Y-%m-%d")
if until < datetime.now():
rqfilter = False
except ValueError:
# wrong date formatting
pass
if rqfilter:
try:
(rqfilter,) = parse_requirements("filter " + rqfilter)
if len(rqfilter.specifier._specs) > 0:
return rqfilter.specifier
except ValueError:
pass
return False
@property
def version(self) -> Optional[str]:
"""
Returns the current version of the requirement.
Returns:
Optional[str]: The current version, or None if not pinned.
"""
if self.is_pinned:
return next(iter(self.specs._specs))._spec[1]
specs = self.specs
if self.filter:
specs = SpecifierSet(
",".join(
[
"".join(s._spec)
for s in list(specs._specs) + list(self.filter._specs)
]
)
)
return self.get_latest_version_within_specs( # type: ignore
specs,
versions=self.package.versions,
prereleases=self.prereleases, # type: ignore
)
def get_hashes(self, version: str) -> List:
"""
Retrieves the hashes for a specific version from PyPI.
Args:
version (str): The version to retrieve hashes for.
Returns:
List: A list of hashes for the specified version.
"""
headers = get_meta_http_headers()
r = requests.get(
"https://pypi.org/pypi/{name}/{version}/json".format(
name=self.key, version=version
),
headers=headers,
)
hashes = []
data = r.json()
for item in data.get("urls", {}):
sha256 = item.get("digests", {}).get("sha256", False)
if sha256:
hashes.append({"hash": sha256, "method": "sha256"})
return hashes
def update_version(
self, content: str, version: str, update_hashes: bool = True
) -> str:
"""
Updates the version of the requirement in the content.
Args:
content (str): The original content.
version (str): The new version to update to.
update_hashes (bool): Whether to update the hashes as well.
Returns:
str: The updated content.
"""
if self.file_type == filetypes.tox_ini:
updater_class = updater.ToxINIUpdater
elif self.file_type == filetypes.conda_yml:
updater_class = updater.CondaYMLUpdater
elif self.file_type == filetypes.requirements_txt:
updater_class = updater.RequirementsTXTUpdater
elif self.file_type == filetypes.pipfile:
updater_class = updater.PipfileUpdater
elif self.file_type == filetypes.pipfile_lock:
updater_class = updater.PipfileLockUpdater
elif self.file_type == filetypes.setup_cfg:
updater_class = updater.SetupCFGUpdater
else:
raise NotImplementedError
dep = Dependency(
name=self.name,
specs=self.specs,
line=self.line,
line_numbers=[
self.lineno,
]
if self.lineno != 0
else None,
dependency_type=self.file_type,
hashes=self.hashes,
extras=self.extras,
)
hashes = []
if self.hashes and update_hashes:
hashes = self.get_hashes(version)
return updater_class.update(
content=content, dependency=dep, version=version, hashes=hashes, spec="=="
)
@classmethod
def parse(
cls, s: str, lineno: int, file_type: str = filetypes.requirements_txt
) -> "Requirement":
"""
Parses a requirement from a line of text.
Args:
s (str): The line of text.
lineno (int): The line number.
file_type (str): The type of the file containing the requirement.
Returns:
Requirement: The parsed requirement.
"""
# setuptools requires a space before the comment. If this isn't the case, add it.
if "\t#" in s:
(parsed,) = parse_requirements(s.replace("\t#", "\t #"))
else:
(parsed,) = parse_requirements(s)
return cls(
name=parsed.name,
specs=parsed.specifier,
line=s,
lineno=lineno,
extras=list(parsed.extras),
file_type=file_type,
)

View File

@@ -0,0 +1,54 @@
Safety has detected a vulnerable package, [{{ pkg }}]({{ remediation['more_info_url'] }}), that should be updated from **{% if remediation['version'] %}{{ remediation['version'] }}{% else %}{{ remediation['requirement']['specifier'] }}{% endif %}** to **{{ remediation['recommended_version'] }}** to fix {{ vulns | length }} vulnerabilit{{ "y" if vulns|length == 1 else "ies" }}{% if overall_impact %}{{ " rated " + overall_impact if vulns|length == 1 else " with the highest CVSS severity rating being " + overall_impact }}{% endif %}.
To read more about the impact of {{ "this vulnerability" if vulns|length == 1 else "these vulnerabilities" }} see [PyUps {{ pkg }} page]({{ remediation['more_info_url'] }}).
{{ hint }}
If you're using `pip`, you can run:
```
pip install {{ pkg }}=={{ remediation['recommended_version'] }}
# Followed by a pip freeze
```
<details>
<summary>Vulnerabilities Found</summary>
{% for vuln in vulns %}
* {{ vuln.advisory }}
{% if vuln.severity and vuln.severity.cvssv3 and vuln.severity.cvssv3.base_severity %}
* This vulnerability was rated {{ vuln.severity.cvssv3.base_severity }} ({{ vuln.severity.cvssv3.base_score }}) on CVSSv3.
{% endif %}
* To read more about this vulnerability, see PyUps [vulnerability page]({{ vuln.more_info_url }})
{% endfor %}
</details>
<details>
<summary>Changelog from {{ remediation['requirement']['name'] }}{{ remediation['requirement']['specifier'] }} to {{ remediation['recommended_version'] }}</summary>
{% if summary_changelog %}
The full changelog is too long to post here. See [PyUps {{ pkg }} page]({{ remediation['more_info_url'] }}) for more information.
{% else %}
{% for version, log in changelog.items() %}
### {{ version }}
```
{{ log }}
```
{% endfor %}
{% endif %}
</details>
<details>
<summary>Ignoring {{ "This Vulnerability" if vulns|length == 1 else "These Vulnerabilities" }}</summary>
If you wish to [ignore this vulnerability](https://docs.pyup.io/docs/safety-20-policy-file), you can add the following to `.safety-policy.yml` in this repo:
```
security:
ignore-vulnerabilities:{% for vuln in vulns %}
{{ vuln.vulnerability_id }}:
reason: enter a reason as to why you're ignoring this vulnerability
expires: 'YYYY-MM-DD' # datetime string - date this ignore will expire
{% endfor %}
```
</details>

View File

@@ -0,0 +1,47 @@
Vulnerability fix: This PR updates [{{ pkg }}]({{ remediation['more_info_url'] }}) from **{% if remediation['version'] %}{{ remediation['version'] }}{% else %}{{ remediation['requirement']['specifier'] }}{% endif %}** to **{{ remediation['recommended_version'] }}** to fix {{ vulns | length }} vulnerabilit{{ "y" if vulns|length == 1 else "ies" }}{% if overall_impact %}{{ " rated " + overall_impact if vulns|length == 1 else " with the highest CVSS severity rating being " + overall_impact }}{% endif %}.
To read more about the impact of {{ "this vulnerability" if vulns|length == 1 else "these vulnerabilities" }} see [PyUps {{ pkg }} page]({{ remediation['more_info_url'] }}).
{{ hint }}
<details>
<summary>Vulnerabilities Fixed</summary>
{% for vuln in vulns %}
* {{ vuln.advisory }}
{% if vuln.severity and vuln.severity.cvssv3 and vuln.severity.cvssv3.base_severity %}
* This vulnerability was rated {{ vuln.severity.cvssv3.base_severity }} ({{ vuln.severity.cvssv3.base_score }}) on CVSSv3.
{% endif %}
* To read more about this vulnerability, see PyUps [vulnerability page]({{ vuln.more_info_url }})
{% endfor %}
</details>
<details>
<summary>Changelog</summary>
{% if summary_changelog %}
The full changelog is too long to post here. See [PyUps {{ pkg }} page]({{ remediation['more_info_url'] }}) for more information.
{% else %}
{% for version, log in changelog.items() %}
### {{ version }}
```
{{ log }}
```
{% endfor %}
{% endif %}
</details>
<details>
<summary>Ignoring {{ "This Vulnerability" if vulns|length == 1 else "These Vulnerabilities" }}</summary>
If you wish to [ignore this vulnerability](https://docs.pyup.io/docs/safety-20-policy-file), you can add the following to `.safety-policy.yml` in this repo:
```
security:
ignore-vulnerabilities:{% for vuln in vulns %}
{{ vuln.vulnerability_id }}:
reason: enter a reason as to why you're ignoring this vulnerability
expires: 'YYYY-MM-DD' # datetime string - date this ignore will expire
{% endfor %}
```
</details>

View File

@@ -0,0 +1,393 @@
# type: ignore
import hashlib
import os
import sys
from functools import wraps
from typing import Optional, List, Dict, Any
from packaging.version import parse as parse_version
from packaging.specifiers import SpecifierSet
from pathlib import Path
import click
# Jinja2 will only be installed if the optional deps are installed.
# It's fine if our functions fail, but don't let this top level
# import error out.
from safety.models import is_pinned_requirement
from safety.output_utils import (
get_unpinned_hint,
get_specifier_range_info,
get_fix_hint_for_unpinned,
)
try:
import jinja2
except ImportError:
jinja2 = None
import requests
from safety.meta import get_meta_http_headers
def highest_base_score(vulns: List[Dict[str, Any]]) -> float:
"""
Calculates the highest CVSS base score from a list of vulnerabilities.
Args:
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
Returns:
float: The highest CVSS base score.
"""
highest_base_score = 0
for vuln in vulns:
if vuln["severity"] is not None:
highest_base_score = max(
highest_base_score,
(vuln["severity"].get("cvssv3", {}) or {}).get("base_score", 10),
)
return highest_base_score
def generate_branch_name(pkg: str, remediation: Dict[str, Any]) -> str:
"""
Generates a branch name for a given package and remediation.
Args:
pkg (str): The package name.
remediation (Dict[str, Any]): The remediation data.
Returns:
str: The generated branch name.
"""
return f"{pkg}/{remediation['requirement']['specifier']}/{remediation['recommended_version']}"
def generate_issue_title(pkg: str, remediation: Dict[str, Any]) -> str:
"""
Generates an issue title for a given package and remediation.
Args:
pkg (str): The package name.
remediation (Dict[str, Any]): The remediation data.
Returns:
str: The generated issue title.
"""
return f"Security Vulnerability in {pkg}{remediation['requirement']['specifier']}"
def get_hint(remediation: Dict[str, Any]) -> str:
"""
Generates a hint for a given remediation.
Args:
remediation (Dict[str, Any]): The remediation data.
Returns:
str: The generated hint.
"""
pinned = is_pinned_requirement(
SpecifierSet(remediation["requirement"]["specifier"])
)
hint = ""
if not pinned:
fix_hint = get_fix_hint_for_unpinned(remediation)
hint = (
f"{fix_hint}\n\n{get_unpinned_hint(remediation['requirement']['name'])} "
f"{get_specifier_range_info(style=False)}"
)
return hint
def generate_title(
pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]]
) -> str:
"""
Generates a title for a pull request or issue.
Args:
pkg (str): The package name.
remediation (Dict[str, Any]): The remediation data.
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
Returns:
str: The generated title.
"""
suffix = "y" if len(vulns) == 1 else "ies"
from_dependency = (
remediation["version"]
if remediation["version"]
else remediation["requirement"]["specifier"]
)
return f"Update {pkg} from {from_dependency} to {remediation['recommended_version']} to fix {len(vulns)} vulnerabilit{suffix}"
def generate_body(
pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]], *, api_key: str
) -> Optional[str]:
"""
Generates the body content for a pull request.
Args:
pkg (str): The package name.
remediation (Dict[str, Any]): The remediation data.
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
api_key (str): The API key for fetching changelog data.
Returns:
str: The generated body content.
"""
changelog = fetch_changelog(
pkg,
remediation["version"],
remediation["recommended_version"],
api_key=api_key,
from_spec=remediation.get("requirement", {}).get("specifier", None),
)
p = Path(__file__).parent / "templates"
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(p))) # type: ignore
template = env.get_template("pr.jinja2")
overall_impact = cvss3_score_to_label(highest_base_score(vulns))
context = {
"pkg": pkg,
"remediation": remediation,
"vulns": vulns,
"changelog": changelog,
"overall_impact": overall_impact,
"summary_changelog": False,
"hint": get_hint(remediation),
}
result = template.render(context)
# GitHub has a PR body length limit of 65536. If we're going over that, skip the changelog and just use a link.
if len(result) < 65500:
return result
context["summary_changelog"] = True
return template.render(context)
def generate_issue_body(
pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]], *, api_key: str
) -> Optional[str]:
"""
Generates the body content for an issue.
Args:
pkg (str): The package name.
remediation (Dict[str, Any]): The remediation data.
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
api_key (str): The API key for fetching changelog data.
Returns:
str: The generated body content.
"""
changelog = fetch_changelog(
pkg,
remediation["version"],
remediation["recommended_version"],
api_key=api_key,
from_spec=remediation.get("requirement", {}).get("specifier", None),
)
p = Path(__file__).parent / "templates"
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(p))) # type: ignore
template = env.get_template("issue.jinja2")
overall_impact = cvss3_score_to_label(highest_base_score(vulns))
context = {
"pkg": pkg,
"remediation": remediation,
"vulns": vulns,
"changelog": changelog,
"overall_impact": overall_impact,
"summary_changelog": False,
"hint": get_hint(remediation),
}
result = template.render(context)
# GitHub has a PR body length limit of 65536. If we're going over that, skip the changelog and just use a link.
if len(result) < 65500:
return result
context["summary_changelog"] = True
return template.render(context)
def generate_commit_message(pkg: str, remediation: Dict[str, Any]) -> str:
"""
Generates a commit message for a given package and remediation.
Args:
pkg (str): The package name.
remediation (Dict[str, Any]): The remediation data.
Returns:
str: The generated commit message.
"""
from_dependency = (
remediation["version"]
if remediation["version"]
else remediation["requirement"]["specifier"]
)
return (
f"Update {pkg} from {from_dependency} to {remediation['recommended_version']}"
)
def git_sha1(raw_contents: bytes) -> str:
"""
Calculates the SHA-1 hash of the given raw contents.
Args:
raw_contents (bytes): The raw contents to hash.
Returns:
str: The SHA-1 hash.
"""
return hashlib.sha1(
b"blob " + str(len(raw_contents)).encode("ascii") + b"\0" + raw_contents
).hexdigest()
def fetch_changelog(
package: str,
from_version: Optional[str],
to_version: str,
*,
api_key: str,
from_spec: Optional[str] = None,
) -> Dict[str, Any]:
"""
Fetches the changelog for a package from a specified version to another version.
Args:
package (str): The package name.
from_version (Optional[str]): The starting version.
to_version (str): The ending version.
api_key (str): The API key for fetching changelog data.
from_spec (Optional[str]): The specifier for the starting version.
Returns:
Dict[str, Any]: The fetched changelog data.
"""
to_version_parsed = parse_version(to_version)
if from_version:
from_version_parsed = parse_version(from_version)
else:
from_version_parsed = None
from_spec = SpecifierSet(from_spec)
changelog = {}
headers = {"X-Api-Key": api_key}
headers.update(get_meta_http_headers())
r = requests.get(
"https://pyup.io/api/v1/changelogs/{}/".format(package), headers=headers
)
if r.status_code == 200:
data = r.json()
if data:
# sort the changelog by release
sorted_log = sorted(
data.items(), key=lambda v: parse_version(v[0]), reverse=True
)
# go over each release and add it to the log if it's within the "upgrade
# range" e.g. update from 1.2 to 1.3 includes a changelog for 1.2.1 but
# not for 0.4.
for version, log in sorted_log:
parsed_version = parse_version(version)
version_check = from_version and (parsed_version > from_version_parsed)
spec_check = (
from_spec
and isinstance(from_spec, SpecifierSet)
and from_spec.contains(parsed_version)
)
if version_check or spec_check and parsed_version <= to_version_parsed:
changelog[version] = log
return changelog
def cvss3_score_to_label(score: float) -> Optional[str]:
"""
Converts a CVSS v3 score to a severity label.
Args:
score (float): The CVSS v3 score.
Returns:
Optional[str]: The severity label.
"""
if 0.1 <= score <= 3.9:
return "low"
elif 4.0 <= score <= 6.9:
return "medium"
elif 7.0 <= score <= 8.9:
return "high"
elif 9.0 <= score <= 10.0:
return "critical"
return None
def require_files_report(func):
@wraps(func)
def inner(obj: Any, *args: Any, **kwargs: Any) -> Any:
"""
Decorator that ensures a report is generated against a file.
Args:
obj (Any): The object containing the report.
*args (Any): Additional arguments.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The result of the decorated function.
"""
if obj.report["report_meta"]["scan_target"] != "files":
click.secho(
"This report was generated against an environment, but this alert command requires "
"a scan report that was generated against a file. To learn more about the "
"`safety alert` command visit https://docs.pyup.io/docs/safety-2-alerts",
fg="red",
)
sys.exit(1)
files = obj.report["report_meta"]["scanned"]
obj.requirements_files = {}
for f in files:
if not os.path.exists(f):
cwd = os.getcwd()
click.secho(
"A requirements file scanned in the report, {}, does not exist (looking in {}).".format(
f, cwd
),
fg="red",
)
sys.exit(1)
obj.requirements_files[f] = open(f, "rb").read()
return func(obj, *args, **kwargs)
return inner

View File

@@ -0,0 +1,88 @@
import sys
import logging
logger = logging.getLogger(__name__)
def apply_asyncio_patch():
"""
Apply a patch to asyncio's exception handling for subprocesses.
There are some issues with asyncio's exception handling for subprocesses,
which causes a RuntimeError to be raised when the event loop was already closed.
This patch catches the RuntimeError and ignores it, which allows the event loop
to be closed properly.
Similar issues:
- https://bugs.python.org/issue39232
- https://github.com/python/cpython/issues/92841
"""
import asyncio.base_subprocess
original_subprocess_del = asyncio.base_subprocess.BaseSubprocessTransport.__del__
def patched_subprocess_del(self):
try:
original_subprocess_del(self)
except (RuntimeError, ValueError, OSError) as e:
if isinstance(e, RuntimeError) and str(e) != "Event loop is closed":
raise
if isinstance(e, ValueError) and str(e) != "I/O operation on closed pipe":
raise
if isinstance(e, OSError) and "[WinError 6]" not in str(e):
raise
logger.debug(f"Patched {original_subprocess_del}")
asyncio.base_subprocess.BaseSubprocessTransport.__del__ = patched_subprocess_del
if sys.platform == "win32":
import asyncio.proactor_events as proactor_events
original_pipe_del = proactor_events._ProactorBasePipeTransport.__del__
def patched_pipe_del(self):
try:
original_pipe_del(self)
except (RuntimeError, ValueError) as e:
if isinstance(e, RuntimeError) and str(e) != "Event loop is closed":
raise
if (
isinstance(e, ValueError)
and str(e) != "I/O operation on closed pipe"
):
raise
logger.debug(f"Patched {original_pipe_del}")
original_repr = proactor_events._ProactorBasePipeTransport.__repr__
def patched_repr(self):
try:
return original_repr(self)
except ValueError as e:
if str(e) != "I/O operation on closed pipe":
raise
logger.debug(f"Patched {original_repr}")
return f"<{self.__class__} [closed]>"
proactor_events._ProactorBasePipeTransport.__del__ = patched_pipe_del
proactor_events._ProactorBasePipeTransport.__repr__ = patched_repr
import subprocess
original_popen_del = subprocess.Popen.__del__
def patched_popen_del(self):
try:
original_popen_del(self)
except OSError as e:
if "[WinError 6]" not in str(e):
raise
logger.debug(f"Patched {original_popen_del}")
subprocess.Popen.__del__ = patched_popen_del
apply_asyncio_patch()

View File

@@ -0,0 +1,10 @@
from .cli_utils import auth_options, build_client_session, proxy_options, inject_session
from .cli import auth
__all__ = [
"build_client_session",
"proxy_options",
"auth_options",
"inject_session",
"auth",
]

View File

@@ -0,0 +1,402 @@
# type: ignore
import logging
import sys
from datetime import datetime
from safety.auth.models import Auth
from safety.auth.utils import initialize, is_email_verified
from safety.console import main_console as console
from safety.constants import (
MSG_FINISH_REGISTRATION_TPL,
MSG_VERIFICATION_HINT,
DEFAULT_EPILOG,
)
from safety.meta import get_version
from safety.decorators import notify
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated
from typing import Optional
import click
import typer
from rich.padding import Padding
from typer import Typer
from safety.auth.main import (
clean_session,
get_auth_info,
get_authorization_data,
get_token,
)
from safety.auth.server import process_browser_callback
from safety.events.utils import emit_auth_started, emit_auth_completed
from safety.util import initialize_event_bus
from safety.scan.constants import (
CLI_AUTH_COMMAND_HELP,
CLI_AUTH_HEADLESS_HELP,
CLI_AUTH_LOGIN_HELP,
CLI_AUTH_LOGOUT_HELP,
CLI_AUTH_STATUS_HELP,
)
from ..cli_util import SafetyCLISubGroup, get_command_for, pass_safety_cli_obj
from safety.error_handlers import handle_cmd_exception
from .constants import (
MSG_FAIL_LOGIN_AUTHED,
MSG_FAIL_REGISTER_AUTHED,
MSG_LOGOUT_DONE,
MSG_LOGOUT_FAILED,
MSG_NON_AUTHENTICATED,
)
LOG = logging.getLogger(__name__)
auth_app = Typer(rich_markup_mode="rich", name="auth")
CMD_LOGIN_NAME = "login"
CMD_REGISTER_NAME = "register"
CMD_STATUS_NAME = "status"
CMD_LOGOUT_NAME = "logout"
DEFAULT_CMD = CMD_LOGIN_NAME
@auth_app.callback(
invoke_without_command=True,
cls=SafetyCLISubGroup,
help=CLI_AUTH_COMMAND_HELP,
epilog=DEFAULT_EPILOG,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
@pass_safety_cli_obj
def auth(ctx: typer.Context) -> None:
"""
Authenticate Safety CLI with your account.
Args:
ctx (typer.Context): The Typer context object.
"""
LOG.info("auth started")
# If no subcommand is invoked, forward to the default command
if not ctx.invoked_subcommand:
default_command = get_command_for(name=DEFAULT_CMD, typer_instance=auth_app)
return ctx.forward(default_command)
def fail_if_authenticated(ctx: typer.Context, with_msg: str) -> None:
"""
Exits the command if the user is already authenticated.
Args:
ctx (typer.Context): The Typer context object.
with_msg (str): The message to display if authenticated.
"""
info = get_auth_info(ctx)
if info:
console.print()
email = f"[green]{ctx.obj.auth.email}[/green]"
if not ctx.obj.auth.email_verified:
email = f"{email} {render_email_note(ctx.obj.auth)}"
console.print(with_msg.format(email=email))
sys.exit(0)
def render_email_note(auth: Auth) -> str:
"""
Renders a note indicating whether email verification is required.
Args:
auth (Auth): The Auth object.
Returns:
str: The rendered email note.
"""
return "" if auth.email_verified else "[red](email verification required)[/red]"
def render_successful_login(auth: Auth, organization: Optional[str] = None) -> None:
"""
Renders a message indicating a successful login.
Args:
auth (Auth): The Auth object.
organization (Optional[str]): The organization name.
"""
DEFAULT = "--"
name = auth.name if auth.name else DEFAULT
email = auth.email if auth.email else DEFAULT
email_note = render_email_note(auth)
console.print()
console.print("[bold][green]You're authenticated[/green][/bold]")
if name and name != email:
details = [f"[green][bold]Account:[/bold] {name}, {email}[/green] {email_note}"]
else:
details = [f"[green][bold]Account:[/bold] {email}[/green] {email_note}"]
if organization:
details.insert(0, f"[green][bold]Organization:[/bold] {organization}[green]")
for msg in details:
console.print(Padding(msg, (0, 0, 0, 1)), emoji=True)
@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
@handle_cmd_exception
@notify
def login(
ctx: typer.Context,
headless: Annotated[
Optional[bool],
typer.Option(
"--headless",
help=CLI_AUTH_HEADLESS_HELP,
),
] = None,
) -> None:
"""
Authenticate Safety CLI with your safetycli.com account using your default browser.
Args:
ctx (typer.Context): The Typer context object.
headless (bool): Whether to run in headless mode.
"""
LOG.info("login started")
headless = headless is True
# Check if the user is already authenticated
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)
console.print()
info = None
brief_msg: str = (
"Redirecting your browser to log in; once authenticated, "
"return here to start using Safety"
)
if ctx.obj.auth.org:
console.print(
f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] organization."
)
if headless:
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"
# Get authorization data and generate the authorization URL
uri, initial_state = get_authorization_data(
client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org,
headless=headless,
)
click.secho(brief_msg)
click.echo()
emit_auth_started(ctx.obj.event_bus, ctx)
# Process the browser callback to complete the authentication
info = process_browser_callback(
uri, initial_state=initial_state, ctx=ctx, headless=headless
)
is_success = False
error_msg = None
if info:
if info.get("email", None):
organization = None
if ctx.obj.auth.org and ctx.obj.auth.org.name:
organization = ctx.obj.auth.org.name
ctx.obj.auth.refresh_from(info)
if headless:
console.print()
initialize(ctx, refresh=True)
initialize_event_bus(ctx=ctx)
render_successful_login(ctx.obj.auth, organization=organization)
is_success = True
console.print()
if ctx.obj.auth.org or ctx.obj.auth.email_verified:
if not getattr(ctx.obj, "only_auth_msg", False):
console.print(
"[tip]Tip[/tip]: now try [bold]`safety scan`[/bold] in your projects root "
"folder to run a project scan or [bold]`safety -help`[/bold] to learn more."
)
else:
console.print(
MSG_FINISH_REGISTRATION_TPL.format(email=ctx.obj.auth.email)
)
console.print()
console.print(MSG_VERIFICATION_HINT)
else:
click.secho("Safety is now authenticated but your email is missing.")
else:
error_msg = ":stop_sign: [red]"
if ctx.obj.auth.org:
error_msg += (
f"Error logging into {ctx.obj.auth.org.name} organization "
f"with auth ID: {ctx.obj.auth.org.id}."
)
else:
error_msg += "Error logging into Safety."
error_msg += (
" Please try again, or use [bold]`safety auth -help`[/bold] "
"for more information[/red]"
)
console.print(error_msg, emoji=True)
emit_auth_completed(
ctx.obj.event_bus, ctx, success=is_success, error_message=error_msg
)
@auth_app.command(name=CMD_LOGOUT_NAME, help=CLI_AUTH_LOGOUT_HELP)
@handle_cmd_exception
@notify
def logout(ctx: typer.Context) -> None:
"""
Log out of your current session.
Args:
ctx (typer.Context): The Typer context object.
"""
LOG.info("logout started")
id_token = get_token("id_token")
msg = MSG_NON_AUTHENTICATED
if id_token:
# Clean the session if an ID token is found
if clean_session(ctx.obj.auth.client):
msg = MSG_LOGOUT_DONE
else:
msg = MSG_LOGOUT_FAILED
console.print(msg)
@auth_app.command(name=CMD_STATUS_NAME, help=CLI_AUTH_STATUS_HELP)
@click.option(
"--ensure-auth/--no-ensure-auth",
default=False,
help="This will keep running the command until anauthentication is made.",
)
@click.option(
"--login-timeout",
"-w",
type=int,
default=600,
help="Max time allowed to wait for an authentication.",
)
@handle_cmd_exception
@notify
def status(
ctx: typer.Context, ensure_auth: bool = False, login_timeout: int = 600
) -> None:
"""
Display Safety CLI's current authentication status.
Args:
ctx (typer.Context): The Typer context object.
ensure_auth (bool): Whether to keep running until authentication is made.
login_timeout (int): Max time allowed to wait for authentication.
"""
LOG.info("status started")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
safety_version = get_version()
console.print(f"[{current_time}]: Safety {safety_version}")
info = get_auth_info(ctx)
initialize(ctx, refresh=True)
if ensure_auth:
console.print("running: safety auth status --ensure-auth")
console.print()
if info:
verified = is_email_verified(info)
email_status = " [red](email not verified)[/red]" if not verified else ""
console.print(f"[green]Authenticated as {info['email']}[/green]{email_status}")
elif ensure_auth:
console.print(
"Safety is not authenticated. Launching default browser to log in"
)
console.print()
uri, initial_state = get_authorization_data(
client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org,
ensure_auth=ensure_auth,
)
# Process the browser callback to complete the authentication
info = process_browser_callback(
uri, initial_state=initial_state, timeout=login_timeout, ctx=ctx
)
if not info:
console.print(
f"[red]Timeout error ({login_timeout} seconds): not successfully authenticated without the timeout period.[/red]"
)
sys.exit(1)
organization = None
if ctx.obj.auth.org and ctx.obj.auth.org.name:
organization = ctx.obj.auth.org.name
render_successful_login(ctx.obj.auth, organization=organization)
console.print()
else:
console.print(MSG_NON_AUTHENTICATED)
@auth_app.command(name=CMD_REGISTER_NAME)
@handle_cmd_exception
@notify
def register(ctx: typer.Context) -> None:
"""
Create a new user account for the safetycli.com service.
Args:
ctx (typer.Context): The Typer context object.
"""
LOG.info("register started")
# Check if the user is already authenticated
fail_if_authenticated(ctx, with_msg=MSG_FAIL_REGISTER_AUTHED)
# Get authorization data and generate the registration URL
uri, initial_state = get_authorization_data(
client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
sign_up=True,
)
console.print(
"\nRedirecting your browser to register for a free account. Once registered, return here to start using Safety."
)
console.print()
# Process the browser callback to complete the registration
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx)
console.print()
if info:
console.print(f"[green]Successfully registered {info.get('email')}[/green]")
console.print()
else:
console.print("[red]Unable to register in this time, try again.[/red]")

View File

@@ -0,0 +1,290 @@
import logging
from typing import Dict, Optional, Tuple, Any, Callable
import click
from .main import (
get_auth_info,
get_host_config,
get_organization,
get_proxy_config,
get_redirect_url,
get_token_data,
save_auth_config,
get_token,
clean_session,
)
from authlib.common.security import generate_token
from safety.auth.constants import CLIENT_ID
from safety.auth.models import Organization, Auth
from safety.auth.utils import (
S3PresignedAdapter,
SafetyAuthSession,
get_keys,
is_email_verified,
)
from safety.scan.constants import (
CLI_KEY_HELP,
CLI_PROXY_HOST_HELP,
CLI_PROXY_PORT_HELP,
CLI_PROXY_PROTOCOL_HELP,
CLI_STAGE_HELP,
)
from safety.util import DependentOption, SafetyContext, get_proxy_dict
from safety.models import SafetyCLI
from safety_schemas.models import Stage
LOG = logging.getLogger(__name__)
def build_client_session(
api_key: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[SafetyAuthSession, Dict[str, Any]]:
"""
Builds and configures the client session for authentication.
Args:
api_key (Optional[str]): The API key for authentication.
proxies (Optional[Dict[str, str]]): Proxy configuration.
headers (Optional[Dict[str, str]]): Additional headers.
Returns:
Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration.
"""
kwargs = {}
target_proxies = proxies
# Global proxy defined in the config.ini
proxy_config, proxy_timeout, proxy_required = get_proxy_config()
if not proxies:
target_proxies = proxy_config
def update_token(tokens, **kwargs):
save_auth_config(
access_token=tokens["access_token"],
id_token=tokens["id_token"],
refresh_token=tokens["refresh_token"],
)
load_auth_session(click_ctx=click.get_current_context(silent=True)) # type: ignore
client_session = SafetyAuthSession(
client_id=CLIENT_ID,
code_challenge_method="S256",
redirect_uri=get_redirect_url(),
update_token=update_token,
scope="openid email profile offline_access",
**kwargs,
)
client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter())
client_session.proxy_required = proxy_required
client_session.proxy_timeout = proxy_timeout
client_session.proxies = target_proxies # type: ignore
client_session.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
openid_config = client_session.fetch_openid_config()
client_session.metadata["token_endpoint"] = openid_config.get(
"token_endpoint", None
)
if api_key:
client_session.api_key = api_key # type: ignore
client_session.headers["X-Api-Key"] = api_key
if headers:
client_session.headers.update(headers)
return client_session, openid_config
def load_auth_session(click_ctx: click.Context) -> None:
"""
Loads the authentication session from the context.
Args:
click_ctx (click.Context): The Click context object.
"""
if not click_ctx:
LOG.warning("Click context is needed to be able to load the Auth data.")
return
client = click_ctx.obj.auth.client
keys = click_ctx.obj.auth.keys
access_token: str = get_token(name="access_token") # type: ignore
refresh_token: str = get_token(name="refresh_token") # type: ignore
id_token: str = get_token(name="id_token") # type: ignore
if access_token and keys:
try:
token = get_token_data(access_token, keys, silent_if_expired=True)
client.token = {
"access_token": access_token,
"refresh_token": refresh_token,
"id_token": id_token,
"token_type": "bearer",
"expires_at": token.get("exp", None), # type: ignore
}
except Exception as e:
print(e)
clean_session(client)
def proxy_options(func: Callable) -> Callable:
"""
Decorator that defines proxy options for Click commands.
Options defined per command, this will override the proxy settings defined in the
config.ini file.
Args:
func (Callable): The Click command function.
Returns:
Callable: The wrapped Click command function with proxy options.
"""
func = click.option(
"--proxy-protocol",
type=click.Choice(["http", "https"]),
default="https",
cls=DependentOption,
required_options=["proxy_host"],
help=CLI_PROXY_PROTOCOL_HELP,
)(func)
func = click.option(
"--proxy-port",
multiple=False,
type=int,
default=80,
cls=DependentOption,
required_options=["proxy_host"],
help=CLI_PROXY_PORT_HELP,
)(func)
func = click.option(
"--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP
)(func)
return func
def auth_options(stage: bool = True) -> Callable:
"""
Decorator that defines authentication options for Click commands.
Args:
stage (bool): Whether to include the stage option.
Returns:
Callable: The decorator function.
"""
def decorator(func: Callable) -> Callable:
func = click.option(
"--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP
)(func)
if stage:
func = click.option(
"--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP
)(func)
return func
return decorator
def inject_session(
ctx: click.Context,
proxy_protocol: Optional[str] = None,
proxy_host: Optional[str] = None,
proxy_port: Optional[str] = None,
key: Optional[str] = None,
stage: Optional[Stage] = None,
invoked_command: str = "",
) -> Any:
org: Optional[Organization] = get_organization()
if not stage:
host_stage = get_host_config(key_name="stage")
stage = host_stage if host_stage else Stage.development
proxy_config: Optional[Dict[str, str]] = get_proxy_dict(
proxy_protocol, # type: ignore
proxy_host, # type: ignore
proxy_port, # type: ignore
)
client_session, openid_config = build_client_session(
api_key=key, proxies=proxy_config
)
keys = get_keys(client_session, openid_config)
auth = Auth(
stage=stage,
keys=keys,
org=org,
client_id=CLIENT_ID, # type: ignore
client=client_session,
code_verifier=generate_token(48),
)
if not ctx.obj:
ctx.obj = SafetyCLI()
ctx.obj.auth = auth
load_auth_session(ctx)
info = get_auth_info(ctx)
if info:
ctx.obj.auth.name = info.get("name")
ctx.obj.auth.email = info.get("email")
ctx.obj.auth.email_verified = is_email_verified(info) # type: ignore
SafetyContext().account = info["email"]
else:
SafetyContext().account = ""
@ctx.call_on_close
def clean_up_on_close():
LOG.debug("Closing requests session.")
ctx.obj.auth.client.close()
if ctx.obj.event_bus:
from safety.events.utils import (
create_internal_event,
InternalEventType,
InternalPayload,
)
payload = InternalPayload(ctx=ctx)
flush_event = create_internal_event(
event_type=InternalEventType.FLUSH_SECURITY_TRACES, payload=payload
)
close_event = create_internal_event(
event_type=InternalEventType.CLOSE_RESOURCES, payload=payload
)
flush_future = ctx.obj.event_bus.emit(flush_event)
close_future = ctx.obj.event_bus.emit(close_event)
# Wait for both events to be processed
if flush_future and close_future:
try:
flush_future.result()
close_future.result()
except Exception as e:
LOG.warning(f"Error waiting for events to process: {e}")
ctx.obj.event_bus.stop()

View File

@@ -0,0 +1,35 @@
from pathlib import Path
from safety.constants import USER_CONFIG_DIR, get_config_setting
AUTH_CONFIG_FILE_NAME = "auth.ini"
AUTH_CONFIG_USER = USER_CONFIG_DIR / Path(AUTH_CONFIG_FILE_NAME)
HOST: str = "localhost"
CLIENT_ID = get_config_setting("CLIENT_ID")
AUTH_SERVER_URL = get_config_setting("AUTH_SERVER_URL")
SAFETY_PLATFORM_URL = get_config_setting("SAFETY_PLATFORM_URL")
OPENID_CONFIG_URL = f"{AUTH_SERVER_URL}/.well-known/openid-configuration"
CLAIM_EMAIL_VERIFIED_API = "https://api.safetycli.com/email_verified"
CLAIM_EMAIL_VERIFIED_AUTH_SERVER = "email_verified"
CLI_AUTH = f"{SAFETY_PLATFORM_URL}/cli/auth"
CLI_AUTH_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/auth/success"
CLI_AUTH_LOGOUT = f"{SAFETY_PLATFORM_URL}/cli/logout"
CLI_CALLBACK = f"{SAFETY_PLATFORM_URL}/cli/callback"
CLI_LOGOUT_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/logout/success"
MSG_NON_AUTHENTICATED = (
"Safety is not authenticated. Please run 'safety auth login' to log in."
)
MSG_FAIL_LOGIN_AUTHED = """[green]You are authenticated as[/green] {email}.
To log into a different account, first logout via: safety auth logout, and then login again."""
MSG_FAIL_REGISTER_AUTHED = "You are currently logged in to {email}, please logout using `safety auth logout` before registering a new account."
MSG_LOGOUT_DONE = "[green]Logout done.[/green]"
MSG_LOGOUT_FAILED = "[red]Logout failed. Try again.[/red]"

View File

@@ -0,0 +1,329 @@
import configparser
from typing import Any, Dict, Optional, Tuple, Union
from authlib.oidc.core import CodeIDToken
from authlib.jose import jwt
from authlib.jose.errors import ExpiredTokenError
from safety.auth.models import Organization
from safety.auth.constants import (
CLI_AUTH_LOGOUT,
CLI_CALLBACK,
AUTH_CONFIG_USER,
CLI_AUTH,
)
from safety.constants import CONFIG
from safety_schemas.models import Stage
from safety.util import get_proxy_dict
def get_authorization_data(
client,
code_verifier: str,
organization: Optional[Organization] = None,
sign_up: bool = False,
ensure_auth: bool = False,
headless: bool = False,
) -> Tuple[str, str]:
"""
Generate the authorization URL for the authentication process.
Args:
client: The authentication client.
code_verifier (str): The code verifier for the PKCE flow.
organization (Optional[Organization]): The organization to authenticate with.
sign_up (bool): Whether the URL is for sign-up.
ensure_auth (bool): Whether to ensure authentication.
headless (bool): Whether to run in headless mode.
Returns:
Tuple[str, str]: The authorization URL and initial state.
"""
kwargs = {
"sign_up": sign_up,
"locale": "en",
"ensure_auth": ensure_auth,
"headless": headless,
}
if organization:
kwargs["organization"] = organization.id
return client.create_authorization_url(
CLI_AUTH, code_verifier=code_verifier, **kwargs
)
def get_logout_url(id_token: str) -> str:
"""
Generate the logout URL.
Args:
id_token (str): The ID token.
Returns:
str: The logout URL.
"""
return f"{CLI_AUTH_LOGOUT}?id_token={id_token}"
def get_redirect_url() -> str:
"""
Get the redirect URL for the authentication callback.
Returns:
str: The redirect URL.
"""
return CLI_CALLBACK
def get_organization() -> Optional[Organization]:
"""
Retrieve the organization configuration.
Returns:
Optional[Organization]: The organization object, or None if not configured.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
org_conf: Union[Dict[str, str], configparser.SectionProxy] = (
config["organization"] if "organization" in config.sections() else {}
)
org_id: Optional[str] = (
org_conf["id"].replace('"', "") if org_conf.get("id", None) else None
)
org_name: Optional[str] = (
org_conf["name"].replace('"', "") if org_conf.get("name", None) else None
)
if not org_id:
return None
org = Organization(id=org_id, name=org_name) # type: ignore
return org
def get_auth_info(ctx) -> Optional[Dict]:
"""
Retrieve the authentication information.
Args:
ctx: The context object containing authentication data.
Returns:
Optional[Dict]: The authentication information, or None if not authenticated.
"""
from safety.auth.utils import is_email_verified
info = None
if ctx.obj.auth.client.token:
try:
info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore
verified = is_email_verified(info) # type: ignore
if not verified:
user_info = ctx.obj.auth.client.fetch_user_info()
verified = is_email_verified(user_info)
if verified:
# refresh only if needed
raise ExpiredTokenError
except ExpiredTokenError:
# id_token expired. So fire a manually a refresh
try:
ctx.obj.auth.client.refresh_token(
ctx.obj.auth.client.metadata.get("token_endpoint"),
refresh_token=ctx.obj.auth.client.token.get("refresh_token"),
)
info = get_token_data(
get_token(name="id_token"), # type: ignore
keys=ctx.obj.auth.keys, # type: ignore
)
except Exception as _e:
clean_session(ctx.obj.auth.client)
except Exception as _g:
clean_session(ctx.obj.auth.client)
return info
def get_token_data(
token: str, keys: Any, silent_if_expired: bool = False
) -> Optional[Dict]:
"""
Decode and validate the token data.
Args:
token (str): The token to decode.
keys (Any): The keys to use for decoding.
silent_if_expired (bool): Whether to silently ignore expired tokens.
Returns:
Optional[Dict]: The decoded token data, or None if invalid.
"""
claims = jwt.decode(token, keys, claims_cls=CodeIDToken)
try:
claims.validate()
except ExpiredTokenError as e:
if not silent_if_expired:
raise e
return claims
def get_token(name: str = "access_token") -> Optional[str]:
""" "
Retrieve a token from the local authentication configuration.
This returns tokens saved in the local auth configuration.
There are two types of tokens: access_token and id_token
Args:
name (str): The name of the token to retrieve.
Returns:
Optional[str]: The token value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
if "auth" in config.sections() and name in config["auth"]:
value = config["auth"][name]
if value:
return value
return None
def get_host_config(key_name: str) -> Optional[Any]:
"""
Retrieve a configuration value from the host configuration.
Args:
key_name (str): The name of the configuration key.
Returns:
Optional[Any]: The configuration value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
if not config.has_section("host"):
return None
host_section = dict(config.items("host"))
if key_name in host_section:
if key_name == "stage":
# Support old alias in the config.ini
if host_section[key_name] == "dev":
host_section[key_name] = "development"
if host_section[key_name] not in {env.value for env in Stage}:
return None
return Stage(host_section[key_name])
return None
def str_to_bool(s: str) -> bool:
"""
Convert a string to a boolean value.
Args:
s (str): The string to convert.
Returns:
bool: The converted boolean value.
Raises:
ValueError: If the string cannot be converted.
"""
if s.lower() == "true" or s == "1":
return True
elif s.lower() == "false" or s == "0":
return False
else:
raise ValueError(f"Cannot convert '{s}' to a boolean value.")
def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]:
"""
Retrieve the proxy configuration.
Returns:
Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
proxy_dictionary = None
required = False
timeout = None
proxy = None
if config.has_section("proxy"):
proxy = dict(config.items("proxy"))
if proxy:
try:
proxy_dictionary = get_proxy_dict(
proxy["protocol"],
proxy["host"],
proxy["port"], # type: ignore
)
required = str_to_bool(proxy["required"])
timeout = proxy["timeout"]
except Exception:
pass
return proxy_dictionary, timeout, required # type: ignore
def clean_session(client) -> bool:
"""
Clean the authentication session.
Args:
client: The authentication client.
Returns:
bool: Always returns True.
"""
config = configparser.ConfigParser()
config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile)
client.token = None
return True
def save_auth_config(
access_token: Optional[str] = None,
id_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> None:
"""
Save the authentication configuration.
Args:
access_token (Optional[str]): The access token.
id_token (Optional[str]): The ID token.
refresh_token (Optional[str]): The refresh token.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
config["auth"] = { # type: ignore
"access_token": access_token,
"id_token": id_token,
"refresh_token": refresh_token,
}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile) # type: ignore

View File

@@ -0,0 +1,105 @@
from dataclasses import dataclass
import os
from typing import Any, Optional, Dict
from authlib.integrations.base_client import BaseOAuth
from safety_schemas.models import Stage
@dataclass
class Organization:
id: str
name: str
def to_dict(self) -> Dict:
"""
Convert the Organization instance to a dictionary.
Returns:
dict: The dictionary representation of the organization.
"""
return {"id": self.id, "name": self.name}
@dataclass
class Auth:
org: Optional[Organization]
keys: Any
client: Any
code_verifier: str
client_id: str
stage: Optional[Stage] = Stage.development
email: Optional[str] = None
name: Optional[str] = None
email_verified: bool = False
def is_valid(self) -> bool:
"""
Check if the authentication information is valid.
Returns:
bool: True if valid, False otherwise.
"""
if os.getenv("SAFETY_DB_DIR"):
return True
if not self.client:
return False
if self.client.api_key:
return True
return bool(self.client.token and self.email_verified)
def refresh_from(self, info: Dict) -> None:
"""
Refresh the authentication information from the provided info.
Args:
info (dict): The information to refresh from.
"""
from safety.auth.utils import is_email_verified
self.name = info.get("name")
self.email = info.get("email")
self.email_verified = is_email_verified(info) # type: ignore
def get_auth_method(self) -> str:
"""
Get the authentication method.
Returns:
str: The authentication method.
"""
if self.client.api_key:
return "API Key"
if self.client.token:
return "Token"
return "None"
class XAPIKeyAuth(BaseOAuth):
def __init__(self, api_key: str) -> None:
"""
Initialize the XAPIKeyAuth instance.
Args:
api_key (str): The API key to use for authentication.
"""
self.api_key = api_key
def __call__(self, r: Any) -> Any:
"""
Add the API key to the request headers.
Args:
r (Any): The request object.
Returns:
Any: The modified request object.
"""
r.headers["X-API-Key"] = self.api_key
return r

View File

@@ -0,0 +1,308 @@
# type: ignore
import http.server
import json
import logging
import random
import socket
import sys
import time
from typing import Any, Optional, Dict, Tuple
import urllib.parse
import threading
import click
from safety.auth.utils import is_jupyter_notebook
from safety.console import main_console as console
from safety.auth.constants import (
AUTH_SERVER_URL,
CLI_AUTH_SUCCESS,
CLI_LOGOUT_SUCCESS,
HOST,
)
from safety.auth.main import save_auth_config
from rich.prompt import Prompt
LOG = logging.getLogger(__name__)
def find_available_port() -> Optional[int]:
"""
Find an available port on localhost within the dynamic port range (49152-65536).
Returns:
Optional[int]: An available port number, or None if no ports are available.
"""
# Dynamic ports IANA
port_range = list(range(49152, 65536))
random.shuffle(port_range)
for port in port_range:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.connect(("localhost", port))
# If the connect succeeds, the port is already in use
except socket.error:
# If the connect fails, the port is available
return port
return None
def auth_process(
code: str, state: str, initial_state: str, code_verifier: str, client: Any
) -> Any:
"""
Process the authentication callback and exchange the authorization code for tokens.
Args:
code (str): The authorization code.
state (str): The state parameter from the callback.
initial_state (str): The initial state parameter.
code_verifier (str): The code verifier for PKCE.
client (Any): The OAuth client.
Returns:
Any: The user information.
Raises:
SystemExit: If there is an error during authentication.
"""
err = None
if initial_state is None or initial_state != state:
err = (
"The state parameter value provided does not match the expected "
"value. The state parameter is used to protect against Cross-Site "
"Request Forgery (CSRF) attacks. For security reasons, the "
"authorization process cannot proceed with an invalid state "
"parameter value. Please try again, ensuring that the state "
"parameter value provided in the authorization request matches "
"the value returned in the callback."
)
if err:
click.secho(f"Error: {err}", fg="red")
sys.exit(1)
try:
tokens = client.fetch_token(
url=f"{AUTH_SERVER_URL}/oauth/token",
code_verifier=code_verifier,
client_id=client.client_id,
grant_type="authorization_code",
code=code,
)
save_auth_config(
access_token=tokens["access_token"],
id_token=tokens["id_token"],
refresh_token=tokens["refresh_token"],
)
return client.fetch_user_info()
except Exception as e:
LOG.exception(e)
sys.exit(1)
class CallbackHandler(http.server.BaseHTTPRequestHandler):
def auth(self, code: str, state: str, err: str, error_description: str) -> None:
"""
Handle the authentication callback.
Args:
code (str): The authorization code.
state (str): The state parameter.
err (str): The error message, if any.
error_description (str): The error description, if any.
"""
initial_state = self.server.initial_state
ctx = self.server.ctx
result = auth_process(
code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client,
)
self.server.callback = result
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})
def logout(self) -> None:
"""
Handle the logout callback.
"""
ctx = self.server.ctx
uri = CLI_LOGOUT_SUCCESS
if ctx.obj.auth.org:
uri = f"{uri}&org_id={ctx.obj.auth.org.id}"
self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={})
def do_GET(self) -> None:
"""
Handle GET requests.
"""
query = urllib.parse.urlparse(self.path).query
params = urllib.parse.parse_qs(query)
callback_type: Optional[str] = None
try:
c_type = params.get("type", [])
if (
isinstance(c_type, list)
and len(c_type) == 1
and isinstance(c_type[0], str)
):
callback_type = c_type[0]
except Exception:
msg = "Unable to process the callback, try again."
self.send_error(400, msg)
click.secho("Unable to process the callback, try again.")
return
if callback_type == "logout":
self.logout()
return
code = params.get("code", [""])[0]
state = params.get("state", [""])[0]
err = params.get("error", [""])[0]
error_description = params.get("error_description", [""])[0]
self.auth(code=code, state=state, err=err, error_description=error_description)
def do_redirect(self, location: str, params: Dict) -> None:
"""
Redirect the client to the specified location.
Args:
location (str): The URL to redirect to.
params (dict): Additional parameters for the redirection.
"""
self.send_response(302)
self.send_header("Location", location)
self.send_header("Connection", "close")
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
self.end_headers()
def log_message(self, format: str, *args: Any) -> None:
"""
Log an arbitrary message.
Args:
format (str): The format string.
args (Any): Arguments for the format string.
"""
LOG.info(format % args)
def process_browser_callback(uri: str, **kwargs: Any) -> Any:
"""
Process the browser callback for authentication.
Args:
uri (str): The authorization URL.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The user information.
Raises:
SystemExit: If there is an error during the process.
"""
class ThreadedHTTPServer(http.server.HTTPServer):
def __init__(self, server_address: Tuple, RequestHandlerClass: Any) -> None:
"""
Initialize the ThreadedHTTPServer.
Args:
server_address (Tuple): The server address as a tuple (host, port).
RequestHandlerClass (Any): The request handler class.
"""
super().__init__(server_address, RequestHandlerClass)
self.initial_state = None
self.ctx = None
self.callback = None
self.timeout_reached = False
def handle_timeout(self) -> None:
"""
Handle server timeout.
"""
self.timeout_reached = True
return super().handle_timeout()
PORT = find_available_port()
if not PORT:
click.secho("No available ports.")
sys.exit(1)
try:
headless = kwargs.get("headless", False)
initial_state = kwargs.get("initial_state", None)
ctx = kwargs.get("ctx", None)
message = "Copy and paste this URL into your browser:\n:icon_warning: Ensure there are no extra spaces, especially at line breaks, as they may break the link."
if not headless:
# Start a threaded HTTP server to handle the callback
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
server.initial_state = initial_state
server.timeout = kwargs.get("timeout", 600)
server.ctx = ctx
server_thread = threading.Thread(target=server.handle_request)
server_thread.start()
message = "If the browser does not automatically open in 5 seconds, copy and paste this url into your browser:"
target = uri if headless else f"{uri}&port={PORT}"
if is_jupyter_notebook():
console.print(f"{message} {target}")
else:
console.print(f"{message} [link={target}]{target}[/link]")
if headless:
# Handle the headless mode where user manually provides the response
exchange_data = None
while not exchange_data:
auth_code_text = Prompt.ask(
"Paste the response here", default=None, console=console
)
try:
exchange_data = json.loads(auth_code_text)
state = exchange_data["state"]
code = exchange_data["code"]
except Exception:
code = state = None
return auth_process(
code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client,
)
else:
# Wait for the browser authentication in non-headless mode
console.print()
wait_msg = "waiting for browser authentication"
with console.status(wait_msg, spinner="bouncingBar"):
time.sleep(2)
click.launch(target)
server_thread.join()
except OSError as e:
if e.errno == socket.errno.EADDRINUSE:
reason = f"The port {HOST}:{PORT} is currently being used by another application or process. Please choose a different port or terminate the conflicting application/process to free up the port."
else:
reason = "An error occurred while performing this operation."
click.secho(reason)
sys.exit(1)
return server.callback

View File

@@ -0,0 +1,756 @@
# type: ignore
import importlib.util
import json
import logging
from functools import lru_cache
from typing import Any, Callable, Dict, Optional, Tuple, List, Literal
import requests
from authlib.integrations.base_client.errors import OAuthError
from authlib.integrations.requests_client import OAuth2Session
from requests.adapters import HTTPAdapter
from safety_schemas.models import STAGE_ID_MAPPING, Stage
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from safety.auth.constants import (
AUTH_SERVER_URL,
OPENID_CONFIG_URL,
)
from safety.constants import (
PLATFORM_API_CHECK_UPDATES_ENDPOINT,
PLATFORM_API_INITIALIZE_ENDPOINT,
PLATFORM_API_POLICY_ENDPOINT,
PLATFORM_API_PROJECT_CHECK_ENDPOINT,
PLATFORM_API_PROJECT_ENDPOINT,
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT,
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
REQUEST_TIMEOUT,
FeatureType,
get_config_setting,
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT,
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT,
)
from safety.error_handlers import output_exception
from safety.errors import (
InvalidCredentialError,
NetworkConnectionError,
RequestTimeoutError,
SafetyError,
ServerError,
TooManyRequestsError,
)
from safety.meta import get_meta_http_headers
from safety.models import SafetyCLI
from safety.scan.util import AuthenticationType
from safety.util import SafetyContext
LOG = logging.getLogger(__name__)
def get_keys(
client_session: OAuth2Session, openid_config: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Retrieve the keys from the OpenID configuration.
Args:
client_session (OAuth2Session): The OAuth2 session.
openid_config (Dict[str, Any]): The OpenID configuration.
Returns:
Optional[Dict[str, Any]]: The keys, if available.
"""
if "jwks_uri" in openid_config:
return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() # type: ignore
return None
def is_email_verified(info: Dict[str, Any]) -> Optional[bool]:
"""
Check if the email is verified.
Args:
info (Dict[str, Any]): The user information.
Returns:
bool: True
"""
# return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(
# CLAIM_EMAIL_VERIFIED_AUTH_SERVER
# )
# Always return True to avoid email verification
return True
def extract_detail(response: requests.Response) -> Optional[str]:
"""
Extract the reason from an HTTP response.
Args:
response (requests.Response): The response.
Returns:
Optional[str]: The reason.
"""
detail = None
try:
detail = response.json().get("detail")
except Exception:
LOG.debug("Failed to extract detail from response: %s", response.status_code)
return detail
def parse_response(func: Callable) -> Callable:
"""
Decorator to parse the response from an HTTP request.
Args:
func (Callable): The function to wrap.
Returns:
Callable: The wrapped function.
"""
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3),
reraise=True,
retry=retry_if_exception_type(
(
NetworkConnectionError,
RequestTimeoutError,
TooManyRequestsError,
ServerError,
)
),
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
)
def wrapper(*args, **kwargs):
try:
r = func(*args, **kwargs)
except OAuthError as e:
LOG.exception("OAuth failed: %s", e)
raise InvalidCredentialError(
message="Your token authentication expired, try login again."
)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException as e:
raise e
# TODO: Handle content as JSON and fallback to text for all responses
if r.status_code == 403:
reason = extract_detail(response=r)
raise InvalidCredentialError(
credential="Failed authentication.", reason=reason
)
if r.status_code == 429:
raise TooManyRequestsError(reason=r.text)
if r.status_code >= 400 and r.status_code < 500:
error_code = None
try:
data = r.json()
reason = data.get("detail", "Unable to find reason.")
error_code = data.get("error_code", None)
except Exception:
reason = r.reason
raise SafetyError(message=reason, error_code=error_code)
if r.status_code >= 500 and r.status_code < 600:
reason = extract_detail(response=r)
LOG.debug("ServerError %s -> Response returned: %s", r.status_code, r.text)
raise ServerError(reason=reason)
data = None
try:
data = r.json()
except json.JSONDecodeError as e:
raise ServerError(message=f"Bad JSON response from the server: {e}")
return data
return wrapper
class SafetyAuthSession(OAuth2Session):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Initialize the SafetyAuthSession.
Args:
*args (Any): Positional arguments for the parent class.
**kwargs (Any): Keyword arguments for the parent class.
"""
super().__init__(*args, **kwargs)
self.proxy_required: bool = False
self.proxy_timeout: Optional[int] = None
self.api_key = None
def get_credential(self) -> Optional[str]:
"""
Get the current authentication credential.
Returns:
Optional[str]: The API key, token, or None.
"""
if self.api_key:
return self.api_key
if self.token:
return SafetyContext().account
return None
def is_using_auth_credentials(self) -> bool:
"""
Check if the session is using authentication credentials.
This does NOT check if the client is authenticated.
Returns:
bool: True if using authentication credentials, False otherwise.
"""
return self.get_authentication_type() != AuthenticationType.none
def get_authentication_type(self) -> AuthenticationType:
"""
Get the type of authentication being used.
Returns:
AuthenticationType: The type of authentication.
"""
if self.api_key:
return AuthenticationType.api_key
if self.token:
return AuthenticationType.token
return AuthenticationType.none
def request(
self,
method: str,
url: str,
withhold_token: bool = False,
auth: Optional[Tuple] = None,
bearer: bool = True,
**kwargs: Any,
) -> requests.Response:
"""
Make an HTTP request with the appropriate authentication.
Use the right auth parameter for Safety supported auth types.
Args:
method (str): The HTTP method.
url (str): The URL to request.
withhold_token (bool): Whether to withhold the token.
auth (Optional[Tuple]): The authentication tuple.
bearer (bool): Whether to use bearer authentication.
**kwargs (Any): Additional keyword arguments.
Returns:
requests.Response: The HTTP response.
Raises:
Exception: If the request fails.
"""
# By default use the token_auth
TIMEOUT_KEYWARD = "timeout"
func_timeout = (
kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT
)
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"].update(get_meta_http_headers())
if self.api_key:
kwargs["headers"]["X-Api-Key"] = self.api_key
if not self.token or not bearer:
# Fallback to no token auth
auth = ()
# Override proxies
if self.proxies:
kwargs["proxies"] = self.proxies
if self.proxy_timeout:
kwargs["timeout"] = int(self.proxy_timeout) / 1000
if ("proxies" not in kwargs or not self.proxies) and self.proxy_required:
output_exception(
"Proxy connection is required but there is not a proxy setup.", # type: ignore
exit_code_output=True,
)
request_func = super(SafetyAuthSession, self).request
params = {
"method": method,
"url": url,
"withhold_token": withhold_token,
"auth": auth,
}
params.update(kwargs)
try:
return request_func(**params)
except Exception as e:
LOG.debug("Request failed: %s", e)
if self.proxy_required:
output_exception(
f"Proxy is required but the connection failed because: {e}", # type: ignore
exit_code_output=True,
)
if "proxies" in kwargs or self.proxies:
params["proxies"] = {}
params["timeout"] = func_timeout
self.proxies = {}
message = (
"The proxy configuration failed to function and was disregarded."
)
LOG.debug(message)
if message not in [
a["message"] for a in SafetyContext.local_announcements
]:
SafetyContext.local_announcements.append(
{"message": message, "type": "warning", "local": True}
)
return request_func(**params)
raise e
def fetch_openid_config(self) -> Any:
"""
Fetch the OpenID configuration from the authorization server.
Returns:
Any: The OpenID configuration.
"""
try:
openid_config = self.get(
url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT
).json()
except Exception as e:
LOG.debug("Unable to load the openID config: %s", e)
openid_config = {}
return openid_config
@parse_response
def fetch_user_info(self) -> Any:
"""
Fetch user information from the authorization server.
Returns:
Any: The user information.
"""
USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo"
r = self.get(url=USER_INFO_ENDPOINT)
return r
@parse_response
def check_project(
self,
scan_stage: str,
safety_source: str,
project_slug: Optional[str] = None,
git_origin: Optional[str] = None,
project_slug_source: Optional[str] = None,
) -> Any:
"""
Check project information.
Args:
scan_stage (str): The scan stage.
safety_source (str): The safety source.
project_slug (Optional[str]): The project slug.
git_origin (Optional[str]): The git origin.
project_slug_source (Optional[str]): The project slug source.
Returns:
Any: The project information.
"""
data = {
"scan_stage": scan_stage,
"safety_source": safety_source,
"project_slug": project_slug,
"project_slug_source": project_slug_source,
"git_origin": git_origin,
}
r = self.post(url=PLATFORM_API_PROJECT_CHECK_ENDPOINT, json=data)
return r
@parse_response
def project(self, project_id: str) -> Any:
"""
Get project information.
Args:
project_id (str): The project ID.
Returns:
Any: The project information.
"""
data = {"project": project_id}
return self.get(url=PLATFORM_API_PROJECT_ENDPOINT, params=data)
@parse_response
def download_policy(
self, project_id: Optional[str], stage: Stage, branch: Optional[str]
) -> Any:
"""
Download the project policy.
Args:
project_id (Optional[str]): The project ID.
stage (Stage): The stage.
branch (Optional[str]): The branch.
Returns:
Any: The policy data.
"""
data = {
"project": project_id,
"stage": STAGE_ID_MAPPING[stage],
"branch": branch,
}
return self.get(url=PLATFORM_API_POLICY_ENDPOINT, params=data)
@parse_response
def project_scan_request(self, project_id: str) -> Any:
"""
Request a project scan.
Args:
project_id (str): The project ID.
Returns:
Any: The scan request result.
"""
data = {"project_id": project_id}
return self.post(url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data)
@parse_response
def upload_report(self, json_report: str) -> Any:
"""
Upload a scan report.
Args:
json_report (str): The JSON report.
Returns:
Any: The upload result.
"""
return self.post(
url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
data=json_report,
headers={"Content-Type": "application/json"},
)
def upload_requirements(self, json_payload: str) -> Any:
"""
Upload a scan report.
Args:
json_payload (str): The JSON payload to upload.
Returns:
Any: The result of the upload operation.
"""
return self.post(
url=PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
data=json.dumps(json_payload),
headers={"Content-Type": "application/json"},
)
@parse_response
def check_updates(
self,
version: int,
safety_version: Optional[str] = None,
python_version: Optional[str] = None,
os_type: Optional[str] = None,
os_release: Optional[str] = None,
os_description: Optional[str] = None,
) -> Any:
"""
Check for updates.
Args:
version (int): The version.
safety_version (Optional[str]): The Safety version.
python_version (Optional[str]): The Python version.
os_type (Optional[str]): The OS type.
os_release (Optional[str]): The OS release.
os_description (Optional[str]): The OS description.
Returns:
Any: The update check result.
"""
data = {
"version": version,
"safety_version": safety_version,
"python_version": python_version,
"os_type": os_type,
"os_release": os_release,
"os_description": os_description,
}
return self.get(url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data)
@parse_response
def audit_packages(
self, packages: List[str], ecosystem: Literal["pypi", "npmjs"]
) -> Any:
"""
Audits packages for vulnerabilities
Args:
packages: list of package specifiers
ecosystem: the ecosystem to audit
Returns:
Any: The packages audit result.
"""
url = (
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT
if ecosystem == "npmjs"
else FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT
)
data = {"packages": [{"package_specifier": package} for package in packages]}
return self.post(url=url, json=data)
@parse_response
def initialize(self) -> Any:
"""
Initialize a run.
Returns:
Any: The initialization result.
"""
try:
response = self.get(
url=PLATFORM_API_INITIALIZE_ENDPOINT,
headers={"Content-Type": "application/json"},
timeout=5,
)
return response
except requests.exceptions.Timeout:
LOG.error("Auth request to initialize timed out after 5 seconds.")
except Exception:
LOG.exception("Exception trying to auth initialize", exc_info=True)
return None
class S3PresignedAdapter(HTTPAdapter):
def send( # type: ignore
self, request: requests.PreparedRequest, **kwargs: Any
) -> requests.Response:
"""
Send a request, removing the Authorization header.
Args:
request (requests.PreparedRequest): The prepared request.
**kwargs (Any): Additional keyword arguments.
Returns:
requests.Response: The response.
"""
request.headers.pop("Authorization", None)
return super().send(request, **kwargs)
@lru_cache(maxsize=1)
def is_jupyter_notebook() -> bool:
"""
Detects if the code is running in a Jupyter notebook environment, including
various cloud-hosted Jupyter notebooks.
Returns:
bool: True if the environment is identified as a Jupyter notebook (or
equivalent cloud-based environment), False otherwise.
Supported environments:
- Google Colab
- Amazon SageMaker
- Azure Notebooks
- Kaggle Notebooks
- Databricks Notebooks
- Datalore by JetBrains
- Paperspace Gradient Notebooks
- Classic Jupyter Notebook and JupyterLab
"""
if (
(
importlib.util.find_spec("google")
and importlib.util.find_spec("google.colab")
)
is not None
or importlib.util.find_spec("sagemaker") is not None
or importlib.util.find_spec("azureml") is not None
or importlib.util.find_spec("kaggle") is not None
or importlib.util.find_spec("dbutils") is not None
or importlib.util.find_spec("datalore") is not None
or importlib.util.find_spec("gradient") is not None
):
return True
# Detect classic Jupyter Notebook, JupyterLab, and other IPython kernel-based environments
try:
from IPython import get_ipython # type: ignore
ipython = get_ipython()
if ipython is not None and "IPKernelApp" in ipython.config:
return True
except (ImportError, AttributeError, NameError):
pass
return False
def save_flags_config(flags: Dict[FeatureType, bool]) -> None:
"""
Save feature flags configuration to file.
This function attempts to save feature flags to the configuration file
but will fail silently if unable to do so (e.g., due to permission issues
or disk problems). Silent failure is chosen to prevent configuration issues
from disrupting core application functionality.
Note that if saving fails, the application will continue using existing
or default flag values until the next restart.
Args:
flags: Dictionary mapping feature types to their enabled/disabled state
The operation will be logged (with stack trace) if it fails.
"""
import configparser
from safety.constants import CONFIG_FILE_USER
config = configparser.ConfigParser()
config.read(CONFIG_FILE_USER)
flag_settings = {key.name.upper(): str(value) for key, value in flags.items()}
if not config.has_section("settings"):
config.add_section("settings")
settings = dict(config.items("settings"))
settings.update(flag_settings)
for key, value in settings.items():
config.set("settings", key, value)
try:
with open(CONFIG_FILE_USER, "w") as config_file:
config.write(config_file)
except Exception:
LOG.exception("Unable to save flags configuration.")
def get_feature_name(feature: FeatureType, as_attr: bool = False) -> str:
"""Returns a formatted feature name with enabled suffix.
Args:
feature: The feature to format the name for
as_attr: If True, formats for attribute usage (underscore),
otherwise uses hyphen
Returns:
Formatted feature name string with enabled suffix
"""
name = feature.name.lower()
separator = "_" if as_attr else "-"
return f"{name}{separator}enabled"
def str_to_bool(value) -> Optional[bool]:
"""Convert basic string representations to boolean."""
if isinstance(value, bool):
return value
if isinstance(value, str):
value = value.lower().strip()
if value in ("true"):
return True
if value in ("false"):
return False
return None
def initialize(ctx: Any, refresh: bool = True) -> None:
"""
Initializes the run by loading settings.
Args:
ctx (Any): The context object.
refresh (bool): Whether to refresh settings from the server. Defaults to True.
"""
settings = None
current_values = {}
if not ctx.obj:
ctx.obj = SafetyCLI()
for feature in FeatureType:
value = get_config_setting(feature.name)
if value is not None:
current_values[feature] = str_to_bool(value)
if refresh:
try:
settings = ctx.obj.auth.client.initialize() # type: ignore
except Exception:
LOG.info("Unable to initialize, continue with default values.")
if settings:
for feature in FeatureType:
server_value = str_to_bool(settings.get(feature.config_key))
if server_value is not None:
if (
feature not in current_values
or current_values[feature] != server_value
):
current_values[feature] = server_value
save_flags_config(current_values)
for feature, value in current_values.items():
if value is not None:
setattr(ctx.obj, feature.attr_name, value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,967 @@
import logging
import subprocess
import sys
from collections import defaultdict
from enum import Enum
from functools import wraps
import time
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import click
import typer
from rich.console import Console
from rich.table import Table
from rich.text import Text
from typer.core import MarkupMode, TyperCommand, TyperGroup
from click.utils import make_str
from safety.constants import (
BETA_PANEL_DESCRIPTION_HELP,
MSG_NO_AUTHD_CICD_PROD_STG,
MSG_NO_AUTHD_CICD_PROD_STG_ORG,
MSG_NO_AUTHD_DEV_STG,
MSG_NO_AUTHD_DEV_STG_ORG_PROMPT,
MSG_NO_AUTHD_DEV_STG_PROMPT,
MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL,
MSG_NO_VERIFIED_EMAIL_TPL,
CONTEXT_COMMAND_TYPE,
FeatureType,
)
from safety.scan.constants import CONSOLE_HELP_THEME
from safety.models import SafetyCLI
if TYPE_CHECKING:
from click.core import Command, Context
from safety.auth.models import Auth
LOG = logging.getLogger(__name__)
class CommandType(Enum):
MAIN = "main"
UTILITY = "utility"
BETA = "beta"
def custom_print_options_panel(
name: str, params: List[Any], ctx: Any, console: Console
) -> None:
"""
Print a panel with options.
Args:
name (str): The title of the panel.
params (List[Any]): The list of options/arguments to print.
ctx (Any): The context object.
markup_mode (str): The markup mode.
console (Console): The console to print to.
"""
table = Table(title=name, show_lines=True)
for param in params:
opts = getattr(param, "opts", "")
help_text = getattr(param, "help", "")
table.add_row(str(opts), help_text)
console.print(table)
def custom_print_commands_panel(
name: str, commands: List[Any], console: Console
) -> None:
"""
Print a panel with commands.
Args:
name (str): The title of the panel.
commands (List[Any]): The list of commands to print.
console (Console): The console to print to.
"""
table = Table(title=name, show_lines=True)
for command in commands:
table.add_row(command.name, command.help or "")
console.print(table)
def custom_make_rich_text(text: str) -> Text:
"""
Create rich text.
Args:
text (str): The text to format.
Returns:
Text: The formatted rich text.
"""
return Text(text)
def custom_get_help_text(obj: Any) -> Text:
"""
Get the help text for an object.
Args:
obj (Any): The object to get help text for.
Returns:
Text: The formatted help text.
"""
return Text(obj.help)
def custom_make_command_help(help_text: str) -> Text:
"""
Create rich text for command help.
Args:
help_text (str): The help text to format.
markup_mode (str): The markup mode.
Returns:
Text: The formatted rich text.
"""
return Text(help_text)
def get_command_for(name: str, typer_instance: typer.Typer) -> click.Command:
"""
Retrieve a command by name from a Typer instance.
Args:
name (str): The name of the command.
typer_instance (typer.Typer): The Typer instance.
Returns:
click.Command: The found command.
"""
single_command = next(
(
command
for command in typer_instance.registered_commands
if command.name == name
),
None,
)
if not single_command:
raise ValueError("Unable to find the command name.")
single_command.context_settings = typer_instance.info.context_settings
click_command = typer.main.get_command_from_info(
single_command,
pretty_exceptions_short=typer_instance.pretty_exceptions_short,
rich_markup_mode=typer_instance.rich_markup_mode,
)
if typer_instance._add_completion:
click_install_param, click_show_param = (
typer.main.get_install_completion_arguments()
)
click_command.params.append(click_install_param)
click_command.params.append(click_show_param)
return click_command
def pass_safety_cli_obj(func):
"""
Decorator to ensure the SafetyCLI object exists for a command.
"""
@wraps(func)
def inner(ctx, *args, **kwargs):
if not ctx.obj:
ctx.obj = SafetyCLI()
return func(ctx, *args, **kwargs)
return inner
def pretty_format_help(
obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode
) -> None:
"""
Format and print help text in a pretty format.
Args:
obj (Union[click.Command, click.Group]): The Click command or group.
ctx (click.Context): The Click context.
markup_mode (MarkupMode): The markup mode.
"""
from rich.align import Align
from rich.console import Console
from rich.padding import Padding
from rich.theme import Theme
from typer.rich_utils import (
ARGUMENTS_PANEL_TITLE,
COMMANDS_PANEL_TITLE,
OPTIONS_PANEL_TITLE,
STYLE_USAGE_COMMAND,
highlighter,
)
console = Console()
with console.use_theme(Theme(styles=CONSOLE_HELP_THEME)) as theme_context:
console = theme_context.console
# Print command / group help if we have some
if obj.help:
console.print()
# Print with some padding
console.print(
Padding(Align(custom_get_help_text(obj=obj), pad=False), (0, 1, 0, 1))
)
# Print usage
console.print(
Padding(highlighter(obj.get_usage(ctx)), 1), style=STYLE_USAGE_COMMAND
)
if isinstance(obj, click.MultiCommand):
panel_to_commands: DefaultDict[str, List[click.Command]] = defaultdict(list)
for command_name in obj.list_commands(ctx):
command = obj.get_command(ctx, command_name)
if command and not command.hidden:
panel_name = (
getattr(command, "rich_help_panel", None)
or COMMANDS_PANEL_TITLE
)
panel_to_commands[panel_name].append(command)
# Print each command group panel
default_commands = panel_to_commands.get(COMMANDS_PANEL_TITLE, [])
custom_print_commands_panel(
name=COMMANDS_PANEL_TITLE,
commands=default_commands,
console=console,
)
for panel_name, commands in panel_to_commands.items():
if panel_name == COMMANDS_PANEL_TITLE:
# Already printed above
continue
custom_print_commands_panel(
name=panel_name,
commands=commands,
console=console,
)
panel_to_arguments: DefaultDict[str, List[click.Argument]] = defaultdict(list)
panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list)
for param in obj.get_params(ctx):
# Skip if option is hidden
if getattr(param, "hidden", False):
continue
if isinstance(param, click.Argument):
panel_name = (
getattr(param, "rich_help_panel", None) or ARGUMENTS_PANEL_TITLE
)
panel_to_arguments[panel_name].append(param)
elif isinstance(param, click.Option):
panel_name = (
getattr(param, "rich_help_panel", None) or OPTIONS_PANEL_TITLE
)
panel_to_options[panel_name].append(param)
default_options = panel_to_options.get(OPTIONS_PANEL_TITLE, [])
custom_print_options_panel(
name=OPTIONS_PANEL_TITLE,
params=default_options,
ctx=ctx,
console=console,
)
for panel_name, options in panel_to_options.items():
if panel_name == OPTIONS_PANEL_TITLE:
# Already printed above
continue
custom_print_options_panel(
name=panel_name,
params=options,
ctx=ctx,
console=console,
)
default_arguments = panel_to_arguments.get(ARGUMENTS_PANEL_TITLE, [])
custom_print_options_panel(
name=ARGUMENTS_PANEL_TITLE,
params=default_arguments,
ctx=ctx,
console=console,
)
for panel_name, arguments in panel_to_arguments.items():
if panel_name == ARGUMENTS_PANEL_TITLE:
# Already printed above
continue
custom_print_options_panel(
name=panel_name,
params=arguments,
ctx=ctx,
console=console,
)
if ctx.parent:
params = []
for param in ctx.parent.command.params:
if isinstance(param, click.Option):
params.append(param)
custom_print_options_panel(
name="Global-Options",
params=params,
ctx=ctx.parent,
console=console,
)
# Epilogue if we have it
if obj.epilog:
# Remove single linebreaks, replace double with single
lines = obj.epilog.split("\n\n")
epilogue = "\n".join([x.replace("\n", " ").strip() for x in lines])
epilogue_text = custom_make_rich_text(text=epilogue)
console.print(Padding(Align(epilogue_text, pad=False), 1))
def print_main_command_panels(
*,
name: str,
commands_type: CommandType,
commands: List[click.Command],
markup_mode: MarkupMode,
console,
) -> None:
"""
Print the main command panels.
Args:
name (str): The name of the panel.
commands (List[click.Command]): List of commands to display.
markup_mode (MarkupMode): The markup mode.
console: The Rich console.
"""
from rich import box
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from typer.rich_utils import (
ALIGN_COMMANDS_PANEL,
STYLE_COMMANDS_PANEL_BORDER,
STYLE_COMMANDS_TABLE_BORDER_STYLE,
STYLE_COMMANDS_TABLE_BOX,
STYLE_COMMANDS_TABLE_LEADING,
STYLE_COMMANDS_TABLE_PAD_EDGE,
STYLE_COMMANDS_TABLE_PADDING,
STYLE_COMMANDS_TABLE_ROW_STYLES,
STYLE_COMMANDS_TABLE_SHOW_LINES,
)
t_styles: Dict[str, Any] = {
"show_lines": STYLE_COMMANDS_TABLE_SHOW_LINES,
"leading": STYLE_COMMANDS_TABLE_LEADING,
"box": STYLE_COMMANDS_TABLE_BOX,
"border_style": STYLE_COMMANDS_TABLE_BORDER_STYLE,
"row_styles": STYLE_COMMANDS_TABLE_ROW_STYLES,
"pad_edge": STYLE_COMMANDS_TABLE_PAD_EDGE,
"padding": STYLE_COMMANDS_TABLE_PADDING,
}
box_style = getattr(box, t_styles.pop("box"), None)
commands_table = Table(
highlight=False,
show_header=False,
expand=True,
box=box_style,
**t_styles,
)
console_width = 80
column_width = 25
if console.size and console.size[0] > 80:
console_width = console.size[0]
from rich.console import Group
description = None
if commands_type is CommandType.BETA:
description = Group(Text(""), Text(BETA_PANEL_DESCRIPTION_HELP), Text(""))
commands_table.add_column(
style="bold cyan", no_wrap=True, width=column_width, max_width=column_width
)
commands_table.add_column(width=console_width - column_width)
rows = []
for command in commands:
helptext = command.short_help or command.help or ""
command_name = command.name or ""
command_name_text = (
Text(command_name, style="")
if commands_type is CommandType.BETA
else Text(command_name)
)
rows.append(
[
command_name_text,
custom_make_command_help(
help_text=helptext,
),
]
)
rows.append([])
for row in rows:
commands_table.add_row(*row)
if commands_table.row_count:
renderables = (
[description, commands_table]
if description is not None
else [Text(""), commands_table]
)
console.print(
Panel(
Group(*renderables),
border_style=STYLE_COMMANDS_PANEL_BORDER,
title=name,
title_align=ALIGN_COMMANDS_PANEL,
)
)
# The help output for the main safety root command: `safety --help`
def format_main_help(
obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode
) -> None:
"""
Format the main help output for the safety root command.
Args:
obj (Union[click.Command, click.Group]): The Click command or group.
ctx (click.Context): The Click context.
markup_mode (MarkupMode): The markup mode.
"""
from rich.align import Align
from rich.console import Console
from rich.padding import Padding
from rich.theme import Theme
from typer.rich_utils import (
ARGUMENTS_PANEL_TITLE,
COMMANDS_PANEL_TITLE,
OPTIONS_PANEL_TITLE,
STYLE_USAGE_COMMAND,
highlighter,
)
typer_console = Console()
with typer_console.use_theme(Theme(styles=CONSOLE_HELP_THEME)) as theme_context:
console = theme_context.console
# Print command / group help if we have some
if obj.help:
console.print()
# Print with some padding
console.print(
Padding(
Align(
custom_get_help_text(obj=obj),
pad=False,
),
(0, 1, 0, 1),
)
)
# Print usage
console.print(
Padding(highlighter(obj.get_usage(ctx)), 1), style=STYLE_USAGE_COMMAND
)
if isinstance(obj, click.MultiCommand):
UTILITY_COMMANDS_PANEL_TITLE = "Utility commands"
BETA_COMMANDS_PANEL_TITLE = "Beta Commands"
COMMANDS_PANEL_TITLE_CONSTANTS = {
CommandType.MAIN: COMMANDS_PANEL_TITLE,
CommandType.UTILITY: UTILITY_COMMANDS_PANEL_TITLE,
CommandType.BETA: BETA_COMMANDS_PANEL_TITLE,
}
panel_to_commands: Dict[CommandType, List[click.Command]] = {}
# Keep order of panels
for command_type in COMMANDS_PANEL_TITLE_CONSTANTS.keys():
panel_to_commands[command_type] = []
for command_name in obj.list_commands(ctx):
command = obj.get_command(ctx, command_name)
if command and not command.hidden:
command_type = command.context_settings.get(
CONTEXT_COMMAND_TYPE, CommandType.MAIN
)
panel_to_commands[command_type].append(command)
for command_type, commands in panel_to_commands.items():
print_main_command_panels(
name=COMMANDS_PANEL_TITLE_CONSTANTS[command_type],
commands_type=command_type,
commands=commands,
markup_mode=markup_mode,
console=console,
)
panel_to_arguments: DefaultDict[str, List[click.Argument]] = defaultdict(list)
panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list)
for param in obj.get_params(ctx):
# Skip if option is hidden
if getattr(param, "hidden", False):
continue
if isinstance(param, click.Argument):
panel_name = (
getattr(param, "rich_help_panel", None) or ARGUMENTS_PANEL_TITLE
)
panel_to_arguments[panel_name].append(param)
elif isinstance(param, click.Option):
panel_name = (
getattr(param, "rich_help_panel", None) or OPTIONS_PANEL_TITLE
)
panel_to_options[panel_name].append(param)
default_arguments = panel_to_arguments.get(ARGUMENTS_PANEL_TITLE, [])
custom_print_options_panel(
name=ARGUMENTS_PANEL_TITLE,
params=default_arguments,
ctx=ctx,
console=console,
)
for panel_name, arguments in panel_to_arguments.items():
if panel_name == ARGUMENTS_PANEL_TITLE:
# Already printed above
continue
custom_print_options_panel(
name=panel_name,
params=arguments,
ctx=ctx,
console=console,
)
default_options = panel_to_options.get(OPTIONS_PANEL_TITLE, [])
custom_print_options_panel(
name=OPTIONS_PANEL_TITLE,
params=default_options,
ctx=ctx,
console=console,
)
for panel_name, options in panel_to_options.items():
if panel_name == OPTIONS_PANEL_TITLE:
# Already printed above
continue
custom_print_options_panel(
name=panel_name,
params=options,
ctx=ctx,
console=console,
)
# Epilogue if we have it
if obj.epilog:
# Remove single linebreaks, replace double with single
lines = obj.epilog.split("\n\n")
epilogue = "\n".join([x.replace("\n", " ").strip() for x in lines])
epilogue_text = custom_make_rich_text(text=epilogue)
console.print(Padding(Align(epilogue_text, pad=False), 1))
def process_auth_status_not_ready(console, auth: "Auth", ctx: typer.Context) -> None:
"""
Handle the process when the authentication status is not ready.
Args:
console: The Rich console.
auth (Auth): The Auth object.
ctx (typer.Context): The Typer context.
"""
from rich.prompt import Confirm, Prompt
from safety_schemas.models import Stage
from safety.auth.constants import CLI_AUTH, MSG_NON_AUTHENTICATED
if not auth.client or not auth.client.is_using_auth_credentials():
if auth.stage is Stage.development:
console.print()
if auth.org:
confirmed = Confirm.ask(
MSG_NO_AUTHD_DEV_STG_ORG_PROMPT,
choices=["Y", "N", "y", "n"],
show_choices=False,
show_default=False,
default=True,
console=console,
)
if not confirmed:
sys.exit(0)
from safety.auth.cli import auth_app
login_command = get_command_for(name="login", typer_instance=auth_app)
ctx.invoke(login_command)
else:
console.print(MSG_NO_AUTHD_DEV_STG)
console.print()
choices = ["L", "R", "l", "r"]
next_command = Prompt.ask(
MSG_NO_AUTHD_DEV_STG_PROMPT,
default=None,
choices=choices,
show_choices=False,
console=console,
)
from safety.auth.cli import auth_app
login_command = get_command_for(name="login", typer_instance=auth_app)
register_command = get_command_for(
name="register", typer_instance=auth_app
)
if next_command is None or next_command.lower() not in choices:
sys.exit(0)
console.print()
if next_command.lower() == "r":
ctx.invoke(register_command)
else:
ctx.invoke(login_command)
if not ctx.obj.auth.email_verified:
sys.exit(1)
else:
if not auth.org:
console.print(MSG_NO_AUTHD_CICD_PROD_STG_ORG.format(LOGIN_URL=CLI_AUTH))
else:
console.print(MSG_NO_AUTHD_CICD_PROD_STG)
console.print(
MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL.format(
LOGIN_URL=CLI_AUTH, SIGNUP_URL=f"{CLI_AUTH}/?sign_up=True"
)
)
sys.exit(1)
elif not auth.email_verified:
console.print()
console.print(
MSG_NO_VERIFIED_EMAIL_TPL.format(
email=auth.email if auth.email else "Missing email"
)
)
sys.exit(1)
else:
console.print(MSG_NON_AUTHENTICATED)
sys.exit(1)
class CustomContext(click.Context):
def __init__(
self,
command: "Command",
parent: Optional["Context"] = None,
command_type: CommandType = CommandType.MAIN,
feature_type: Optional[FeatureType] = None,
**kwargs,
) -> None:
self.command_type = command_type
self.feature_type = feature_type
self.started_at = time.monotonic()
self.command_alias_used: Optional[str] = None
super().__init__(command, parent=parent, **kwargs)
class SafetyCLISubGroup(TyperGroup):
"""
Custom TyperGroup with additional functionality for Safety CLI.
"""
context_class = CustomContext
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
"""
Format help message with rich formatting.
Args:
ctx (click.Context): Click context.
formatter (click.HelpFormatter): Click help formatter.
"""
pretty_format_help(self, ctx, markup_mode=self.rich_markup_mode)
def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
"""
Format usage message.
Args:
ctx (click.Context): Click context.
formatter (click.HelpFormatter): Click help formatter.
"""
command_path = ctx.command_path
pieces = self.collect_usage_pieces(ctx)
main_group = ctx.parent
if main_group:
command_path = (
f"{main_group.command_path} [GLOBAL-OPTIONS] {ctx.command.name}"
)
formatter.write_usage(command_path, " ".join(pieces))
def command(
self,
*args: Any,
**kwargs: Any,
) -> click.Command: # type: ignore[override]
"""
Create a new command.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
click.Command: The created command.
"""
super().command(*args, **kwargs)
class SafetyCLICommand(TyperCommand):
"""
Custom TyperCommand with additional functionality for Safety CLI.
"""
context_class = CustomContext
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
"""
Format help message with rich formatting.
Args:
ctx (click.Context): Click context.
formatter (click.HelpFormatter): Click help formatter.
"""
pretty_format_help(self, ctx, markup_mode=self.rich_markup_mode)
def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
"""
Format usage message.
Args:
ctx (click.Context): Click context.
formatter (click.HelpFormatter): Click help formatter.
"""
command_path = ctx.command_path
pieces = self.collect_usage_pieces(ctx)
main_group = ctx.parent
if main_group:
command_path = (
f"{main_group.command_path} [GLOBAL-OPTIONS] {ctx.command.name}"
)
formatter.write_usage(command_path, " ".join(pieces))
class SafetyCLILegacyGroup(click.Group):
"""
Custom Click Group to handle legacy command-line arguments.
"""
context_class = CustomContext
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.all_commands = {}
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
def add_command(self, cmd, name=None) -> None:
super().add_command(cmd, name)
name = name or cmd.name
self.all_commands[name] = cmd
def parse_args(self, ctx: click.Context, args: List[str]) -> List[str]:
ctx = cast(CustomContext, ctx)
if len(args) >= 1:
if "pip" in args[0] and ctx:
ctx.command_alias_used = args[0]
args[0] = "pip"
parsed_args = super().parse_args(ctx, args)
args = ctx.args
# Workaround for legacy check options, that now are global options
subcommand_args = set(args)
PROXY_HOST_OPTIONS = set(["--proxy-host", "-ph"])
if (
"check" in ctx.protected_args
or "license" in ctx.protected_args
and (
bool(
PROXY_HOST_OPTIONS.intersection(subcommand_args)
or "--key" in subcommand_args
)
)
):
proxy_options, key = self.parse_legacy_args(args)
if proxy_options:
ctx.params.update(proxy_options)
if key:
ctx.params.update({"key": key})
return parsed_args
def parse_legacy_args(
self, args: List[str]
) -> Tuple[Optional[Dict[str, str]], Optional[str]]:
"""
Parse legacy command-line arguments for proxy settings and keys.
Args:
args (List[str]): List of command-line arguments.
Returns:
Tuple[Optional[Dict[str, str]], Optional[str]]: Parsed proxy options and key.
"""
options = {"proxy_protocol": "https", "proxy_port": 80, "proxy_host": None}
key = None
for i, arg in enumerate(args):
if arg in ["--proxy-protocol", "-pr"] and i + 1 < len(args):
options["proxy_protocol"] = args[i + 1]
elif arg in ["--proxy-port", "-pp"] and i + 1 < len(args):
options["proxy_port"] = int(args[i + 1])
elif arg in ["--proxy-host", "-ph"] and i + 1 < len(args):
options["proxy_host"] = args[i + 1]
elif arg in ["--key"] and i + 1 < len(args):
key = args[i + 1]
proxy = options if options["proxy_host"] else None
return proxy, key
def get_filtered_commands(self, ctx: click.Context) -> Dict[str, click.Command]:
from safety.auth.utils import initialize
initialize(ctx, refresh=False)
# Filter commands here:
from .constants import CONTEXT_FEATURE_TYPE
disabled_features = [
feature_type
for feature_type in FeatureType
if not getattr(ctx.obj, feature_type.attr_name, False)
]
return {
k: v
for k, v in self.commands.items()
if v.context_settings.get(CONTEXT_FEATURE_TYPE, None)
not in disabled_features
or k in ["firewall"]
}
def invoke(self, ctx: click.Context) -> None:
"""
Invoke the command, handling legacy arguments.
Args:
ctx (click.Context): Click context.
"""
session_kwargs = {
"ctx": ctx,
"proxy_protocol": ctx.params.pop("proxy_protocol", None),
"proxy_host": ctx.params.pop("proxy_host", None),
"proxy_port": ctx.params.pop("proxy_port", None),
"key": ctx.params.pop("key", None),
"stage": ctx.params.pop("stage", None),
}
invoked_command = make_str(next(iter(ctx.protected_args), ""))
from safety.auth.cli_utils import inject_session
inject_session(**session_kwargs, invoked_command=invoked_command)
# call initialize if the --key is used.
if session_kwargs["key"]:
from safety.auth.utils import initialize
initialize(ctx, refresh=True)
self.commands = self.get_filtered_commands(ctx)
# Now, invoke the original behavior
super(SafetyCLILegacyGroup, self).invoke(ctx)
def list_commands(self, ctx: click.Context) -> List[str]:
"""Override click.Group.list_commands with custom filtering"""
self.commands = self.get_filtered_commands(ctx)
return super().list_commands(ctx)
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
"""
Format help message with rich formatting.
Args:
ctx (click.Context): Click context.
formatter (click.HelpFormatter): Click help formatter.
"""
# The main `safety --help`
if self.name == "cli":
format_main_help(self, ctx, markup_mode="rich")
# All other help outputs
else:
pretty_format_help(self, ctx, markup_mode="rich")
class SafetyCLILegacyCommand(click.Command):
"""
Custom Click Command to handle legacy command-line arguments.
"""
context_class = CustomContext
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
"""
Format help message with rich formatting.
Args:
ctx (click.Context): Click context.
formatter (click.HelpFormatter): Click help formatter.
"""
pretty_format_help(self, ctx, markup_mode="rich")
def get_git_branch_name() -> Optional[str]:
"""
Retrieves the current Git branch name.
Returns:
str: The current Git branch name, or None if it cannot be determined.
"""
try:
branch_name = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
stderr=subprocess.DEVNULL,
text=True,
).strip()
return branch_name if branch_name else None
except Exception:
return None

View File

@@ -0,0 +1,177 @@
import logging
from pathlib import Path
from typing import Optional
from safety.codebase.render import render_initialization_result
from safety.errors import SafetyError
from safety.events.utils.emission import (
emit_codebase_detection_status,
emit_codebase_setup_completed,
)
from safety.init.main import launch_auth_if_needed
from safety.tool.main import find_local_tool_files
from safety.util import clean_project_id
from typing_extensions import Annotated
import typer
from safety.cli_util import SafetyCLISubGroup, SafetyCLICommand
from .constants import (
CMD_CODEBASE_INIT_NAME,
CMD_HELP_CODEBASE_INIT,
CMD_HELP_CODEBASE,
CMD_CODEBASE_GROUP_NAME,
CMD_HELP_CODEBASE_INIT_DISABLE_FIREWALL,
CMD_HELP_CODEBASE_INIT_LINK_TO,
CMD_HELP_CODEBASE_INIT_NAME,
CMD_HELP_CODEBASE_INIT_PATH,
)
from ..cli_util import CommandType, get_command_for
from ..error_handlers import handle_cmd_exception
from ..decorators import notify
from ..constants import CONTEXT_COMMAND_TYPE, DEFAULT_EPILOG
from safety.console import main_console as console
from .main import initialize_codebase, prepare_unverified_codebase
logger = logging.getLogger(__name__)
cli_apps_opts = {
"rich_markup_mode": "rich",
"cls": SafetyCLISubGroup,
"name": CMD_CODEBASE_GROUP_NAME,
}
codebase_app = typer.Typer(**cli_apps_opts)
DEFAULT_CMD = CMD_CODEBASE_INIT_NAME
@codebase_app.callback(
invoke_without_command=True,
cls=SafetyCLISubGroup,
help=CMD_HELP_CODEBASE,
epilog=DEFAULT_EPILOG,
context_settings={
"allow_extra_args": True,
"ignore_unknown_options": True,
CONTEXT_COMMAND_TYPE: CommandType.BETA,
},
)
def codebase(
ctx: typer.Context,
):
"""
Group command for Safety Codebase (project) operations. Running this command will forward to the default command.
"""
logger.info("codebase started")
# If no subcommand is invoked, forward to the default command
if not ctx.invoked_subcommand:
default_command = get_command_for(name=DEFAULT_CMD, typer_instance=codebase_app)
return ctx.forward(default_command)
@codebase_app.command(
cls=SafetyCLICommand,
help=CMD_HELP_CODEBASE_INIT,
name=CMD_CODEBASE_INIT_NAME,
epilog=DEFAULT_EPILOG,
options_metavar="[OPTIONS]",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
@handle_cmd_exception
@notify
def init(
ctx: typer.Context,
name: Annotated[
Optional[str],
typer.Option(
help=CMD_HELP_CODEBASE_INIT_NAME,
callback=lambda name: clean_project_id(name) if name else None,
),
] = None,
link_to: Annotated[
Optional[str],
typer.Option(
"--link-to",
help=CMD_HELP_CODEBASE_INIT_LINK_TO,
callback=lambda name: clean_project_id(name) if name else None,
),
] = None,
skip_firewall_setup: Annotated[
bool, typer.Option(help=CMD_HELP_CODEBASE_INIT_DISABLE_FIREWALL)
] = False,
codebase_path: Annotated[
Path,
typer.Option(
"--path",
exists=True,
file_okay=False,
dir_okay=True,
writable=True,
readable=True,
resolve_path=True,
show_default=False,
help=CMD_HELP_CODEBASE_INIT_PATH,
),
] = Path("."),
):
"""
Initialize a Safety Codebase. The codebase may be entirely new to Safety Platform,
or may already exist in Safety Platform and the user is wanting to initialize it locally.
"""
logger.info("codebase init started")
if link_to and name:
raise typer.BadParameter("--link-to and --name cannot be used together")
org_slug = launch_auth_if_needed(ctx, console)
if not org_slug:
raise SafetyError(
"Organization not found, please run 'safety auth status' or 'safety auth login'"
)
should_enable_firewall = not skip_firewall_setup and ctx.obj.firewall_enabled
unverified_codebase = prepare_unverified_codebase(
codebase_path=codebase_path,
user_provided_name=name,
user_provided_link_to=link_to,
)
local_files = find_local_tool_files(codebase_path)
emit_codebase_detection_status(
event_bus=ctx.obj.event_bus,
ctx=ctx,
detected=any(local_files),
detected_files=local_files if local_files else None,
)
project_file_created, project_status = initialize_codebase(
ctx=ctx,
console=console,
codebase_path=codebase_path,
unverified_codebase=unverified_codebase,
org_slug=org_slug,
link_to=link_to,
should_enable_firewall=should_enable_firewall,
)
codebase_init_status = (
"reinitialized" if unverified_codebase.created else project_status
)
codebase_id = ctx.obj.project.id if ctx.obj.project and ctx.obj.project.id else None
render_initialization_result(
console=console,
codebase_init_status=codebase_init_status,
codebase_id=codebase_id,
)
emit_codebase_setup_completed(
event_bus=ctx.obj.event_bus,
ctx=ctx,
is_created=project_file_created,
codebase_id=codebase_id,
)

View File

@@ -0,0 +1,27 @@
CMD_HELP_CODEBASE_INIT = "Initialize a Safety Codebase (like git init for security). Sets up a new codebase or connects your local project to an existing one on Safety Platform."
CMD_HELP_CODEBASE = (
"[BETA] Manage your Safety Codebase integration.\nExample: safety codebase init"
)
CMD_CODEBASE_GROUP_NAME = "codebase"
CMD_CODEBASE_INIT_NAME = "init"
# init options help
CMD_HELP_CODEBASE_INIT_NAME = "Name of the codebase. Defaults to GIT origin name, parent directory name, or random string if parent directory is unnamed. The value will be normalized for use as an identifier."
CMD_HELP_CODEBASE_INIT_LINK_TO = (
"Link to an existing codebase using its codebase slug (found in Safety Platform)."
)
CMD_HELP_CODEBASE_INIT_DISABLE_FIREWALL = "Don't enable Firewall protection for this codebase (enabled by default when available in your organization)"
CMD_HELP_CODEBASE_INIT_PATH = (
"Path to the codebase directory. Defaults to current directory."
)
CODEBASE_INIT_REINITIALIZED = "Reinitialized existing codebase {codebase_name}"
CODEBASE_INIT_ALREADY_EXISTS = "A codebase already exists in this directory. Please delete .safety-project.ini and run `safety codebase init` again to initialize a new codebase."
CODEBASE_INIT_NOT_FOUND_LINK_TO = "\nError: codebase '{codebase_name}' specified with --link-to does not exist.\n\nTo create a new codebase instead, use one of:\n safety codebase init\n safety codebase init --name \"custom name\"\n\nTo link to an existing codebase, verify the codebase id and try again."
CODEBASE_INIT_NOT_FOUND_PROJECT_FILE = "\nError: codebase '{codebase_name}' specified with the current .safety-project.ini file does not exist.\n\nTo create a new codebase instead, delete the corrupted .safety-project.ini file and then use one of:\n safety codebase init\n safety codebase init --name \"custom name\"\n\nTo link to an existing codebase, verify the codebase id and try again."
CODEBASE_INIT_LINKED = "Linked to codebase {codebase_name}."
CODEBASE_INIT_CREATED = "Created new codebase {codebase_name}."
CODEBASE_INIT_ERROR = "Error: unable to initialize the codebase. Please try again."

View File

@@ -0,0 +1,115 @@
from typing import TYPE_CHECKING, Optional
from ..codebase_utils import load_unverified_project_from_config
from safety.errors import SafetyError, SafetyException
from pathlib import Path
from safety.codebase.constants import (
CODEBASE_INIT_ERROR,
CODEBASE_INIT_NOT_FOUND_LINK_TO,
CODEBASE_INIT_NOT_FOUND_PROJECT_FILE,
)
from safety.init.main import create_project
from typer import Context
from rich.console import Console
import sys
if TYPE_CHECKING:
from ..codebase_utils import UnverifiedProjectModel
def initialize_codebase(
ctx: Context,
console: Console,
codebase_path: Path,
unverified_codebase: "UnverifiedProjectModel",
org_slug: str,
link_to: Optional[str] = None,
should_enable_firewall: bool = False,
):
is_interactive = sys.stdin.isatty()
link_behavior = "prompt"
create_if_missing = True
is_codebase_file_created = unverified_codebase.created
if link_to or is_codebase_file_created:
link_behavior = "always"
create_if_missing = False
elif not is_interactive:
link_behavior = "never"
project_file_created, project_status = create_project(
ctx=ctx,
console=console,
target=codebase_path,
unverified_project=unverified_codebase,
create_if_missing=create_if_missing,
link_behavior=link_behavior,
)
if project_status == "not_found":
codebase_name = "Unknown"
msg = "Codebase not found."
if link_to:
msg = CODEBASE_INIT_NOT_FOUND_LINK_TO
codebase_name = link_to
elif is_codebase_file_created:
msg = CODEBASE_INIT_NOT_FOUND_PROJECT_FILE
codebase_name = unverified_codebase.id
raise SafetyError(msg.format(codebase_name=codebase_name))
elif project_status == "found" and not is_interactive:
# Non-TTY mode: Project exists but we can't link (link_behavior="never")
suggested_name = unverified_codebase.id
raise SafetyError(
f"Project '{suggested_name}' already exists. "
f"In non-interactive mode, use --link-to '{suggested_name}' to link to the existing project, "
f"or use --name with a different project name to create a new one."
)
if not ctx.obj.project:
raise SafetyException(CODEBASE_INIT_ERROR)
if should_enable_firewall:
from ..tool.main import configure_local_directory
configure_local_directory(codebase_path, org_slug, ctx.obj.project.id)
return project_file_created, project_status
def fail_if_codebase_name_mismatch(
provided_name: str,
unverified_codebase: "UnverifiedProjectModel",
) -> None:
"""
Useful to prevent the user from overwriting an existing codebase by mistyping the name.
"""
if unverified_codebase.id and provided_name != unverified_codebase.id:
from safety.codebase.constants import CODEBASE_INIT_ALREADY_EXISTS
raise SafetyError(CODEBASE_INIT_ALREADY_EXISTS)
def prepare_unverified_codebase(
codebase_path: Path,
user_provided_name: Optional[str] = None,
user_provided_link_to: Optional[str] = None,
) -> "UnverifiedProjectModel":
"""
Prepare the unverified codebase object based on the provided name and link to.
"""
unverified_codebase = load_unverified_project_from_config(
project_root=codebase_path
)
provided_name = user_provided_name or user_provided_link_to
if provided_name:
fail_if_codebase_name_mismatch(
provided_name=provided_name,
unverified_codebase=unverified_codebase,
)
unverified_codebase.id = provided_name
return unverified_codebase

View File

@@ -0,0 +1,34 @@
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from rich.console import Console
def render_initialization_result(
console: "Console",
codebase_init_status: Optional[str] = None,
codebase_id: Optional[str] = None,
):
if not codebase_init_status or not codebase_id:
console.print("Error: unable to initialize codebase")
return
message = None
if codebase_init_status == "created":
from safety.codebase.constants import CODEBASE_INIT_CREATED
message = CODEBASE_INIT_CREATED
if codebase_init_status == "linked":
from safety.codebase.constants import CODEBASE_INIT_LINKED
message = CODEBASE_INIT_LINKED
if codebase_init_status == "reinitialized":
from safety.codebase.constants import CODEBASE_INIT_REINITIALIZED
message = CODEBASE_INIT_REINITIALIZED
if message:
console.print(message.format(codebase_name=codebase_id))

View File

@@ -0,0 +1,89 @@
import configparser
import logging
from pathlib import Path
from dataclasses import dataclass
from typing import Optional
from safety_schemas.models import ProjectModel
PROJECT_CONFIG = ".safety-project.ini"
PROJECT_CONFIG_SECTION = "project"
PROJECT_CONFIG_ID = "id"
PROJECT_CONFIG_URL = "url"
PROJECT_CONFIG_NAME = "name"
logger = logging.getLogger(__name__)
@dataclass
class UnverifiedProjectModel:
"""
Data class representing an unverified project model.
"""
id: Optional[str]
project_path: Path
created: bool
name: Optional[str] = None
url_path: Optional[str] = None
def load_unverified_project_from_config(project_root: Path) -> UnverifiedProjectModel:
"""
Loads an unverified project from the configuration file located at the project root.
Args:
project_root (Path): The root directory of the project.
Returns:
UnverifiedProjectModel: An instance of UnverifiedProjectModel.
"""
config = configparser.ConfigParser()
project_path = project_root / PROJECT_CONFIG
config.read(project_path)
id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None)
url = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_URL, fallback=None)
name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None)
created = True
if not id:
created = False
return UnverifiedProjectModel(
id=id, url_path=url, name=name, project_path=project_path, created=created
)
def save_project_info(project: ProjectModel, project_path: Path) -> bool:
"""
Saves the project information to the configuration file.
Args:
project (ProjectModel): The ProjectModel object containing project
information.
project_path (Path): The path to the configuration file.
Returns:
bool: True if the project information was saved successfully, False
otherwise.
"""
config = configparser.ConfigParser()
config.read(project_path)
if PROJECT_CONFIG_SECTION not in config.sections():
config[PROJECT_CONFIG_SECTION] = {}
config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_ID] = project.id
if project.url_path:
config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_URL] = project.url_path
if project.name:
config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_NAME] = project.name
try:
with open(project_path, "w") as configfile:
config.write(configfile)
except Exception:
logger.exception("Error saving project info")
return False
return True

View File

@@ -0,0 +1,149 @@
from functools import lru_cache
import logging
import os
import sys
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union
from rich.console import Console
from rich.theme import Theme
from safety.emoji import load_emoji
if TYPE_CHECKING:
from rich.console import HighlighterType, JustifyMethod, OverflowMethod
from rich.style import Style
from rich.text import Text
LOG = logging.getLogger(__name__)
@lru_cache()
def should_use_ascii():
"""
Check if we should use ASCII alternatives for emojis
"""
encoding = getattr(sys.stdout, "encoding", "").lower()
if encoding in {"utf-8", "utf8", "cp65001", "utf-8-sig"}:
return False
return True
def get_spinner_animation() -> List[str]:
"""
Get the spinner animation based on the encoding
"""
if should_use_ascii():
spinner = [
"[ ]",
"[= ]",
"[== ]",
"[=== ]",
"[====]",
"[ ===]",
"[ ==]",
"[ =]",
]
else:
spinner = ["", "", "", "", "", "", "", "", "", ""]
return spinner
def replace_non_ascii_chars(text: str):
"""
Replace non-ascii characters with ascii alternatives
"""
CHARS_MAP = {
"": "-",
"": "'",
}
for char, replacement in CHARS_MAP.items():
text = text.replace(char, replacement)
try:
text.encode("ascii")
except UnicodeEncodeError:
LOG.warning("No handled non-ascii characters detected, encoding with replace")
text = text.encode("ascii", "replace").decode("ascii")
return text
class SafeConsole(Console):
"""
Console subclass that handles emoji encoding issues by detecting
problematic encoding environments and replacing emojis with ASCII alternatives.
Uses string replacement for custom emoji namespace to avoid private API usage.
"""
def render_str(
self,
text: str,
*,
style: Union[str, "Style"] = "",
justify: Optional["JustifyMethod"] = None,
overflow: Optional["OverflowMethod"] = None,
emoji: Optional[bool] = None,
markup: Optional[bool] = None,
highlight: Optional[bool] = None,
highlighter: Optional["HighlighterType"] = None,
) -> "Text":
"""
Override render_str to pre-process our custom emojis before Rich handles the text.
"""
use_ascii = should_use_ascii()
text = load_emoji(text, use_ascii=use_ascii)
if use_ascii:
text = replace_non_ascii_chars(text)
# Let Rich handle everything else normally
return super().render_str(
text,
style=style,
justify=justify,
overflow=overflow,
emoji=emoji,
markup=markup,
highlight=highlight,
highlighter=highlighter,
)
SAFETY_THEME = {
"file_title": "bold default on default",
"dep_name": "bold yellow on default",
"scan_meta_title": "bold default on default",
"vuln_brief": "red on default",
"rem_brief": "bold green on default",
"rem_severity": "bold red on default",
"brief_severity": "bold default on default",
"status.spinner": "green",
"recommended_ver": "bold cyan on default",
"vuln_id": "bold default on default",
"number": "bold cyan on default",
"link": "underline bright_blue on default",
"tip": "bold default on default",
"specifier": "bold cyan on default",
"vulns_found_number": "red on default",
}
non_interactive = os.getenv("NON_INTERACTIVE") == "1"
console_kwargs: Dict[str, Any] = {
"theme": Theme(SAFETY_THEME, inherit=False),
"emoji": not should_use_ascii(),
}
if non_interactive:
LOG.info(
"NON_INTERACTIVE environment variable is set, forcing non-interactive mode"
)
console_kwargs["force_terminal"] = True
console_kwargs["force_interactive"] = False
main_console = SafeConsole(**console_kwargs)

View File

@@ -0,0 +1,259 @@
# -*- coding: utf-8 -*-
import configparser
import os
from enum import Enum
from pathlib import Path
from typing import Optional
from safety.meta import get_version
JSON_SCHEMA_VERSION = "2.0.0"
# TODO fix this
OPEN_MIRRORS = [
f"https://pyup.io/aws/safety/free/{JSON_SCHEMA_VERSION}/",
]
DIR_NAME = ".safety"
def get_system_dir() -> Path:
"""
Get the system directory for the safety configuration.
Returns:
Path: The system directory path.
"""
import os
import sys
raw_dir = os.getenv("SAFETY_SYSTEM_CONFIG_PATH")
app_data = os.environ.get("ALLUSERSPROFILE", None)
if not raw_dir:
if sys.platform.startswith("win") and app_data:
raw_dir = app_data
elif sys.platform.startswith("darwin"):
raw_dir = "/Library/Application Support"
elif sys.platform.startswith("linux"):
raw_dir = "/etc"
else:
raw_dir = "/"
return Path(raw_dir, DIR_NAME)
def get_user_dir() -> Path:
"""
Get the user directory for the safety configuration.
Returns:
Path: The user directory path.
"""
path = Path("~", DIR_NAME).expanduser()
return path
USER_CONFIG_DIR = get_user_dir()
SYSTEM_CONFIG_DIR = get_system_dir()
CACHE_FILE_DIR = USER_CONFIG_DIR / f"{JSON_SCHEMA_VERSION.replace('.', '')}"
DB_CACHE_FILE = CACHE_FILE_DIR / "cache.json"
CONFIG_FILE_NAME = "config.ini"
CONFIG_FILE_SYSTEM = SYSTEM_CONFIG_DIR / CONFIG_FILE_NAME if SYSTEM_CONFIG_DIR else None
CONFIG_FILE_USER = USER_CONFIG_DIR / CONFIG_FILE_NAME
CONFIG = (
CONFIG_FILE_SYSTEM
if CONFIG_FILE_SYSTEM and CONFIG_FILE_SYSTEM.exists()
else CONFIG_FILE_USER
)
SAFETY_POLICY_FILE_NAME = ".safety-policy.yml"
SYSTEM_POLICY_FILE = SYSTEM_CONFIG_DIR / SAFETY_POLICY_FILE_NAME
USER_POLICY_FILE = USER_CONFIG_DIR / SAFETY_POLICY_FILE_NAME
DEFAULT_DOMAIN = "safetycli.com"
DEFAULT_EMAIL = f"support@{DEFAULT_DOMAIN}"
class URLSettings(Enum):
PLATFORM_API_BASE_URL = f"https://platform.{DEFAULT_DOMAIN}/cli/api/v1"
DATA_API_BASE_URL = f"https://data.{DEFAULT_DOMAIN}/api/v1/safety/"
CLIENT_ID = "AWnwFBMr9DdZbxbDwYxjm4Gb24pFTnMp"
AUTH_SERVER_URL = f"https://auth.{DEFAULT_DOMAIN}"
SAFETY_PLATFORM_URL = f"https://platform.{DEFAULT_DOMAIN}"
FIREWALL_API_BASE_URL = "https://pkgs.safetycli.com"
class FeatureType(Enum):
"""
Defines server-controlled features for dynamic feature management.
Each enum value represents a toggleable feature controlled through
server-side configuration, enabling gradual rollouts to different user
segments. Features are cached during CLI initialization.
History:
Created to support progressive feature rollouts and A/B testing without
disturbing users.
"""
FIREWALL = "firewall"
PLATFORM = "platform"
EVENTS = "events"
@property
def config_key(self) -> str:
"""For JSON/config lookup e.g. 'feature-a-enabled'"""
return f"{self.name.lower()}-enabled"
@property
def attr_name(self) -> str:
"""For Python attribute access e.g. 'feature_a_enabled'"""
return f"{self.name.lower()}_enabled"
def get_config_setting(name: str, default=None) -> Optional[str]:
"""
Get the configuration setting from the config file or defaults.
Args:
name (str): The name of the setting to retrieve.
Returns:
Optional[str]: The value of the setting if found, otherwise None.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
if name in [setting.name for setting in URLSettings]:
default = URLSettings[name]
if "settings" in config.sections() and name in config["settings"]:
value = config["settings"][name]
if value:
return value
return default.value if default else default
DATA_API_BASE_URL = get_config_setting("DATA_API_BASE_URL")
PLATFORM_API_BASE_URL = get_config_setting("PLATFORM_API_BASE_URL")
PLATFORM_API_PROJECT_ENDPOINT = f"{PLATFORM_API_BASE_URL}/project"
PLATFORM_API_PROJECT_CHECK_ENDPOINT = f"{PLATFORM_API_BASE_URL}/project-check"
PLATFORM_API_POLICY_ENDPOINT = f"{PLATFORM_API_BASE_URL}/policy"
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT = (
f"{PLATFORM_API_BASE_URL}/project-scan-request"
)
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT = f"{PLATFORM_API_BASE_URL}/scan"
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT = (
f"{PLATFORM_API_BASE_URL}/process_files"
)
PLATFORM_API_CHECK_UPDATES_ENDPOINT = f"{PLATFORM_API_BASE_URL}/versions-and-configs"
PLATFORM_API_INITIALIZE_ENDPOINT = f"{PLATFORM_API_BASE_URL}/initialize"
PLATFORM_API_EVENTS_ENDPOINT = f"{PLATFORM_API_BASE_URL}/events"
FIREWALL_API_BASE_URL = get_config_setting("FIREWALL_API_BASE_URL")
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT = f"{FIREWALL_API_BASE_URL}/audit/pypi/packages/"
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT = (
f"{FIREWALL_API_BASE_URL}/audit/npmjs/packages/"
)
API_MIRRORS = [DATA_API_BASE_URL]
# Fetch the REQUEST_TIMEOUT from the environment variable, defaulting to 30 if not set
REQUEST_TIMEOUT = int(os.getenv("SAFETY_REQUEST_TIMEOUT", 30))
# Colors
YELLOW = "yellow"
RED = "red"
GREEN = "green"
# MESSAGES
IGNORE_UNPINNED_REQ_REASON = (
"This vulnerability is being ignored due to the 'ignore-unpinned-requirements' flag (default True). "
"To change this, set 'ignore-unpinned-requirements' to False under 'security' in your policy file. "
"See https://docs.pyup.io/docs/safety-20-policy-file for more information."
)
# REGEXES
HASH_REGEX_GROUPS = r"--hash[=| ](\w+):(\w+)"
DOCS_API_KEY_URL = "https://docs.safetycli.com/cli/api-keys"
MSG_NO_AUTHD_DEV_STG = "Please login or register Safety CLI [bold](free forever)[/bold] to scan and secure your projects with Safety"
MSG_NO_AUTHD_DEV_STG_PROMPT = "(R)egister for a free account in 30 seconds, or (L)ogin with an existing account to continue (R/L)"
MSG_NO_AUTHD_DEV_STG_ORG_PROMPT = "Please log in to secure your projects with Safety. Press enter to continue to log in (Y/N)"
MSG_NO_AUTHD_CICD_PROD_STG = "Enter your Safety API key to scan projects in CI/CD using the --key argument or setting your API key in the SAFETY_API_KEY environment variable."
MSG_NO_AUTHD_CICD_PROD_STG_ORG = f"""
Login to get your API key
To log in: [link]{{LOGIN_URL}}[/link]
Read more at: [link]{DOCS_API_KEY_URL}[/link]
"""
MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL = f"""
Login or register for a free account to get your API key
To log in: [link]{{LOGIN_URL}}[/link]
To register: [link]{{SIGNUP_URL}}[/link]
Read more at: [link]{DOCS_API_KEY_URL}[/link]
"""
MSG_FINISH_REGISTRATION_TPL = (
"To complete your account open the “verify your email” email sent to {email}"
)
MSG_VERIFICATION_HINT = "Cant find the verification email? Login at [link]`https://platform.safetycli.com/login/`[/link] to resend the verification email"
MSG_NO_VERIFIED_EMAIL_TPL = f"""Email verification is required for {{email}}
{MSG_FINISH_REGISTRATION_TPL}
{MSG_VERIFICATION_HINT}"""
# Exit codes
EXIT_CODE_OK = 0
EXIT_CODE_FAILURE = 1
EXIT_CODE_VULNERABILITIES_FOUND = 64
EXIT_CODE_INVALID_AUTH_CREDENTIAL = 65
EXIT_CODE_TOO_MANY_REQUESTS = 66
EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB = 67
EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB = 68
EXIT_CODE_MALFORMED_DB = 69
EXIT_CODE_INVALID_PROVIDED_REPORT = 70
EXIT_CODE_INVALID_REQUIREMENT = 71
EXIT_CODE_EMAIL_NOT_VERIFIED = 72
# For Depreciated Messages
BAR_LINE = "+===========================================================================================================================================================================================+"
BETA_PANEL_DESCRIPTION_HELP = "These commands are experimental and part of our commitment to delivering innovative features. As we refine functionality, they may be significantly altered or, in rare cases, removed without prior notice. We welcome your feedback and encourage cautious use."
CONTEXT_COMMAND_TYPE = "command_type"
CONTEXT_FEATURE_TYPE = "feature_type"
CLI_VERSION = get_version()
CLI_WEBSITE_URL = "https://safetycli.com"
CLI_DOCUMENTATION_URL = "https://docs.safetycli.com"
CLI_SUPPORT_EMAIL = "support@safetycli.com"
# Main Safety --help data:
CLI_MAIN_INTRODUCTION = (
f"Safety CLI 3 - Vulnerability Scanning for Secure Python Development\n\n"
"Leverage the most comprehensive vulnerability data available to secure your projects against vulnerable and malicious packages. Safety CLI is a Python dependency vulnerability scanner that enhances software supply chain security at every stage of development.\n\n"
f"Documentation: {CLI_DOCUMENTATION_URL}\n"
f"Contact: {CLI_SUPPORT_EMAIL}\n"
)
DEFAULT_EPILOG = (
f"\nSafety CLI version: {CLI_VERSION}\n"
f"\nDocumentation: {CLI_DOCUMENTATION_URL}\n\n\n\n"
"Made with love by Safety Cybersecurity\n\n"
f"{CLI_WEBSITE_URL}\n\n"
f"{CLI_SUPPORT_EMAIL}\n"
)

View File

@@ -0,0 +1,40 @@
from functools import wraps
from safety.events.utils import emit_command_executed
def notify(func):
"""
A decorator that wraps a function to emit events.
Args:
func (callable): The function to be wrapped by the decorator.
Returns:
callable: The wrapped function with notification logic.
The decorator ensures that the `emit_command_executed` function is called
after the wrapped function completes, regardless of whether it exits
normally or via a `SystemExit` exception.
Example:
@notify
def my_function(ctx, *args, **kwargs):
# function implementation
pass
"""
@wraps(func)
def inner(ctx, *args, **kwargs):
try:
result = func(ctx, *args, **kwargs)
emit_command_executed(ctx.obj.event_bus, ctx, returned_code=0)
return result
except SystemExit as e:
# Handle sys.exit() case
exit_code = e.code if isinstance(e.code, int) else 1
emit_command_executed(ctx.obj.event_bus, ctx, returned_code=exit_code)
raise
# Any other exceptions will bypass notification and propagate normally
return inner

View File

@@ -0,0 +1,104 @@
# Custom emoji namespace mapping
import re
from typing import Match
CUSTOM_EMOJI_MAP = {
"icon_check": "",
"icon_warning": "⚠️",
"icon_info": "",
}
# ASCII fallback mapping for problematic environments
ASCII_FALLBACK_MAP = {
"icon_check": "+",
"icon_warning": "!",
"icon_info": "i",
"white_heavy_check_mark": "++",
"white_check_mark": "+",
"check_mark": "+",
"heavy_check_mark": "+",
"shield": "[SHIELD]",
"x": "X",
"lock": "[LOCK]",
"key": "[KEY]",
"pencil": "[EDIT]",
"arrow_up": "^",
"stop_sign": "[STOP]",
"warning": "!",
"locked": "[LOCK]",
"pushpin": "[PIN]",
"magnifying_glass_tilted_left": "[SCAN]",
"fire": "[CRIT]",
"yellow_circle": "[HIGH]",
"sparkles": "*",
"mag_right": "[VIEW]",
"link": "->",
"light_bulb": "[TIP]",
"trophy": "[DONE]",
"rocket": ">>",
"busts_in_silhouette": "[TEAM]",
"floppy_disk": "[SAVE]",
"heavy_plus_sign": "[ADD]",
"books": "[DOCS]",
"speech_balloon": "[HELP]",
}
# Pre-compiled regex for emoji processing (Rich-style)
CUSTOM_EMOJI_PATTERN = re.compile(r"(:icon_\w+:)")
def process_custom_emojis(text: str, use_ascii: bool = False) -> str:
"""
Pre-process our custom emoji namespace before Rich handles the text.
This only handles our custom :icon_*: emojis.
"""
if not isinstance(text, str) or ":icon_" not in text:
return text
def replace_custom_emoji(match: Match[str]) -> str:
emoji_code = match.group(1) # :icon_check:
emoji_name = emoji_code[1:-1] # icon_check
# If we should use ASCII, use the fallback
if use_ascii:
return ASCII_FALLBACK_MAP.get(emoji_name, emoji_code)
return CUSTOM_EMOJI_MAP.get(emoji_name, emoji_code)
return CUSTOM_EMOJI_PATTERN.sub(replace_custom_emoji, text)
def process_rich_emojis_fallback(text: str) -> str:
"""
Replace Rich emoji codes with ASCII alternatives when in problematic environments.
"""
# Simple pattern to match Rich emoji codes like :emoji_name:
emoji_pattern = re.compile(r":([a-zA-Z0-9_]+):")
def replace_with_ascii(match: Match[str]) -> str:
emoji_name = match.group(1)
# Check if we have an ASCII fallback
ascii_replacement = ASCII_FALLBACK_MAP.get(emoji_name, None)
if ascii_replacement:
return ascii_replacement
# Otherwise keep the original
return match.group(0)
return emoji_pattern.sub(replace_with_ascii, text)
def load_emoji(text: str, use_ascii: bool = False) -> str:
"""
Load emoji from text if emoji is present.
"""
# Pre-process our custom emojis
text = process_custom_emojis(text, use_ascii)
# If we need ASCII fallbacks, also process Rich emoji codes
if use_ascii:
text = process_rich_emojis_fallback(text)
return text

View File

@@ -0,0 +1,28 @@
from pathlib import Path
import logging
logger = logging.getLogger(__name__)
def detect_encoding(file_path: Path) -> str:
"""
UTF-8 is the most common encoding standard, this is a simple
way to improve the support for related Windows based files.
Handles the most common cases efficiently.
"""
try:
with open(file_path, "rb") as f:
# Read first 3 bytes for BOM detection
bom = f.read(3)
# Check most common Windows patterns first
if bom[:2] in (b"\xff\xfe", b"\xfe\xff"):
return "utf-16"
elif bom.startswith(b"\xef\xbb\xbf"):
return "utf-8-sig"
return "utf-8"
except Exception:
logger.exception("Error detecting encoding")
return "utf-8"

View File

@@ -0,0 +1,89 @@
# Standard library imports
import logging
import sys
import traceback
from functools import wraps
from typing import TYPE_CHECKING, Optional
# Third-party imports
import click
# Local imports
from safety.constants import EXIT_CODE_FAILURE, EXIT_CODE_OK
from safety.errors import SafetyError, SafetyException
from safety.events.utils import emit_command_error
if TYPE_CHECKING:
from safety.scan.models import ScanOutput
LOG = logging.getLogger(__name__)
def output_exception(exception: Exception, exit_code_output: bool = True) -> None:
"""
Output an exception message to the console and exit.
Args:
exception (Exception): The exception to output.
exit_code_output (bool): Whether to output the exit code.
Exits:
Exits the program with the appropriate exit code.
"""
click.secho(str(exception), fg="red", file=sys.stderr)
if exit_code_output:
exit_code = EXIT_CODE_FAILURE
if hasattr(exception, "get_exit_code"):
exit_code = exception.get_exit_code()
else:
exit_code = EXIT_CODE_OK
sys.exit(exit_code)
def handle_cmd_exception(func):
"""
Decorator to handle exceptions in command functions.
Args:
func: The command function to wrap.
Returns:
The wrapped function.
"""
@wraps(func)
def inner(ctx, output: Optional["ScanOutput"] = None, *args, **kwargs):
if output:
from safety.scan.models import ScanOutput
kwargs.update({"output": output})
if output is ScanOutput.NONE:
return func(ctx, *args, **kwargs)
try:
return func(ctx, *args, **kwargs)
except click.ClickException as e:
emit_command_error(
ctx.obj.event_bus, ctx, message=str(e), traceback=traceback.format_exc()
)
raise e
except SafetyError as e:
LOG.exception("Expected SafetyError happened: %s", e)
emit_command_error(
ctx.obj.event_bus, ctx, message=str(e), traceback=traceback.format_exc()
)
output_exception(e, exit_code_output=True)
except Exception as e:
emit_command_error(
ctx.obj.event_bus, ctx, message=str(e), traceback=traceback.format_exc()
)
LOG.exception("Unexpected Exception happened: %s", e)
exception = e if isinstance(e, SafetyException) else SafetyException(info=e)
output_exception(exception, exit_code_output=True)
return inner

View File

@@ -0,0 +1,287 @@
from typing import Optional
from safety.constants import (
EXIT_CODE_EMAIL_NOT_VERIFIED,
EXIT_CODE_FAILURE,
EXIT_CODE_INVALID_AUTH_CREDENTIAL,
EXIT_CODE_INVALID_PROVIDED_REPORT,
EXIT_CODE_INVALID_REQUIREMENT,
EXIT_CODE_MALFORMED_DB,
EXIT_CODE_TOO_MANY_REQUESTS,
EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB,
EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB,
)
class SafetyException(Exception):
"""
Base exception for Safety CLI errors.
Args:
message (str): The error message template.
info (str): Additional information to include in the error message.
"""
def __init__(self, message: str = "Unhandled exception happened: {info}", info: str = ""):
self.message = message.format(info=info)
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this exception.
Returns:
int: The exit code.
"""
return EXIT_CODE_FAILURE
class SafetyError(Exception):
"""
Generic Safety CLI error.
Args:
message (str): The error message.
error_code (Optional[int]): The error code.
"""
def __init__(self, message: str = "Unhandled Safety generic error", error_code: Optional[int] = None):
self.message = message
self.error_code = error_code
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_FAILURE
class MalformedDatabase(SafetyError):
"""
Error raised when the vulnerability database is malformed.
Args:
reason (Optional[str]): The reason for the error.
fetched_from (str): The source of the fetched data.
message (str): The error message template.
"""
def __init__(self, reason: Optional[str] = None, fetched_from: str = "server",
message: str = "Sorry, something went wrong.\n"
"Safety CLI cannot read the data fetched from {fetched_from} because it is malformed.\n"):
info = f"Reason, {reason}" if reason else ""
info = "Reason, {reason}".format(reason=reason)
self.message = message.format(fetched_from=fetched_from) + (info if reason else "")
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_MALFORMED_DB
class DatabaseFetchError(SafetyError):
"""
Error raised when the vulnerability database cannot be fetched.
Args:
message (str): The error message.
"""
def __init__(self, message: str = "Unable to load vulnerability database"):
self.message = message
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB
class InvalidProvidedReportError(SafetyError):
"""
Error raised when the provided report is invalid for applying fixes.
Args:
message (str): The error message.
"""
def __init__(self, message: str = "Unable to apply fix: the report needs to be generated from a file. "
"Environment isn't supported yet."):
self.message = message
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_INVALID_PROVIDED_REPORT
class InvalidRequirementError(SafetyError):
"""
Error raised when a requirement is invalid.
Args:
message (str): The error message template.
line (str): The invalid requirement line.
"""
def __init__(self, message: str = "Unable to parse the requirement: {line}", line: str = ""):
self.message = message.format(line=line)
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_INVALID_REQUIREMENT
class DatabaseFileNotFoundError(DatabaseFetchError):
"""
Error raised when the vulnerability database file is not found.
Args:
db (Optional[str]): The database file path.
message (str): The error message template.
"""
def __init__(self, db: Optional[str] = None, message: str = "Unable to find vulnerability database in {db}"):
self.db = db
self.message = message.format(db=db)
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB
class InvalidCredentialError(DatabaseFetchError):
"""
Error raised when authentication credentials are invalid.
Args:
credential (Optional[str]): The invalid credential.
message (str): The error message template.
reason (Optional[str]): The reason for the error.
"""
def __init__(self, credential: Optional[str] = None,
message: str = "Your authentication credential{credential}is invalid. See {link}.",
reason: Optional[str] = None):
self.credential = credential
self.link = 'https://docs.safetycli.com/safety-docs/support/invalid-api-key-error'
self.message = message.format(credential=f" '{self.credential}' ", link=self.link) if self.credential else message.format(credential=' ', link=self.link)
info = f" Reason: {reason}"
self.message = self.message + (info if reason else "")
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_INVALID_AUTH_CREDENTIAL
class NotVerifiedEmailError(SafetyError):
"""
Error raised when the user's email is not verified.
Args:
message (str): The error message.
"""
def __init__(self, message: str = "email is not verified"):
self.message = message
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_EMAIL_NOT_VERIFIED
class TooManyRequestsError(DatabaseFetchError):
"""
Error raised when too many requests are made to the server.
Args:
reason (Optional[str]): The reason for the error.
message (str): The error message template.
"""
def __init__(self, reason: Optional[str] = None,
message: str = "Too many requests."):
info = f" Reason: {reason}"
self.message = message + (info if reason else "")
super().__init__(self.message)
def get_exit_code(self) -> int:
"""
Get the exit code associated with this error.
Returns:
int: The exit code.
"""
return EXIT_CODE_TOO_MANY_REQUESTS
class NetworkConnectionError(DatabaseFetchError):
"""
Error raised when there is a network connection issue.
Args:
message (str): The error message.
"""
def __init__(self, message: str = "Check your network connection, unable to reach the server."):
self.message = message
super().__init__(self.message)
class RequestTimeoutError(DatabaseFetchError):
"""
Error raised when a request times out.
Args:
message (str): The error message.
"""
def __init__(self, message: str = "Check your network connection, the request timed out."):
self.message = message
super().__init__(self.message)
class ServerError(DatabaseFetchError):
"""
Error raised when there is a server issue.
Args:
reason (Optional[str]): The reason for the error.
message (str): The error message template.
"""
def __init__(self, reason: Optional[str] = None,
message: str = "Sorry, something went wrong.\n"
"Our engineers are working quickly to resolve the issue."):
info = f" Reason: {reason}"
self.message = message + (info if reason else "")
super().__init__(self.message)

View File

@@ -0,0 +1,25 @@
from .handlers import EventHandler
from .types import (
CloseResourcesEvent,
CommandErrorEvent,
CommandExecutedEvent,
FirewallConfiguredEvent,
FirewallDisabledEvent,
FirewallHeartbeatEvent,
FlushSecurityTracesEvent,
PackageInstalledEvent,
PackageUninstalledEvent,
)
__all__ = [
"EventHandler",
"CloseResourcesEvent",
"FlushSecurityTracesEvent",
"CommandExecutedEvent",
"CommandErrorEvent",
"PackageInstalledEvent",
"PackageUninstalledEvent",
"FirewallHeartbeatEvent",
"FirewallConfiguredEvent",
"FirewallDisabledEvent",
]

View File

@@ -0,0 +1,7 @@
from .bus import EventBus
from .utils import start_event_bus
__all__ = [
"EventBus",
"start_event_bus",
]

View File

@@ -0,0 +1,348 @@
"""
Core EventBus implementation.
"""
import asyncio
import queue
import threading
import time
import logging
from concurrent.futures import Future
from typing import Dict, List, Any, Optional, Callable, TypeVar
from dataclasses import dataclass, field
from ..handlers import EventHandler
from safety_schemas.models.events import Event, EventTypeBase, PayloadBase
@dataclass
class EventBusMetrics:
"""
Metrics for the event bus.
"""
events_emitted: int = 0
events_processed: int = 0
events_failed: int = 0
queue_high_water_mark: int = 0
handler_durations: Dict[str, List[float]] = field(default_factory=dict)
E = TypeVar("E", bound=Event)
# Define bounded type variables
EventTypeT = TypeVar("EventTypeT", bound=EventTypeBase)
PayloadT = TypeVar("PayloadT", bound=PayloadBase)
class EventBus:
"""
Event bus that runs in a separate thread with its own asyncio event loop.
This class manages event subscription and dispatching across threads.
This is an approach to leverage asyncio without migrating current codebase
to async.
"""
def __init__(self, max_queue_size: int = 1000):
"""
Initialize the event bus.
Args:
max_queue_size: Maximum number of events that can be queued
"""
self._handlers: Dict[EventTypeBase, List[EventHandler[Any]]] = {}
# Queue for passing events from main thread to event bus thread
self._event_queue: queue.Queue = queue.Queue(maxsize=max_queue_size)
# Thread management
self._running = False
self._thread: Optional[threading.Thread] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._shutdown_event = threading.Event()
# Setup logging
self.logger = logging.getLogger("event_bus")
# Metrics
self.metrics = EventBusMetrics()
def subscribe(
self, event_types: List[EventTypeBase], handler: EventHandler[E]
) -> None:
"""
Subscribe a handler to one or more event types.
Args:
event_types: The list of event types to subscribe to
handler: The handler to register
"""
for event_type in event_types:
if event_type not in self._handlers:
self._handlers[event_type] = []
self.logger.info(
f"Registering handler {handler.__class__.__name__} "
f"for event type {event_type}"
)
self._handlers[event_type].append(handler)
def emit(
self,
event: Event[EventTypeT, PayloadT],
block: bool = False,
timeout: Optional[float] = None,
) -> Optional[Future]:
"""
Emit an event to be processed by the event bus.
Args:
event: The event to emit
block: Whether to block if the queue is full
timeout: How long to wait if blocking
Returns:
Future that will contain the results, or None if the event couldn't be queued
"""
if not self._running:
self.logger.warning("Event bus is not running, but an event was emitted")
self.metrics.events_emitted += 1
# Create a future to track the results
future = Future()
try:
# Track queue size for metrics
current_size = self._event_queue.qsize()
self.metrics.queue_high_water_mark = max(
current_size, self.metrics.queue_high_water_mark
)
# Put the event in the queue
self._event_queue.put((event, future), block=block, timeout=timeout)
self.logger.debug("Emitted %s (%s)", event.type, event.id)
return future
except queue.Full:
self.logger.error(f"Event queue is full, dropping event: {event}")
future.set_exception(RuntimeError("Event queue is full"))
return future
def start(self):
if self._running:
return
self._running = True
self._shutdown_event.clear()
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()
def stop(self, timeout=5.0):
if not self._running:
return True
self._running = False
self._event_queue.put((None, None), block=False) # Send sentinel
return self._shutdown_event.wait(timeout)
def _run_event_loop(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
async def main():
pending_tasks = set()
# Process the queue until shutdown
while self._running or not self._event_queue.empty():
try:
# Get the next event with a short timeout
try:
event, future = self._event_queue.get(timeout=0.1)
except queue.Empty:
await asyncio.sleep(0.01)
continue
# Check for shutdown sentinel
if event is None:
self.logger.info("Received shutdown sentinel")
break
# Process the event
task = asyncio.create_task(self._dispatch_event(event, future))
self.logger.debug(f"Dispatching {event.type} ({event.id})")
pending_tasks.add(task)
task.add_done_callback(lambda t: pending_tasks.discard(t))
except Exception as e:
self.logger.exception(f"Error processing event: {e}")
# Wait for any pending tasks before exiting
if pending_tasks:
self.logger.info(f"Waiting for {len(pending_tasks)} pending tasks")
await asyncio.gather(*pending_tasks, return_exceptions=True)
try:
# Single run_until_complete call for the entire lifecycle
self._loop.run_until_complete(main())
finally:
self._loop.close()
self._shutdown_event.set()
async def _dispatch_event(
self, event: Event[EventTypeBase, PayloadBase], future: Future
) -> None:
"""
Dispatch an event to all registered handlers.
Args:
event: The event to dispatch
future: Future to set with the results
"""
results = []
handlers = self._handlers.get(event.type, [])
if not handlers:
self.logger.warning(f"No handlers registered for event type {event.type}")
future.set_result([])
return
# Create tasks for all handlers and run them concurrently
tasks = []
for handler in handlers:
task = asyncio.create_task(self._handle_event(handler, event))
tasks.append(task)
trace_id = event.correlation_id if event.correlation_id else "-"
self.logger.debug(
"Event %s | %s | %s Handler(s) Task(s)", trace_id, event.type, len(tasks)
)
# Wait for all handlers to complete
handler_results = await asyncio.gather(*tasks, return_exceptions=True)
self.logger.info(
"Event %s | %s | %s Handler(s) Completed",
trace_id,
event.type,
len(handler_results),
)
# Process results
for i, result in enumerate(handler_results):
if isinstance(result, Exception):
self.logger.error(
"Event %s | %s | Handler %d failed: %s",
trace_id,
event.type,
i,
str(result),
exc_info=result,
)
else:
self.logger.debug(
"Event %s | %s | Handler %d succeeded: %s",
trace_id,
event.type,
i,
str(result),
)
results.append(result)
# Set the result on the future
if not future.done():
future.set_result(results)
async def _handle_event(self, handler: EventHandler[E], event: E) -> Any:
"""
Handle a single event with error handling and metrics.
Args:
handler: The handler to use
event: The event to handle
Returns:
The result from the handler
"""
handler_name = handler.__class__.__name__
start_time = time.time()
try:
# Call the handler
result = await handler.handle(event)
# Record successful processing
self.metrics.events_processed += 1
# Record timing
duration = time.time() - start_time
if handler_name not in self.metrics.handler_durations:
self.metrics.handler_durations[handler_name] = []
self.metrics.handler_durations[handler_name].append(duration)
self.logger.debug(
f"Handler {handler_name} processed {event.__class__.__name__} "
f"in {duration:.3f}s"
)
return result
except Exception as e:
# Record failure
self.metrics.events_failed += 1
self.logger.exception(
f"Handler {handler_name} failed to process {event.__class__.__name__}: {e}"
)
raise
def get_metrics(self) -> dict:
"""
Get the current metrics for the event bus.
Returns:
Dictionary of metrics
"""
metrics: dict[str, Any] = {
"events_emitted": self.metrics.events_emitted,
"events_processed": self.metrics.events_processed,
"events_failed": self.metrics.events_failed,
"current_queue_size": self._event_queue.qsize(),
"queue_high_water_mark": self.metrics.queue_high_water_mark,
}
# Add handler metrics
handler_metrics = {}
for handler_name, durations in self.metrics.handler_durations.items():
if not durations:
continue
handler_metrics[handler_name] = {
"count": len(durations),
"avg_duration": sum(durations) / len(durations),
"max_duration": max(durations),
"min_duration": min(durations),
}
metrics["handlers"] = handler_metrics
return metrics
def emit_with_callback(
self, event: Event, callback: Callable[[List[Any]], None]
) -> None:
"""
Emit an event and register a callback for when it completes.
Args:
event: The event to emit
callback: Function to call with the results when complete
"""
future = self.emit(event)
if future:
future.add_done_callback(
lambda f: callback(f.result()) if not f.exception() else None
)

View File

@@ -0,0 +1,57 @@
from typing import TYPE_CHECKING
from .bus import EventBus
from safety_schemas.models.events import EventType
from safety.events.types import InternalEventType
from safety.events.handlers import SecurityEventsHandler
from safety.constants import PLATFORM_API_EVENTS_ENDPOINT
if TYPE_CHECKING:
from safety.models import SafetyCLI
from safety.auth.utils import SafetyAuthSession
def start_event_bus(obj: "SafetyCLI", session: "SafetyAuthSession"):
"""
Initializes the event bus with the default security events handler
for authenticated users.
This function creates an instance of the EventBus, starts it,
and assigns it to the `event_bus` attribute of the provided `obj`.
It also initializes the `security_events_handler` with the necessary
parameters and subscribes it to a predefined list of events.
Args:
obj (SafetyCLI): The main application object.
session (SafetyAuthSession): The authentication session containing
the necessary credentials and proxies.
"""
event_bus = EventBus()
event_bus.start()
obj.event_bus = event_bus
token = session.token.get("access_token") if session.token else None
obj.security_events_handler = SecurityEventsHandler(
api_endpoint=PLATFORM_API_EVENTS_ENDPOINT,
proxies=session.proxies, # type: ignore
auth_token=token,
api_key=session.api_key,
)
events = [
EventType.AUTH_STARTED,
EventType.AUTH_COMPLETED,
EventType.COMMAND_EXECUTED,
EventType.COMMAND_ERROR,
InternalEventType.CLOSE_RESOURCES,
InternalEventType.FLUSH_SECURITY_TRACES,
]
event_bus.subscribe(events, obj.security_events_handler)
if obj.firewall_enabled:
from safety.firewall.events.utils import register_event_handlers
register_event_handlers(obj.event_bus, obj=obj)

View File

@@ -0,0 +1,5 @@
from .base import EventHandler
from .common import SecurityEventsHandler
__all__ = ["EventHandler", "SecurityEventsHandler"]

View File

@@ -0,0 +1,32 @@
"""
Event handler definitions for the event bus system.
"""
from abc import ABC, abstractmethod
from typing import Any, TypeVar, Generic
from safety_schemas.models.events import Event
# Type variable for event types
EventType = TypeVar("EventType", bound=Event)
class EventHandler(Generic[EventType], ABC):
"""
Abstract base class for event handlers.
Concrete handlers should implement the handle method.
"""
@abstractmethod
async def handle(self, event: EventType) -> Any:
"""
Handle an event asynchronously.
Args:
event: The event to handle
Returns:
Any result from handling the event
"""
pass

View File

@@ -0,0 +1,333 @@
import asyncio
import functools
import logging
import os
import sys
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import uuid
import httpx
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
import tenacity
from safety.meta import get_identifier, get_meta_http_headers, get_version
from ..types import (
CommandErrorEvent,
CommandExecutedEvent,
CloseResourcesEvent,
InternalEventType,
FlushSecurityTracesEvent,
)
from ..handlers import EventHandler
if TYPE_CHECKING:
from safety_schemas.models.events import EventContext
from safety.events.utils import InternalPayload
from safety.models import SafetyCLI
SecurityEventTypes = Union[
CommandExecutedEvent,
CommandErrorEvent,
FlushSecurityTracesEvent,
CloseResourcesEvent,
]
class SecurityEventsHandler(EventHandler[SecurityEventTypes]):
"""
Handler that collects events in memory and flushes them when requested.
"""
def __init__(
self,
api_endpoint: str,
proxies: Optional[Dict[str, str]] = None,
auth_token: Optional[str] = None,
api_key: Optional[str] = None,
):
"""
Initialize the telemetry handler.
Args:
api_endpoint: URL to send events to
proxies: Optional dictionary of proxy settings
auth_token: Optional authentication token for the API
api_key: Optional API key for authentication
"""
self.api_endpoint = api_endpoint
self.proxies = proxies
self.auth_token = auth_token
self.api_key = api_key
# Storage for collected events
self.collected_events: List[Dict[str, Any]] = []
# HTTP client (created when needed)
self.http_client = None
# Logging
self.logger = logging.getLogger("security_events_handler")
# Event types that should not be collected (to avoid recursion)
self.excluded_event_types = [
InternalEventType.FLUSH_SECURITY_TRACES,
]
async def handle(self, event: SecurityEventTypes) -> Dict[str, Any]:
"""
Handle an event - either collect it or process a flush request.
Args:
event: The event to handle
Returns:
Status dictionary
"""
if event.type is InternalEventType.CLOSE_RESOURCES:
self.logger.info("Received request to close resources")
await self.close_async()
return {"closed": True}
if event.type is InternalEventType.FLUSH_SECURITY_TRACES:
self.logger.info(f"Received flush request from {event.source}")
return await self.flush(event_payload=event.payload)
# Don't collect excluded event types
if any(event == t for t in self.excluded_event_types):
return {"skipped": True, "reason": "excluded_event_type"}
try:
event_data = event.model_dump(mode="json")
except Exception:
return {"collected": False, "event_count": len(self.collected_events)}
# Add to in-memory collection
self.collected_events.append(event_data)
event_count = len(self.collected_events)
self.logger.debug(
f"Collected event: {event.type}, total event count: {event_count}"
)
return {"collected": True, "event_count": event_count}
async def _build_context_data(self, obj: Optional["SafetyCLI"]) -> "EventContext":
"""
Generate context data for telemetry events.
Returns:
Dict containing context information about client, runtime, etc.
"""
from safety_schemas.models.events.types import SourceType
from safety.events.utils.context import create_event_context
project = getattr(obj, "project", None)
tags = None
try:
if obj and obj.auth and obj.auth.stage:
tags = [obj.auth.stage.value]
except AttributeError:
pass
version = get_version() or "unknown"
path = ""
try:
path = sys.argv[0]
except (IndexError, TypeError):
pass
context = await asyncio.get_event_loop().run_in_executor(
None,
functools.partial(
create_event_context,
SourceType(get_identifier()),
version,
path,
project,
tags,
),
)
return context
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=0.1, min=0.2, max=1.0),
retry=retry_if_exception_type(
(httpx.NetworkError, httpx.TimeoutException, httpx.HTTPStatusError)
),
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
)
async def _send_events(
self, payload: dict, headers: dict
) -> Optional[httpx.Response]:
"""
Send events to the API with retry logic.
Args:
payload: The data payload to send
headers: The HTTP headers to include
Returns:
Response from the API or None if http_client is not initialized
Raises:
Exception if all retries fail
"""
if self.http_client is None:
self.logger.warning("Cannot send events: HTTP client not initialized")
return None
TIMEOUT = int(os.getenv("SAFETY_REQUEST_TIMEOUT_EVENTS", 10))
response = await self.http_client.post(
self.api_endpoint, json=payload, headers=headers, timeout=TIMEOUT
)
response.raise_for_status()
return response
async def flush(self, event_payload: "InternalPayload") -> Dict[str, Any]:
"""
Send all collected events to the API endpoint.
Returns:
Status dictionary
"""
# If no events, just return early
if not self.collected_events:
return {"status": "no_events", "count": 0}
# Get a copy of events and clear the original list
events_to_send = self.collected_events.copy()
self.collected_events.clear()
event_count = len(events_to_send)
self.logger.info(
"[Flush] -> Sending %s events to %s", event_count, self.api_endpoint
)
IDEMPOTENCY_KEY = str(uuid.uuid4())
# Get context data that will be shared across all events
obj = event_payload.ctx.obj if event_payload.ctx else None
context = await self._build_context_data(obj=obj)
self.logger.info("Context data built")
for event_data in events_to_send:
event_data["context"] = context.model_dump(mode="json")
payload = {"events": events_to_send}
# Create HTTP client if needed
if self.http_client is None:
# TODO: Add proxy support
self.http_client = httpx.AsyncClient(proxy=None)
headers = {
"Content-Type": "application/json",
"X-Idempotency-Key": IDEMPOTENCY_KEY,
}
headers.update(get_meta_http_headers())
# Add authentication
if self.api_key:
headers["X-Api-Key"] = self.api_key
elif self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
try:
# Send the request with retries
response = await self._send_events(payload, headers)
# Handle case where http_client was None and _send_events returned None
if response is None:
self.logger.warning("Events not sent: HTTP client not initialized")
# Put events back in collection
self.collected_events = events_to_send + self.collected_events
return {
"status": "error",
"count": event_count,
"error": "HTTP client not initialized",
}
self.logger.info(
f"Successfully sent {event_count} events, status: {response.status_code}"
)
return {
"status": "success",
"count": event_count,
"http_status": response.status_code,
}
except tenacity.RetryError as retry_exc:
# Put events back in collection
self.collected_events = events_to_send + self.collected_events
exc = retry_exc.last_attempt.exception()
status_code = None
if hasattr(exc, "response"):
status_code = exc.response # type: ignore
self.logger.error(f"Failed after retries: {exc}")
result = {"status": "error", "count": event_count, "error": repr(exc)}
if status_code:
result["http_status"] = status_code
return result
except Exception as exc:
# Handle any other unexpected exceptions
self.collected_events = events_to_send + self.collected_events
self.logger.exception(f"Unexpected error: {exc}")
return {"status": "error", "count": event_count, "error": repr(exc)}
async def close_async(self):
"""Close the HTTP client asynchronously."""
if self.http_client:
await self.http_client.aclose()
self.http_client = None
self.logger.debug("HTTP client closed")
def close(self):
"""
Synchronous method to close the handler.
This is a non-blocking method that initiates closure but doesn't wait for it.
The event bus will handle the actual closing asynchronously.
"""
self.logger.info("Initiating telemetry handler shutdown")
# The actual close will happen when the event loop processes events
# Just log the intent and let the event loop handle it
return {"status": "shutdown_initiated"}
def get_stats(self) -> Dict[str, Any]:
"""
Get current telemetry statistics.
Returns:
Dictionary of statistics
"""
event_count = len(self.collected_events)
# Group events by type
event_types = {}
for event in self.collected_events:
event_type = event.get("event_type", "unknown")
if event_type not in event_types:
event_types[event_type] = 0
event_types[event_type] += 1
return {
"events_in_memory": event_count,
"event_types": event_types,
"api_endpoint": self.api_endpoint,
}

View File

@@ -0,0 +1,28 @@
from .aliases import (
CloseResourcesEvent,
CommandErrorEvent,
CommandExecutedEvent,
FirewallConfiguredEvent,
FirewallDisabledEvent,
FirewallHeartbeatEvent,
FlushSecurityTracesEvent,
PackageInstalledEvent,
PackageUninstalledEvent,
EventBusReadyEvent,
)
from .base import InternalEventType, InternalPayload
__all__ = [
"CloseResourcesEvent",
"FlushSecurityTracesEvent",
"InternalEventType",
"InternalPayload",
"CommandExecutedEvent",
"CommandErrorEvent",
"PackageInstalledEvent",
"PackageUninstalledEvent",
"FirewallHeartbeatEvent",
"FirewallConfiguredEvent",
"FirewallDisabledEvent",
"EventBusReadyEvent",
]

View File

@@ -0,0 +1,81 @@
from typing import Literal
from safety_schemas.models.events import Event, EventType
from safety_schemas.models.events.payloads import (
AuthCompletedPayload,
AuthStartedPayload,
CodebaseSetupCompletedPayload,
CodebaseSetupResponseCreatedPayload,
FirewallConfiguredPayload,
FirewallDisabledPayload,
FirewallSetupCompletedPayload,
FirewallSetupResponseCreatedPayload,
InitScanCompletedPayload,
InitStartedPayload,
PackageInstalledPayload,
PackageUninstalledPayload,
CommandExecutedPayload,
CommandErrorPayload,
FirewallHeartbeatPayload,
CodebaseDetectionStatusPayload,
)
from .base import InternalEventType, InternalPayload
CommandExecutedEvent = Event[
Literal[EventType.COMMAND_EXECUTED], CommandExecutedPayload
]
CommandErrorEvent = Event[Literal[EventType.COMMAND_ERROR], CommandErrorPayload]
PackageInstalledEvent = Event[
Literal[EventType.PACKAGE_INSTALLED], PackageInstalledPayload
]
PackageUninstalledEvent = Event[
Literal[EventType.PACKAGE_UNINSTALLED], PackageUninstalledPayload
]
FirewallHeartbeatEvent = Event[
Literal[EventType.FIREWALL_HEARTBEAT], FirewallHeartbeatPayload
]
FirewallConfiguredEvent = Event[
Literal[EventType.FIREWALL_CONFIGURED], FirewallConfiguredPayload
]
FirewallDisabledEvent = Event[
Literal[EventType.FIREWALL_DISABLED], FirewallDisabledPayload
]
InitStartedEvent = Event[Literal[EventType.INIT_STARTED], InitStartedPayload]
AuthStartedEvent = Event[Literal[EventType.AUTH_STARTED], AuthStartedPayload]
AuthCompletedEvent = Event[Literal[EventType.AUTH_COMPLETED], AuthCompletedPayload]
# Firewall setup events
FirewallSetupResponseCreatedEvent = Event[
Literal[EventType.FIREWALL_SETUP_RESPONSE_CREATED],
FirewallSetupResponseCreatedPayload,
]
FirewallSetupCompletedEvent = Event[
Literal[EventType.FIREWALL_SETUP_COMPLETED], FirewallSetupCompletedPayload
]
# Codebase setup events
CodebaseDetectionStatusEvent = Event[
Literal[EventType.CODEBASE_DETECTION_STATUS], CodebaseDetectionStatusPayload
]
CodebaseSetupResponseCreatedEvent = Event[
Literal[EventType.CODEBASE_SETUP_RESPONSE_CREATED],
CodebaseSetupResponseCreatedPayload,
]
CodebaseSetupCompletedEvent = Event[
Literal[EventType.CODEBASE_SETUP_COMPLETED], CodebaseSetupCompletedPayload
]
# Scan events
InitScanCompletedEvent = Event[
Literal[EventType.INIT_SCAN_COMPLETED], InitScanCompletedPayload
]
# Internal events
CloseResourcesEvent = Event[InternalEventType.CLOSE_RESOURCES, InternalPayload]
FlushSecurityTracesEvent = Event[
InternalEventType.FLUSH_SECURITY_TRACES, InternalPayload
]
EventBusReadyEvent = Event[Literal[InternalEventType.EVENT_BUS_READY], InternalPayload]

View File

@@ -0,0 +1,24 @@
from typing import TYPE_CHECKING, Any, Optional
from typing_extensions import Annotated
from pydantic import ConfigDict
from safety_schemas.models.events import EventTypeBase, PayloadBase
if TYPE_CHECKING:
pass
class InternalEventType(EventTypeBase):
"""
Internal event types.
"""
CLOSE_RESOURCES = "com.safetycli.close_resources"
FLUSH_SECURITY_TRACES = "com.safetycli.flush_security_traces"
EVENT_BUS_READY = "com.safetycli.event_bus_ready"
class InternalPayload(PayloadBase):
ctx: Optional[Annotated[Any, "CustomContext"]] = None
model_config = ConfigDict(extra="allow")

View File

@@ -0,0 +1,34 @@
from .emission import (
emit_command_error,
emit_command_executed,
emit_firewall_disabled,
emit_diff_operations,
emit_firewall_configured,
emit_tool_command_executed,
emit_firewall_heartbeat,
emit_init_started,
emit_auth_started,
emit_auth_completed,
)
from .creation import (
create_internal_event,
InternalEventType,
InternalPayload,
)
__all__ = [
"emit_command_error",
"emit_command_executed",
"emit_firewall_disabled",
"create_internal_event",
"InternalEventType",
"InternalPayload",
"emit_firewall_configured",
"emit_diff_operations",
"emit_init_started",
"emit_auth_started",
"emit_auth_completed",
"emit_tool_command_executed",
"emit_firewall_heartbeat",
]

View File

@@ -0,0 +1,79 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, cast, overload
if TYPE_CHECKING:
from safety.events.event_bus import EventBus
from safety.cli_util import CustomContext
def should_emit(
event_bus: Optional["EventBus"], ctx: Optional["CustomContext"]
) -> bool:
"""
Common conditions that apply to all event emissions.
"""
if event_bus is None:
return False
# Be aware that ctx depends on the command being parsed, if the emit func
# is called from the entrypoint group command, ctx will not have
# the command parsed yet.
return True
def should_emit_firewall_heartbeat(ctx: Optional["CustomContext"]) -> bool:
"""
Condition to check if the firewall is enabled.
"""
if ctx and ctx.obj.firewall_enabled:
return True
return False
# Define TypeVars for better typing
F = TypeVar("F", bound=Callable[..., Any])
R = TypeVar("R")
@overload
def conditional_emitter(emit_func: F, *, conditions: None = None) -> F: ...
@overload
def conditional_emitter(
emit_func: None = None,
*,
conditions: Optional[List[Callable[[Optional["CustomContext"]], bool]]] = None,
) -> Callable[[F], F]: ...
def conditional_emitter(
emit_func=None,
*,
conditions: Optional[List[Callable[[Optional["CustomContext"]], bool]]] = None,
):
"""
A decorator that conditionally calls the decorated function based on conditions.
Only executes the decorated function if all conditions evaluate to True.
"""
def decorator(func: F) -> F:
@wraps(func)
def wrapper(event_bus, ctx=None, *args, **kwargs):
if not should_emit(event_bus, ctx):
return None
if conditions:
if all(condition(ctx) for condition in conditions):
return func(event_bus, ctx, *args, **kwargs)
return None
return func(event_bus, ctx, *args, **kwargs)
return cast(F, wrapper) # Cast to help type checker
if emit_func is None:
return decorator
return decorator(emit_func)

View File

@@ -0,0 +1,163 @@
import getpass
import os
from pathlib import Path
import site
import socket
import sys
import platform
from typing import List, Optional
from safety_schemas.models.events.context import (
ClientInfo,
EventContext,
HostInfo,
OsInfo,
ProjectInfo,
PythonInfo,
RuntimeInfo,
UserInfo,
)
from safety_schemas.models.events.types import SourceType
from safety_schemas.models import ProjectModel
def get_user_info() -> UserInfo:
"""
Collect information about the current user.
"""
return UserInfo(name=getpass.getuser(), home_dir=str(Path.home()))
def get_os_info() -> OsInfo:
"""
Get basic OS information using only the platform module.
Returns a dictionary with architecture, platform, name, version, and kernel_version.
"""
# Initialize with required fields
os_info = {
"architecture": platform.machine(),
"platform": platform.system(),
"name": None,
"version": None,
"kernel_version": None,
}
python_version = sys.version_info
if sys.platform == "wind32":
os_info["version"] = platform.release()
os_info["kernel_version"] = platform.version()
os_info["name"] = "windows"
elif sys.platform == "darwin":
os_info["version"] = platform.mac_ver()[0]
os_info["kernel_version"] = platform.release()
os_info["name"] = "macos"
elif sys.platform == "linux":
os_info["kernel_version"] = platform.release()
if python_version >= (3, 10):
try:
os_release = platform.freedesktop_os_release()
# Use ID for name (more consistent for programmatic use)
os_info["name"] = os_release.get("ID", "linux")
os_info["version"] = os_release.get("VERSION_ID")
except (OSError, AttributeError):
# If freedesktop_os_release fails, keep values as is
pass
return OsInfo(**os_info)
def get_host_info() -> HostInfo:
"""
Collect information about the host machine.
"""
hostname = socket.gethostname()
ipv4_addresses = set()
ipv6_addresses = set()
try:
host_info = socket.getaddrinfo(hostname, None)
for info in host_info:
ip_family = info[0]
ip = str(info[4][0])
if ip_family == socket.AF_INET:
if not ip.startswith("127."):
ipv4_addresses.add(ip)
elif ip_family == socket.AF_INET6:
if not ip.startswith("::1") and ip != "fe80::1":
ipv6_addresses.add(ip)
# Prioritize addresses
primary_ipv4 = next(
(ip for ip in ipv4_addresses),
next(iter(ipv4_addresses)) if ipv4_addresses else None,
)
primary_ipv6 = next(
(ip for ip in ipv6_addresses if not ip.startswith("fe80:")),
next(iter(ipv6_addresses)) if ipv6_addresses else None,
)
except socket.gaierror:
primary_ipv4 = None
primary_ipv6 = None
return HostInfo(name=hostname, ipv4=primary_ipv4, ipv6=primary_ipv6, timezone=None)
def get_python_info() -> PythonInfo:
"""
Collect detailed information about the Python environment.
"""
# Get site-packages directories
site_packages_dirs = site.getsitepackages()
user_site_enabled = bool(site.ENABLE_USER_SITE)
user_site_packages = site.getusersitepackages()
return PythonInfo(
version=f"{sys.version_info.major}.{sys.version_info.minor}",
path=sys.executable,
sys_path=sys.path,
implementation=platform.python_implementation(),
implementation_version=platform.python_version(),
sys_prefix=sys.prefix,
site_packages=site_packages_dirs,
user_site_enabled=user_site_enabled,
user_site_packages=user_site_packages,
encoding=sys.getdefaultencoding(),
filesystem_encoding=sys.getfilesystemencoding(),
)
def create_event_context(
client_identifier: SourceType,
client_version: str,
client_path: str,
project: Optional[ProjectModel] = None,
tags: Optional[List[str]] = None,
) -> EventContext:
client = ClientInfo(
identifier=client_identifier, version=client_version, path=client_path
)
project_info = None
if project:
project_info = ProjectInfo(
id=project.id,
url=project.url_path,
)
runtime = RuntimeInfo(
workdir=os.getcwd(),
user=get_user_info(),
os=get_os_info(),
host=get_host_info(),
python=get_python_info(),
)
return EventContext(client=client, runtime=runtime, project=project_info, tags=tags)

View File

@@ -0,0 +1,48 @@
import time
from typing import Optional, TypeVar
from safety_schemas.models.events import Event, EventTypeBase, PayloadBase, SourceType
from safety.meta import get_identifier
from ..types import InternalEventType, InternalPayload
PayloadBaseT = TypeVar("PayloadBaseT", bound=PayloadBase)
EventTypeBaseT = TypeVar("EventTypeBaseT", bound=EventTypeBase)
def create_event(
payload: PayloadBaseT,
event_type: EventTypeBaseT,
source: SourceType = SourceType(get_identifier()),
timestamp: int = int(time.time()),
correlation_id: Optional[str] = None,
**kwargs,
) -> Event[EventTypeBaseT, PayloadBaseT]:
"""
Generic factory function for creating any type of event.
"""
return Event(
timestamp=timestamp,
payload=payload,
type=event_type,
source=source,
correlation_id=correlation_id,
**kwargs,
)
def create_internal_event(
event_type: InternalEventType,
payload: InternalPayload,
) -> Event[InternalEventType, InternalPayload]:
"""
Create an internal event.
"""
return Event(
type=event_type,
timestamp=int(time.time()),
source=SourceType(get_identifier()),
payload=payload,
)

View File

@@ -0,0 +1,110 @@
import re
from typing import Any, List, Optional
from click.core import ParameterSource as ClickParameterSource
from safety_schemas.models.events.types import ParamSource
def is_sensitive_parameter(param_name: str) -> bool:
"""
Determine if a parameter name likely contains sensitive information.
"""
sensitive_patterns = [
r"(?i)pass(word)?", # password, pass
r"(?i)token", # token, auth_token
r"(?i)key", # key, apikey
r"(?i)auth", # auth, authorization
]
return any(re.search(pattern, param_name) for pattern in sensitive_patterns)
def scrub_sensitive_value(value: str) -> str:
"""
Detect if a value appears to be sensitive information based on
specific patterns.
"""
if not isinstance(value, str):
return value
result = value
if re.match(r"^-{1,2}[\w-]+$", value) and "=" not in value:
return value
# Patterns to detect and replace
patterns = [
# This will replace ports too, but that's fine
(r"\b\w+:\w+\b", "-:-"),
(r"Basic\s+[A-Za-z0-9+/=]+", "Basic -"),
(r"Bearer\s+[A-Za-z0-9._~+/=-]+", "Bearer -"),
(r"\b[A-Za-z0-9_-]{20,}\b", "-"),
(
r"((?:token|api|apikey|key|auth|secret|password|access|jwt|bearer|credential|pwd)=)([^&\s]+)",
r"\1-",
),
]
# Apply each pattern and replace matches
for pattern, repl in patterns:
result = re.sub(pattern, repl, result)
return result
def clean_parameter(param_name: str, param_value: Any) -> Any:
"""
Scrub a parameter value if it's sensitive.
"""
if not isinstance(param_value, str):
return param_value
if is_sensitive_parameter(param_name):
return "-"
return scrub_sensitive_value(param_value)
def get_command_path(ctx) -> List[str]:
hierarchy = []
current = ctx
while current is not None:
if current.command:
name = current.command.name
if name == "cli":
name = "safety"
hierarchy.append(name)
current = current.parent
# Reverse to get top-level first
hierarchy.reverse()
return hierarchy
def get_root_context(ctx):
"""
Get the top-level parent context.
"""
current = ctx
while current.parent is not None:
current = current.parent
return current
def translate_param_source(source: Optional[ClickParameterSource]) -> ParamSource:
"""
Translate Click's ParameterSource enum to our ParameterSource enum
"""
mapping = {
ClickParameterSource.COMMANDLINE: ParamSource.COMMANDLINE,
ClickParameterSource.ENVIRONMENT: ParamSource.ENVIRONMENT,
ClickParameterSource.DEFAULT: ParamSource.DEFAULT,
# In newer Click versions
getattr(ClickParameterSource, "PROMPT", None): ParamSource.PROMPT,
getattr(ClickParameterSource, "CONFIG_FILE", None): ParamSource.CONFIG,
}
return mapping.get(source, ParamSource.UNKNOWN)

View File

@@ -0,0 +1,681 @@
from concurrent.futures import Future
import logging
from pathlib import Path
import re
import shutil
import subprocess
import sys
import time
from typing import (
TYPE_CHECKING,
Dict,
List,
Optional,
Tuple,
Union,
)
import uuid
from safety.utils.pyapp_utils import get_path, get_env
from safety_schemas.models.events import Event, EventType
from safety_schemas.models.events.types import ToolType
from safety_schemas.models.events.payloads import (
CodebaseDetectionStatusPayload,
CodebaseSetupCompletedPayload,
CodebaseSetupResponseCreatedPayload,
DependencyFile,
FirewallConfiguredPayload,
FirewallDisabledPayload,
FirewallSetupCompletedPayload,
FirewallSetupResponseCreatedPayload,
InitExitStep,
InitExitedPayload,
InitScanCompletedPayload,
PackageInstalledPayload,
PackageUninstalledPayload,
PackageUpdatedPayload,
CommandExecutedPayload,
ToolCommandExecutedPayload,
CommandErrorPayload,
AliasConfig,
IndexConfig,
ToolStatus,
CommandParam,
ProcessStatus,
FirewallHeartbeatPayload,
InitStartedPayload,
AuthStartedPayload,
AuthCompletedPayload,
)
import typer
from ..event_bus import EventBus
from ..types.base import InternalEventType, InternalPayload
from .creation import (
create_event,
)
from .data import (
clean_parameter,
get_command_path,
get_root_context,
scrub_sensitive_value,
translate_param_source,
)
from .conditions import conditional_emitter, should_emit_firewall_heartbeat
if TYPE_CHECKING:
from safety.models import SafetyCLI, ToolResult
from safety.cli_util import CustomContext
from safety.init.types import FirewallConfigStatus
from safety.tool.environment_diff import PackageLocation
logger = logging.getLogger(__name__)
@conditional_emitter
def send_and_flush(event_bus: "EventBus", event: Event) -> Optional[Future]:
"""
Emit an event and immediately flush the event bus without closing it.
Args:
event_bus: The event bus to emit on
event: The event to emit
"""
future = event_bus.emit(event)
# Create and emit flush event
flush_payload = InternalPayload()
flush_event = create_event(
payload=flush_payload, event_type=InternalEventType.FLUSH_SECURITY_TRACES
)
# Emit flush event and wait for it to complete
flush_future = event_bus.emit(flush_event)
# Wait for both events to complete
if future:
try:
future.result(timeout=0.5)
except Exception:
logger.error("Emit Failed %s (%s)", event.type, event.id)
if flush_future:
try:
return flush_future.result(timeout=0.5)
except Exception:
logger.error("Flush Failed for event %s", event.id)
return None
@conditional_emitter(conditions=[should_emit_firewall_heartbeat])
def emit_firewall_heartbeat(
event_bus: "EventBus", ctx: Optional["CustomContext"], *, tools: List[ToolStatus]
):
payload = FirewallHeartbeatPayload(tools=tools)
event = create_event(payload=payload, event_type=EventType.FIREWALL_HEARTBEAT)
event_bus.emit(event)
@conditional_emitter
def emit_firewall_disabled(
event_bus: "EventBus",
ctx: Optional["CustomContext"] = None,
*,
reason: Optional[str],
):
payload = FirewallDisabledPayload(reason=reason)
event = create_event(payload=payload, event_type=EventType.FIREWALL_DISABLED)
event_bus.emit(event)
def status_to_tool_status(status: "FirewallConfigStatus") -> List[ToolStatus]:
filtered_path = get_path()
tools = []
for tool_type, configs in status.items():
alias_config = (
configs["alias"] if isinstance(configs["alias"], AliasConfig) else None
)
index_config = (
configs["index"] if isinstance(configs["index"], IndexConfig) else None
)
tool = tool_type.value
command_path = shutil.which(tool, path=filtered_path)
reachable = False
version = "unknown"
if command_path:
args = [command_path, "--version"]
result = subprocess.run(args, capture_output=True, text=True, env=get_env())
if result.returncode == 0:
output = result.stdout
reachable = True
# Extract version
version_match = re.search(r"(\d+\.\d+(?:\.\d+)?)", output)
if version_match:
version = version_match.group(1)
else:
command_path = tool
tool = ToolStatus(
type=tool_type,
command_path=command_path,
version=version,
reachable=reachable,
alias_config=alias_config,
index_config=index_config,
)
tools.append(tool)
return tools
@conditional_emitter
def emit_firewall_configured(
event_bus: "EventBus",
ctx: Optional["CustomContext"] = None,
*,
status: "FirewallConfigStatus",
):
tools = status_to_tool_status(status)
payload = FirewallConfiguredPayload(tools=tools)
event = create_event(payload=payload, event_type=EventType.FIREWALL_CONFIGURED)
event_bus.emit(event)
@conditional_emitter
def emit_diff_operations(
event_bus: "EventBus",
ctx: "CustomContext",
*,
added: Dict["PackageLocation", str],
removed: Dict["PackageLocation", str],
updated: Dict["PackageLocation", Tuple[str, str]],
tool_path: Optional[str],
by_tool: ToolType,
):
obj: "SafetyCLI" = ctx.obj
correlation_id = obj.correlation_id
kwargs = {
"tool_path": tool_path,
"tool": by_tool,
}
if (added or removed or updated) and not correlation_id:
correlation_id = obj.correlation_id = str(uuid.uuid4())
def emit_package_event(event_bus, correlation_id, payload, event_type):
event = create_event(
payload=payload,
event_type=event_type,
correlation_id=correlation_id,
)
event_bus.emit(event)
for package, version in added.items():
emit_package_event(
event_bus,
correlation_id,
PackageInstalledPayload(
package_name=package.name,
location=package.location,
version=version,
**kwargs,
),
EventType.PACKAGE_INSTALLED,
)
for package, version in removed.items():
emit_package_event(
event_bus,
correlation_id,
PackageUninstalledPayload(
package_name=package.name,
location=package.location,
version=version,
**kwargs,
),
EventType.PACKAGE_UNINSTALLED,
)
for package, (previous_version, current_version) in updated.items():
emit_package_event(
event_bus,
correlation_id,
PackageUpdatedPayload(
package_name=package.name,
location=package.location,
previous_version=previous_version,
current_version=current_version,
**kwargs,
),
EventType.PACKAGE_UPDATED,
)
@conditional_emitter
def emit_tool_command_executed(
event_bus: "EventBus", ctx: "CustomContext", *, tool: ToolType, result: "ToolResult"
) -> None:
correlation_id = ctx.obj.correlation_id
if not correlation_id:
correlation_id = ctx.obj.correlation_id = str(uuid.uuid4())
process = result.process
payload = ToolCommandExecutedPayload(
tool=tool,
tool_path=result.tool_path,
raw_command=[clean_parameter("", arg) for arg in process.args],
duration_ms=result.duration_ms,
status=ProcessStatus(
stdout=process.stdout, stderr=process.stderr, return_code=process.returncode
),
)
# Scrub after binary coercion to str
if payload.status.stdout:
payload.status.stdout = scrub_sensitive_value(payload.status.stdout)
if payload.status.stderr:
payload.status.stderr = scrub_sensitive_value(payload.status.stderr)
event = create_event(
correlation_id=correlation_id,
payload=payload,
event_type=EventType.TOOL_COMMAND_EXECUTED,
)
event_bus.emit(event)
@conditional_emitter
def emit_command_executed(
event_bus: "EventBus", ctx: "CustomContext", *, returned_code: int
) -> None:
root_context = get_root_context(ctx)
NA = ""
started_at = getattr(root_context, "started_at", None) if root_context else None
if started_at is not None:
duration_ms = int((time.monotonic() - started_at) * 1000)
else:
duration_ms = 1
command_name = ctx.command.name if ctx.command.name is not None else NA
raw_command = [clean_parameter("", arg) for arg in sys.argv]
params: List[CommandParam] = []
for idx, param in enumerate(ctx.command.params):
param_name = param.name if param.name is not None else NA
param_value = ctx.params.get(param_name)
# Scrub the parameter value if sensitive
scrubbed_value = clean_parameter(param_name, param_value)
# Determine parameter source using Click's API
click_source = ctx.get_parameter_source(param_name)
source = translate_param_source(click_source)
display_name = param_name if param_name else None
params.append(
CommandParam(
position=idx, name=display_name, value=scrubbed_value, source=source
)
)
payload = CommandExecutedPayload(
command_name=command_name,
command_path=get_command_path(ctx),
raw_command=raw_command,
parameters=params,
duration_ms=duration_ms,
status=ProcessStatus(
return_code=returned_code,
),
)
event = create_event(
correlation_id=ctx.obj.correlation_id,
payload=payload,
event_type=EventType.COMMAND_EXECUTED,
)
try:
if future := event_bus.emit(event):
future.result(timeout=0.5)
except Exception:
logger.error("Emit Failed %s (%s)", event.type, event.id)
@conditional_emitter
def emit_command_error(
event_bus: "EventBus",
ctx: "CustomContext",
*,
message: str,
traceback: Optional[str] = None,
) -> None:
"""
Emit a CommandErrorEvent with sensitive data scrubbed.
"""
# Get command name from context if available
command_name = getattr(ctx, "command", None)
if command_name and command_name.name:
command_name = command_name.name
scrub_traceback = None
if traceback:
scrub_traceback = scrub_sensitive_value(traceback)
command_path = get_command_path(ctx)
raw_command = [scrub_sensitive_value(arg) for arg in sys.argv]
payload = CommandErrorPayload(
command_name=command_name,
raw_command=raw_command,
command_path=command_path,
error_message=scrub_sensitive_value(message),
stacktrace=scrub_traceback,
)
event = create_event(
payload=payload,
event_type=EventType.COMMAND_ERROR,
)
event_bus.emit(event)
def emit_init_started(
event_bus: "EventBus", ctx: Union["CustomContext", typer.Context]
) -> None:
"""
Emit an InitStartedEvent and store it as a pending event in SafetyCLI object.
Args:
event_bus: The event bus to emit on
ctx: The Click context containing the SafetyCLI object
"""
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = InitStartedPayload()
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.INIT_STARTED,
)
if not send_and_flush(event_bus, event):
# Store as pending event
obj.pending_events.append(event)
def emit_auth_started(event_bus: "EventBus", ctx: "CustomContext") -> None:
"""
Emit an AuthStartedEvent and store it as a pending event in SafetyCLI object.
Args:
event_bus: The event bus to emit on
ctx: The Click context containing the SafetyCLI object
"""
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = AuthStartedPayload()
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.AUTH_STARTED,
)
if not send_and_flush(event_bus, event):
# Store as pending event
obj.pending_events.append(event)
@conditional_emitter
def emit_auth_completed(
event_bus: "EventBus",
ctx: "CustomContext",
*,
success: bool = True,
error_message: Optional[str] = None,
) -> None:
"""
Emit an AuthCompletedEvent and submit all pending events together.
Args:
event_bus: The event bus to emit on
ctx: The Click context containing the SafetyCLI object
success: Whether authentication was successful
error_message: Optional error message if authentication failed
"""
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = AuthCompletedPayload(success=success, error_message=error_message)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.AUTH_COMPLETED,
)
for pending_event in obj.pending_events:
event_bus.emit(pending_event)
obj.pending_events.clear()
# Emit auth completed event and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_firewall_setup_response_created(
event_bus: "EventBus",
ctx: Union["CustomContext", typer.Context],
*,
user_consent_requested: bool,
user_consent: Optional[bool] = None,
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = FirewallSetupResponseCreatedPayload(
user_consent_requested=user_consent_requested, user_consent=user_consent
)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.FIREWALL_SETUP_RESPONSE_CREATED,
)
# Emit and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_codebase_setup_response_created(
event_bus: "EventBus",
ctx: Union["CustomContext", typer.Context],
*,
user_consent_requested: bool,
user_consent: Optional[bool] = None,
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = CodebaseSetupResponseCreatedPayload(
user_consent_requested=user_consent_requested, user_consent=user_consent
)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.CODEBASE_SETUP_RESPONSE_CREATED,
)
# Emit and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_codebase_detection_status(
event_bus: "EventBus",
ctx: Union["CustomContext", typer.Context],
*,
detected: bool,
detected_files: Optional[List[Path]] = None,
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = CodebaseDetectionStatusPayload(
detected=detected,
dependency_files=[
DependencyFile(file_path=str(file)) for file in detected_files
]
if detected_files
else None,
)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.CODEBASE_DETECTION_STATUS,
)
# Emit and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_init_scan_completed(
event_bus: "EventBus",
ctx: Union["CustomContext", typer.Context],
*,
scan_id: Optional[str],
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = InitScanCompletedPayload(scan_id=scan_id)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.INIT_SCAN_COMPLETED,
)
# Emit and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_codebase_setup_completed(
event_bus: "EventBus",
ctx: Union["CustomContext", typer.Context],
*,
is_created: bool,
codebase_id: Optional[str] = None,
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = CodebaseSetupCompletedPayload(
is_created=is_created, codebase_id=codebase_id
)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.CODEBASE_SETUP_COMPLETED,
)
# Emit and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_firewall_setup_completed(
event_bus: "EventBus",
ctx: "CustomContext",
*,
status: "FirewallConfigStatus",
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
tools = status_to_tool_status(status)
payload = FirewallSetupCompletedPayload(
tools=tools,
)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.FIREWALL_SETUP_COMPLETED,
)
# Emit and flush
send_and_flush(event_bus, event)
@conditional_emitter
def emit_init_exited(
event_bus: "EventBus",
ctx: Union["CustomContext", typer.Context],
*,
exit_step: InitExitStep,
) -> None:
obj: "SafetyCLI" = ctx.obj
if not obj.correlation_id:
obj.correlation_id = str(uuid.uuid4())
payload = InitExitedPayload(exit_step=exit_step)
event = create_event(
correlation_id=obj.correlation_id,
payload=payload,
event_type=EventType.INIT_EXITED,
)
# Emit and flush
send_and_flush(event_bus, event)

Some files were not shown because too many files have changed in this diff Show More