updates
This commit is contained in:
@@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
__author__ = """safetycli.com"""
|
||||
__email__ = 'cli@safetycli.com'
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Allow safety to be executable through `python -m safety`."""
|
||||
from safety.cli import cli
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
cli(prog_name="safety")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,110 @@
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
from typing import Any, IO
|
||||
import click
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from safety.constants import CONTEXT_COMMAND_TYPE
|
||||
|
||||
from . import github
|
||||
from safety.util import SafetyPolicyFile
|
||||
from safety.scan.constants import CLI_ALERT_COMMAND_HELP
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_safety_cli_legacy_group():
|
||||
from safety.cli_util import SafetyCLILegacyGroup
|
||||
|
||||
return SafetyCLILegacyGroup
|
||||
|
||||
|
||||
def get_context_settings():
|
||||
from safety.cli_util import CommandType
|
||||
|
||||
return {CONTEXT_COMMAND_TYPE: CommandType.UTILITY}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Alert:
|
||||
"""
|
||||
Data class for storing alert details.
|
||||
|
||||
Attributes:
|
||||
report (Any): The report data.
|
||||
key (str): The API key for the safetycli.com vulnerability database.
|
||||
policy (Any): The policy data.
|
||||
requirements_files (Any): The requirements files data.
|
||||
"""
|
||||
|
||||
report: Any
|
||||
key: str
|
||||
policy: Any = None
|
||||
requirements_files: Any = None
|
||||
|
||||
|
||||
@click.group(
|
||||
cls=get_safety_cli_legacy_group(),
|
||||
help=CLI_ALERT_COMMAND_HELP,
|
||||
deprecated=True,
|
||||
context_settings=get_context_settings(),
|
||||
)
|
||||
@click.option(
|
||||
"--check-report",
|
||||
help="JSON output of Safety Check to work with.",
|
||||
type=click.File("r"),
|
||||
default=sys.stdin,
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--key",
|
||||
envvar="SAFETY_API_KEY",
|
||||
help="API Key for safetycli.com's vulnerability database. Can be set as SAFETY_API_KEY "
|
||||
"environment variable.",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--policy-file",
|
||||
type=SafetyPolicyFile(),
|
||||
default=".safety-policy.yml",
|
||||
help="Define the policy file to be used",
|
||||
)
|
||||
@click.pass_context
|
||||
def alert(
|
||||
ctx: click.Context, check_report: IO[str], policy_file: SafetyPolicyFile, key: str
|
||||
) -> None:
|
||||
"""
|
||||
Command for processing the Safety Check JSON report.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): The Click context object.
|
||||
check_report (IO[str]): The file containing the JSON report.
|
||||
policy_file (SafetyPolicyFile): The policy file to be used.
|
||||
key (str): The API key for the safetycli.com vulnerability database.
|
||||
"""
|
||||
LOG.info("alert started")
|
||||
LOG.info(f"check_report is using stdin: {check_report == sys.stdin}")
|
||||
|
||||
with check_report:
|
||||
# TODO: This breaks --help for subcommands
|
||||
try:
|
||||
safety_report = json.load(check_report)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
LOG.info("Error in the JSON report.")
|
||||
click.secho("Error decoding input JSON: {}".format(e.msg), fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
if "report_meta" not in safety_report:
|
||||
click.secho("You must pass in a valid Safety Check JSON report", fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
ctx.obj = Alert(
|
||||
report=safety_report, policy=policy_file if policy_file else {}, key=key
|
||||
)
|
||||
|
||||
|
||||
# Adding subcommands for GitHub integration
|
||||
alert.add_command(github.github_pr)
|
||||
alert.add_command(github.github_issue)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,575 @@
|
||||
# type: ignore
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
import click
|
||||
|
||||
try:
|
||||
import github as pygithub
|
||||
except ImportError:
|
||||
pygithub = None
|
||||
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.utils import canonicalize_name
|
||||
|
||||
from . import utils, requirements
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_branch(repo: Any, base_branch: str, new_branch: str) -> None:
|
||||
"""
|
||||
Create a new branch in the given GitHub repository.
|
||||
|
||||
Args:
|
||||
repo (Any): The GitHub repository object.
|
||||
base_branch (str): The name of the base branch.
|
||||
new_branch (str): The name of the new branch to create.
|
||||
"""
|
||||
ref = repo.get_git_ref("heads/" + base_branch)
|
||||
repo.create_git_ref(ref="refs/heads/" + new_branch, sha=ref.object.sha)
|
||||
|
||||
|
||||
def delete_branch(repo: Any, branch: str) -> None:
|
||||
"""
|
||||
Delete a branch from the given GitHub repository.
|
||||
|
||||
Args:
|
||||
repo (Any): The GitHub repository object.
|
||||
branch (str): The name of the branch to delete.
|
||||
"""
|
||||
ref = repo.get_git_ref(f"heads/{branch}")
|
||||
ref.delete()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--repo", help="GitHub standard repo path (eg, my-org/my-project)")
|
||||
@click.option("--token", help="GitHub Access Token")
|
||||
@click.option(
|
||||
"--base-url",
|
||||
help="Optional custom Base URL, if you're using GitHub enterprise",
|
||||
default=None,
|
||||
)
|
||||
@click.pass_obj
|
||||
@utils.require_files_report
|
||||
def github_pr(obj: Any, repo: str, token: str, base_url: Optional[str]) -> None:
|
||||
"""
|
||||
Create a GitHub PR to fix any vulnerabilities using Safety's remediation data.
|
||||
|
||||
This is usually run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results.
|
||||
|
||||
Args:
|
||||
obj (Any): The Click context object containing report data.
|
||||
repo (str): The GitHub repository path.
|
||||
token (str): The GitHub Access Token.
|
||||
base_url (Optional[str]): Custom base URL for GitHub Enterprise, if applicable.
|
||||
"""
|
||||
if pygithub is None:
|
||||
click.secho(
|
||||
"pygithub is not installed. Did you install Safety with GitHub support? Try pip install safety[github]",
|
||||
fg="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Load alert configurations from the policy
|
||||
alert = obj.policy.get("alert", {}) or {}
|
||||
security = alert.get("security", {}) or {}
|
||||
config_pr = security.get("github-pr", {}) or {}
|
||||
|
||||
branch_prefix = config_pr.get("branch-prefix", "pyup/")
|
||||
pr_prefix = config_pr.get("pr-prefix", "[PyUp] ")
|
||||
assignees = config_pr.get("assignees", [])
|
||||
labels = config_pr.get("labels", ["security"])
|
||||
label_severity = config_pr.get("label-severity", True)
|
||||
ignore_cvss_severity_below = config_pr.get("ignore-cvss-severity-below", 0)
|
||||
ignore_cvss_unknown_severity = config_pr.get("ignore-cvss-unknown-severity", False)
|
||||
|
||||
# Authenticate with GitHub
|
||||
gh = pygithub.Github(token, **({"base_url": base_url} if base_url else {}))
|
||||
repo_name = repo
|
||||
repo = gh.get_repo(repo)
|
||||
try:
|
||||
self_user = gh.get_user().login
|
||||
except pygithub.GithubException:
|
||||
# If we're using a token from an action (or integration) we can't call `get_user()`. Fall back
|
||||
# to assuming we're running under an action
|
||||
self_user = "web-flow"
|
||||
|
||||
# Collect all remediations from the report
|
||||
req_remediations = list(
|
||||
itertools.chain.from_iterable(
|
||||
rem.get("requirements", {}).values()
|
||||
for pkg_name, rem in obj.report["remediations"].items()
|
||||
)
|
||||
)
|
||||
|
||||
# Get all open pull requests for the repository
|
||||
pulls = repo.get_pulls(state="open", sort="created", base=repo.default_branch)
|
||||
pending_updates = set(
|
||||
[
|
||||
f"{canonicalize_name(req_rem['requirement']['name'])}{req_rem['requirement']['specifier']}"
|
||||
for req_rem in req_remediations
|
||||
]
|
||||
)
|
||||
|
||||
created = 0
|
||||
|
||||
# TODO: Refactor this loop into a fn to iterate over remediations nicely
|
||||
# Iterate over all requirements files and process each remediation
|
||||
for name, contents in obj.requirements_files.items():
|
||||
raw_contents = contents
|
||||
contents = contents.decode("utf-8") # TODO - encoding?
|
||||
parsed_req_file = requirements.RequirementFile(name, contents)
|
||||
|
||||
for remediation in req_remediations:
|
||||
pkg = remediation["requirement"]["name"]
|
||||
pkg_canonical_name: str = canonicalize_name(pkg)
|
||||
analyzed_spec: str = remediation["requirement"]["specifier"]
|
||||
|
||||
# Skip remediations without a recommended version
|
||||
if remediation["recommended_version"] is None:
|
||||
LOG.debug(
|
||||
f"The GitHub PR alerter only currently supports remediations that have a recommended_version: {pkg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# We have a single remediation that can have multiple vulnerabilities
|
||||
# Find all vulnerabilities associated with the remediation
|
||||
vulns = [
|
||||
x
|
||||
for x in obj.report["vulnerabilities"]
|
||||
if x["package_name"] == pkg_canonical_name
|
||||
and x["analyzed_requirement"]["specifier"] == analyzed_spec
|
||||
]
|
||||
|
||||
# Skip if all vulnerabilities have unknown severity and the ignore flag is set
|
||||
if ignore_cvss_unknown_severity and all(
|
||||
x["severity"] is None for x in vulns
|
||||
):
|
||||
LOG.debug(
|
||||
"All vulnerabilities have unknown severity, and ignore_cvss_unknown_severity is set."
|
||||
)
|
||||
continue
|
||||
|
||||
highest_base_score = 0
|
||||
for vuln in vulns:
|
||||
if vuln["severity"] is not None:
|
||||
highest_base_score = max(
|
||||
highest_base_score,
|
||||
(vuln["severity"].get("cvssv3", {}) or {}).get(
|
||||
"base_score", 10
|
||||
),
|
||||
)
|
||||
|
||||
# Skip if none of the vulnerabilities meet the severity threshold
|
||||
if ignore_cvss_severity_below:
|
||||
at_least_one_match = False
|
||||
for vuln in vulns:
|
||||
# Consider a None severity as a match, since it's controlled by a different flag
|
||||
# If we can't find a base_score but we have severity data, assume it's critical for now.
|
||||
if (
|
||||
vuln["severity"] is None
|
||||
or (vuln["severity"].get("cvssv3", {}) or {}).get(
|
||||
"base_score", 10
|
||||
)
|
||||
>= ignore_cvss_severity_below
|
||||
):
|
||||
at_least_one_match = True
|
||||
|
||||
if not at_least_one_match:
|
||||
LOG.debug(
|
||||
f"None of the vulnerabilities found have a score greater than or equal to the ignore_cvss_severity_below of {ignore_cvss_severity_below}"
|
||||
)
|
||||
continue
|
||||
|
||||
for parsed_req in parsed_req_file.requirements:
|
||||
specs = (
|
||||
SpecifierSet(">=0")
|
||||
if parsed_req.specs == SpecifierSet("")
|
||||
else parsed_req.specs
|
||||
)
|
||||
|
||||
# Check if the requirement matches the remediation
|
||||
if (
|
||||
canonicalize_name(parsed_req.name) == pkg_canonical_name
|
||||
and str(specs) == analyzed_spec
|
||||
):
|
||||
updated_contents = parsed_req.update_version(
|
||||
contents, remediation["recommended_version"]
|
||||
)
|
||||
pending_updates.discard(f"{pkg_canonical_name}{analyzed_spec}")
|
||||
|
||||
new_branch = branch_prefix + utils.generate_branch_name(
|
||||
pkg, remediation
|
||||
)
|
||||
skip_create = False
|
||||
|
||||
# Few possible cases:
|
||||
# 1. No existing PRs exist for this change (don't need to handle)
|
||||
# 2. An existing PR exists, and it's out of date (eg, recommended 0.5.1 and we want 0.5.2)
|
||||
# 3. An existing PR exists, and it's not mergable anymore (eg, needs a rebase)
|
||||
# 4. An existing PR exists, and everything's up to date.
|
||||
# 5. An existing PR exists, but it's not needed anymore (perhaps we've been updated to a later version)
|
||||
# 6. No existing PRs exist, but a branch does exist (perhaps the PR was closed but a stale branch left behind)
|
||||
# In any case, we only act if we've been the only committer to the branch.
|
||||
# Handle various cases for existing pull requests
|
||||
for pr in pulls:
|
||||
if not pr.head.ref.startswith(branch_prefix):
|
||||
continue
|
||||
|
||||
authors = [
|
||||
commit.committer.login for commit in pr.get_commits()
|
||||
]
|
||||
only_us = all([x == self_user for x in authors])
|
||||
|
||||
try:
|
||||
_, pr_pkg, pr_spec, pr_ver = pr.head.ref.split("/")
|
||||
except ValueError:
|
||||
# It's possible that something weird has manually been done, so skip that
|
||||
# Skip invalid branch names
|
||||
LOG.debug(
|
||||
"Found an invalid branch name on an open PR, that matches our prefix. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
pr_pkg = canonicalize_name(pr_pkg)
|
||||
|
||||
if pr_pkg != pkg_canonical_name:
|
||||
continue
|
||||
|
||||
# Case 4: An up-to-date PR exists
|
||||
if (
|
||||
pr_pkg == pkg_canonical_name
|
||||
and pr_spec == analyzed_spec
|
||||
and pr_ver == remediation["recommended_version"]
|
||||
and pr.mergeable
|
||||
):
|
||||
LOG.debug(
|
||||
f"An up to date PR #{pr.number} for {pkg} was found, no action will be taken."
|
||||
)
|
||||
|
||||
skip_create = True
|
||||
continue
|
||||
|
||||
if not only_us:
|
||||
LOG.debug(
|
||||
f"There are other committers on the PR #{pr.number} for {pkg}. No further action will be taken."
|
||||
)
|
||||
continue
|
||||
|
||||
# Case 2: An existing PR is out of date
|
||||
if (
|
||||
pr_pkg == pkg_canonical_name
|
||||
and pr_spec == analyzed_spec
|
||||
and pr_ver != remediation["recommended_version"]
|
||||
):
|
||||
LOG.debug(
|
||||
f"Closing stale PR #{pr.number} for {pkg} as a newer recommended version became"
|
||||
)
|
||||
|
||||
pr.create_issue_comment(
|
||||
"This PR has been replaced, since a newer recommended version became available."
|
||||
)
|
||||
pr.edit(state="closed")
|
||||
delete_branch(repo, pr.head.ref)
|
||||
|
||||
# Case 3: An existing PR is not mergeable
|
||||
if not pr.mergeable:
|
||||
LOG.debug(
|
||||
f"Closing PR #{pr.number} for {pkg} as it has become unmergable and we were the only committer"
|
||||
)
|
||||
|
||||
pr.create_issue_comment(
|
||||
"This PR has been replaced since it became unmergable."
|
||||
)
|
||||
pr.edit(state="closed")
|
||||
delete_branch(repo, pr.head.ref)
|
||||
|
||||
# Skip if no changes were made
|
||||
if updated_contents == contents:
|
||||
LOG.debug(
|
||||
f"Couldn't update {pkg} to {remediation['recommended_version']}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Skip creation if indicated
|
||||
if skip_create:
|
||||
continue
|
||||
|
||||
# Create a new branch and commit the changes
|
||||
try:
|
||||
create_branch(repo, repo.default_branch, new_branch)
|
||||
except pygithub.GithubException as e:
|
||||
if e.data["message"] == "Reference already exists":
|
||||
# There might be a stale branch. If the bot is the only committer, nuke it.
|
||||
comparison = repo.compare(repo.default_branch, new_branch)
|
||||
authors = [
|
||||
commit.committer.login for commit in comparison.commits
|
||||
]
|
||||
only_us = all([x == self_user for x in authors])
|
||||
|
||||
if only_us:
|
||||
delete_branch(repo, new_branch)
|
||||
create_branch(repo, repo.default_branch, new_branch)
|
||||
else:
|
||||
LOG.debug(
|
||||
f"The branch '{new_branch}' already exists - but there is no matching PR and this branch has committers other than us. This remediation will be skipped."
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
try:
|
||||
repo.update_file(
|
||||
path=name,
|
||||
message=utils.generate_commit_message(pkg, remediation),
|
||||
content=updated_contents,
|
||||
branch=new_branch,
|
||||
sha=utils.git_sha1(raw_contents),
|
||||
)
|
||||
except pygithub.GithubException as e:
|
||||
if "does not match" in e.data["message"]:
|
||||
click.secho(
|
||||
f"GitHub blocked a commit on our branch to the requirements file, {name}, as the local hash we computed didn't match the version on {repo.default_branch}. Make sure you're running safety against the latest code on your default branch.",
|
||||
fg="red",
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
pr = repo.create_pull(
|
||||
title=pr_prefix + utils.generate_title(pkg, remediation, vulns),
|
||||
body=utils.generate_body(
|
||||
pkg, remediation, vulns, api_key=obj.key
|
||||
),
|
||||
head=new_branch,
|
||||
base=repo.default_branch,
|
||||
)
|
||||
LOG.debug(f"Created Pull Request to update {pkg}")
|
||||
|
||||
created += 1
|
||||
|
||||
# Add assignees and labels to the PR
|
||||
for assignee in assignees:
|
||||
pr.add_to_assignees(assignee)
|
||||
|
||||
for label in labels:
|
||||
pr.add_to_labels(label)
|
||||
|
||||
if label_severity:
|
||||
score_as_label = utils.cvss3_score_to_label(highest_base_score)
|
||||
if score_as_label:
|
||||
pr.add_to_labels(score_as_label)
|
||||
|
||||
if len(pending_updates) > 0:
|
||||
click.secho(
|
||||
"The following remediations were not followed: {}".format(
|
||||
", ".join(pending_updates)
|
||||
),
|
||||
fg="red",
|
||||
)
|
||||
|
||||
if created:
|
||||
click.secho(
|
||||
f"Safety successfully created {created} GitHub PR{'s' if created > 1 else ''} for repo {repo_name}"
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
"No PRs created; please run the command with debug mode for more information."
|
||||
)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--repo", help="GitHub standard repo path (eg, my-org/my-project)")
|
||||
@click.option("--token", help="GitHub Access Token")
|
||||
@click.option(
|
||||
"--base-url",
|
||||
help="Optional custom Base URL, if you're using GitHub enterprise",
|
||||
default=None,
|
||||
)
|
||||
@click.pass_obj
|
||||
@utils.require_files_report # TODO: For now, it can be removed in the future to support env scans.
|
||||
def github_issue(obj: Any, repo: str, token: str, base_url: Optional[str]) -> None:
|
||||
"""
|
||||
Create a GitHub Issue for any vulnerabilities found using PyUp's remediation data.
|
||||
|
||||
Normally, this is run by a GitHub action. If you're running this manually, ensure that your local repo is up to date and on HEAD - otherwise you'll see strange results.
|
||||
|
||||
Args:
|
||||
obj (Any): The Click context object containing report data.
|
||||
repo (str): The GitHub repository path.
|
||||
token (str): The GitHub Access Token.
|
||||
base_url (Optional[str]): Custom base URL for GitHub Enterprise, if applicable.
|
||||
"""
|
||||
LOG.info("github_issue")
|
||||
|
||||
if pygithub is None:
|
||||
click.secho(
|
||||
"pygithub is not installed. Did you install Safety with GitHub support? Try pip install safety[github]",
|
||||
fg="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Load alert configurations from the policy
|
||||
alert = obj.policy.get("alert", {}) or {}
|
||||
security = alert.get("security", {}) or {}
|
||||
config_issue = security.get("github-issue", {}) or {}
|
||||
|
||||
issue_prefix = config_issue.get("issue-prefix", "[PyUp] ")
|
||||
assignees = config_issue.get("assignees", [])
|
||||
labels = config_issue.get("labels", ["security"])
|
||||
|
||||
label_severity = config_issue.get("label-severity", True)
|
||||
ignore_cvss_severity_below = config_issue.get("ignore-cvss-severity-below", 0)
|
||||
ignore_cvss_unknown_severity = config_issue.get(
|
||||
"ignore-cvss-unknown-severity", False
|
||||
)
|
||||
|
||||
# Authenticate with GitHub
|
||||
gh = pygithub.Github(token, **({"base_url": base_url} if base_url else {}))
|
||||
repo_name = repo
|
||||
repo = gh.get_repo(repo)
|
||||
|
||||
# Get all open issues for the repository
|
||||
issues = list(repo.get_issues(state="open", sort="created"))
|
||||
ISSUE_TITLE_REGEX = re.escape(issue_prefix) + r"Security Vulnerability in (.+)"
|
||||
req_remediations = list(
|
||||
itertools.chain.from_iterable(
|
||||
rem.get("requirements", {}).values()
|
||||
for pkg_name, rem in obj.report["remediations"].items()
|
||||
)
|
||||
)
|
||||
|
||||
created = 0
|
||||
|
||||
# Iterate over all requirements files and process each remediation
|
||||
for name, contents in obj.requirements_files.items():
|
||||
contents = contents.decode("utf-8") # TODO - encoding?
|
||||
parsed_req_file = requirements.RequirementFile(name, contents)
|
||||
|
||||
for remediation in req_remediations:
|
||||
pkg: str = remediation["requirement"]["name"]
|
||||
pkg_canonical_name: str = canonicalize_name(pkg)
|
||||
analyzed_spec: str = remediation["requirement"]["specifier"]
|
||||
|
||||
# Skip remediations without a recommended version
|
||||
if remediation["recommended_version"] is None:
|
||||
LOG.debug(
|
||||
f"The GitHub Issue alerter only currently supports remediations that have a recommended_version: {pkg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# We have a single remediation that can have multiple vulnerabilities
|
||||
# Find all vulnerabilities associated with the remediation
|
||||
vulns = [
|
||||
x
|
||||
for x in obj.report["vulnerabilities"]
|
||||
if x["package_name"] == pkg_canonical_name
|
||||
and x["analyzed_requirement"]["specifier"] == analyzed_spec
|
||||
]
|
||||
|
||||
# Skip if all vulnerabilities have unknown severity and the ignore flag is set
|
||||
if ignore_cvss_unknown_severity and all(
|
||||
x["severity"] is None for x in vulns
|
||||
):
|
||||
LOG.debug(
|
||||
"All vulnerabilities have unknown severity, and ignore_cvss_unknown_severity is set."
|
||||
)
|
||||
continue
|
||||
|
||||
highest_base_score = 0
|
||||
for vuln in vulns:
|
||||
if vuln["severity"] is not None:
|
||||
highest_base_score = max(
|
||||
highest_base_score,
|
||||
(vuln["severity"].get("cvssv3", {}) or {}).get(
|
||||
"base_score", 10
|
||||
),
|
||||
)
|
||||
|
||||
# Skip if none of the vulnerabilities meet the severity threshold
|
||||
if ignore_cvss_severity_below:
|
||||
at_least_one_match = False
|
||||
for vuln in vulns:
|
||||
# Consider a None severity as a match, since it's controlled by a different flag
|
||||
# If we can't find a base_score but we have severity data, assume it's critical for now.
|
||||
if (
|
||||
vuln["severity"] is None
|
||||
or (vuln["severity"].get("cvssv3", {}) or {}).get(
|
||||
"base_score", 10
|
||||
)
|
||||
>= ignore_cvss_severity_below
|
||||
):
|
||||
at_least_one_match = True
|
||||
break
|
||||
|
||||
if not at_least_one_match:
|
||||
LOG.debug(
|
||||
f"None of the vulnerabilities found have a score greater than or equal to the ignore_cvss_severity_below of {ignore_cvss_severity_below}"
|
||||
)
|
||||
continue
|
||||
|
||||
for parsed_req in parsed_req_file.requirements:
|
||||
specs = (
|
||||
SpecifierSet(">=0")
|
||||
if parsed_req.specs == SpecifierSet("")
|
||||
else parsed_req.specs
|
||||
)
|
||||
if (
|
||||
canonicalize_name(parsed_req.name) == pkg_canonical_name
|
||||
and str(specs) == analyzed_spec
|
||||
):
|
||||
skip = False
|
||||
for issue in issues:
|
||||
match = re.match(ISSUE_TITLE_REGEX, issue.title)
|
||||
if match:
|
||||
group = match.group(1)
|
||||
if (
|
||||
group == f"{pkg}{analyzed_spec}"
|
||||
or group == f"{pkg_canonical_name}{analyzed_spec}"
|
||||
):
|
||||
skip = True
|
||||
break
|
||||
|
||||
# For now, we just skip issues if they already exist - we don't try and update them.
|
||||
# Skip if an issue already exists for this remediation
|
||||
if skip:
|
||||
LOG.debug(
|
||||
f"An issue already exists for {pkg}{analyzed_spec} - skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
# Create a new GitHub issue
|
||||
pr = repo.create_issue(
|
||||
title=issue_prefix
|
||||
+ utils.generate_issue_title(pkg, remediation),
|
||||
body=utils.generate_issue_body(
|
||||
pkg, remediation, vulns, api_key=obj.key
|
||||
),
|
||||
)
|
||||
created += 1
|
||||
LOG.debug(f"Created issue to update {pkg}")
|
||||
|
||||
# Add assignees and labels to the issue
|
||||
for assignee in assignees:
|
||||
pr.add_to_assignees(assignee)
|
||||
|
||||
for label in labels:
|
||||
pr.add_to_labels(label)
|
||||
|
||||
if label_severity:
|
||||
score_as_label = utils.cvss3_score_to_label(highest_base_score)
|
||||
if score_as_label:
|
||||
pr.add_to_labels(score_as_label)
|
||||
|
||||
if created:
|
||||
click.secho(
|
||||
f"Safety successfully created {created} new GitHub Issue{'s' if created > 1 else ''} for repo {repo_name}"
|
||||
)
|
||||
else:
|
||||
click.secho(
|
||||
"No issues created; please run the command with debug mode for more information."
|
||||
)
|
||||
@@ -0,0 +1,564 @@
|
||||
# type: ignore
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from packaging.version import parse as parse_version
|
||||
from packaging.specifiers import SpecifierSet
|
||||
import requests
|
||||
from typing import Any, Optional, Generator, Tuple, List
|
||||
from safety.meta import get_meta_http_headers
|
||||
|
||||
from datetime import datetime
|
||||
from dparse import parse, parser, updater, filetypes
|
||||
from dparse.dependencies import Dependency
|
||||
from dparse.parser import setuptools_parse_requirements_backport as parse_requirements
|
||||
|
||||
|
||||
class RequirementFile(object):
|
||||
"""
|
||||
Class representing a requirements file with its content and metadata.
|
||||
|
||||
Attributes:
|
||||
path (str): The file path.
|
||||
content (str): The content of the file.
|
||||
sha (Optional[str]): The SHA of the file.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str, content: str, sha: Optional[str] = None):
|
||||
self.path = path
|
||||
self.content = content
|
||||
self.sha = sha
|
||||
self._requirements: Optional[List] = None
|
||||
self._other_files: Optional[List] = None
|
||||
self._is_valid = None
|
||||
self.is_pipfile = False
|
||||
self.is_pipfile_lock = False
|
||||
self.is_setup_cfg = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
"RequirementFile(path='{path}', sha='{sha}', content='{content}')".format(
|
||||
path=self.path,
|
||||
content=self.content[:30] + "[truncated]"
|
||||
if len(self.content) > 30
|
||||
else self.content,
|
||||
sha=self.sha,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def is_valid(self) -> Optional[bool]:
|
||||
"""
|
||||
Checks if the requirements file is valid by parsing it.
|
||||
|
||||
Returns:
|
||||
bool: True if the file is valid, False otherwise.
|
||||
"""
|
||||
if self._is_valid is None:
|
||||
self._parse()
|
||||
return self._is_valid
|
||||
|
||||
@property
|
||||
def requirements(self) -> Optional[List]:
|
||||
"""
|
||||
Returns the list of requirements parsed from the file.
|
||||
|
||||
Returns:
|
||||
List: The list of requirements.
|
||||
"""
|
||||
if not self._requirements:
|
||||
self._parse()
|
||||
return self._requirements
|
||||
|
||||
@property
|
||||
def other_files(self) -> Optional[List]:
|
||||
"""
|
||||
Returns the list of other files resolved from the requirements file.
|
||||
|
||||
Returns:
|
||||
List: The list of other files.
|
||||
"""
|
||||
if not self._other_files:
|
||||
self._parse()
|
||||
return self._other_files
|
||||
|
||||
@staticmethod
|
||||
def parse_index_server(line: str) -> Optional[str]:
|
||||
"""
|
||||
Parses the index server from a given line.
|
||||
|
||||
Args:
|
||||
line (str): The line to parse.
|
||||
|
||||
Returns:
|
||||
str: The parsed index server.
|
||||
"""
|
||||
return parser.Parser.parse_index_server(line)
|
||||
|
||||
def _hash_parser(self, line: str) -> Optional[Tuple[str, List[str]]]:
|
||||
"""
|
||||
Parses the hashes from a given line.
|
||||
|
||||
Args:
|
||||
line (str): The line to parse.
|
||||
|
||||
Returns:
|
||||
List: The list of parsed hashes.
|
||||
"""
|
||||
return parser.Parser.parse_hashes(line)
|
||||
|
||||
def _parse_requirements_txt(self) -> None:
|
||||
"""
|
||||
Parses the requirements.txt file format.
|
||||
"""
|
||||
self.parse_dependencies(filetypes.requirements_txt)
|
||||
|
||||
def _parse_conda_yml(self) -> None:
|
||||
"""
|
||||
Parses the conda.yml file format.
|
||||
"""
|
||||
self.parse_dependencies(filetypes.conda_yml)
|
||||
|
||||
def _parse_tox_ini(self) -> None:
|
||||
"""
|
||||
Parses the tox.ini file format.
|
||||
"""
|
||||
self.parse_dependencies(filetypes.tox_ini)
|
||||
|
||||
def _parse_pipfile(self) -> None:
|
||||
"""
|
||||
Parses the Pipfile format.
|
||||
"""
|
||||
self.parse_dependencies(filetypes.pipfile)
|
||||
self.is_pipfile = True
|
||||
|
||||
def _parse_pipfile_lock(self) -> None:
|
||||
"""
|
||||
Parses the Pipfile.lock format.
|
||||
"""
|
||||
self.parse_dependencies(filetypes.pipfile_lock)
|
||||
self.is_pipfile_lock = True
|
||||
|
||||
def _parse_setup_cfg(self) -> None:
|
||||
"""
|
||||
Parses the setup.cfg format.
|
||||
"""
|
||||
self.parse_dependencies(filetypes.setup_cfg)
|
||||
self.is_setup_cfg = True
|
||||
|
||||
def _parse(self) -> None:
|
||||
"""
|
||||
Parses the requirements file to extract dependencies and other files.
|
||||
"""
|
||||
self._requirements, self._other_files = [], []
|
||||
if self.path.endswith(".yml") or self.path.endswith(".yaml"):
|
||||
self._parse_conda_yml()
|
||||
elif self.path.endswith(".ini"):
|
||||
self._parse_tox_ini()
|
||||
elif self.path.endswith("Pipfile"):
|
||||
self._parse_pipfile()
|
||||
elif self.path.endswith("Pipfile.lock"):
|
||||
self._parse_pipfile_lock()
|
||||
elif self.path.endswith("setup.cfg"):
|
||||
self._parse_setup_cfg()
|
||||
else:
|
||||
self._parse_requirements_txt()
|
||||
self._is_valid = len(self._requirements) > 0 or len(self._other_files) > 0
|
||||
|
||||
def parse_dependencies(self, file_type: str) -> None:
|
||||
"""
|
||||
Parses the dependencies from the content based on the file type.
|
||||
|
||||
Args:
|
||||
file_type (str): The type of the file.
|
||||
"""
|
||||
result = parse(
|
||||
self.content,
|
||||
path=self.path,
|
||||
sha=self.sha,
|
||||
file_type=file_type,
|
||||
marker=(
|
||||
("pyup: ignore file", "pyup:ignore file"), # file marker
|
||||
("pyup: ignore", "pyup:ignore"), # line marker
|
||||
),
|
||||
)
|
||||
for dep in result.dependencies:
|
||||
req = Requirement(
|
||||
name=dep.name,
|
||||
specs=dep.specs,
|
||||
line=dep.line,
|
||||
lineno=dep.line_numbers[0] if dep.line_numbers else 0,
|
||||
extras=dep.extras,
|
||||
file_type=file_type,
|
||||
)
|
||||
req.index_server = dep.index_server
|
||||
if self.is_pipfile:
|
||||
req.pipfile = self.path
|
||||
req.hashes = dep.hashes
|
||||
self._requirements.append(req)
|
||||
self._other_files = result.resolved_files
|
||||
|
||||
def iter_lines(self, lineno: int = 0) -> Generator[str, None, None]:
|
||||
"""
|
||||
Iterates over lines in the content starting from a specific line number.
|
||||
|
||||
Args:
|
||||
lineno (int): The line number to start from.
|
||||
|
||||
Yields:
|
||||
str: The next line in the content.
|
||||
"""
|
||||
for line in self.content.splitlines()[lineno:]:
|
||||
yield line
|
||||
|
||||
@classmethod
|
||||
def resolve_file(cls, file_path: str, line: str) -> str:
|
||||
"""
|
||||
Resolves a file path from a given line.
|
||||
|
||||
Args:
|
||||
file_path (str): The file path to resolve.
|
||||
line (str): The line containing the file path.
|
||||
|
||||
Returns:
|
||||
str: The resolved file path.
|
||||
"""
|
||||
return parser.Parser.resolve_file(file_path, line)
|
||||
|
||||
|
||||
class Requirement(object):
|
||||
"""
|
||||
Class representing a single requirement.
|
||||
|
||||
Attributes:
|
||||
name (str): The name of the requirement.
|
||||
specs (SpecifierSet): The version specifiers for the requirement.
|
||||
line (str): The line containing the requirement.
|
||||
lineno (int): The line number of the requirement.
|
||||
extras (List): The extras for the requirement.
|
||||
file_type (str): The type of the file containing the requirement.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
specs: SpecifierSet,
|
||||
line: str,
|
||||
lineno: int,
|
||||
extras: List,
|
||||
file_type: str,
|
||||
):
|
||||
self.name = name
|
||||
self.key = name.lower()
|
||||
self.specs = specs
|
||||
self.line = line
|
||||
self.lineno = lineno
|
||||
self.index_server = None
|
||||
self.extras = extras
|
||||
self.hashes = []
|
||||
self.file_type = file_type
|
||||
self.pipfile: Optional[str] = None
|
||||
|
||||
self.hashCmp = (
|
||||
self.key,
|
||||
self.specs,
|
||||
frozenset(self.extras),
|
||||
)
|
||||
|
||||
self._is_insecure = None
|
||||
self._changelog = None
|
||||
|
||||
# Convert compatible releases to a range of versions
|
||||
if (
|
||||
len(self.specs._specs) == 1
|
||||
and next(iter(self.specs._specs))._spec[0] == "~="
|
||||
):
|
||||
# convert compatible releases to something more easily consumed,
|
||||
# e.g. '~=1.2.3' is equivalent to '>=1.2.3,<1.3.0', while '~=1.2'
|
||||
# is equivalent to '>=1.2,<2.0'
|
||||
min_version = next(iter(self.specs._specs))._spec[1]
|
||||
max_version = list(parse_version(min_version).release)
|
||||
max_version[-1] = 0
|
||||
max_version[-2] = max_version[-2] + 1
|
||||
max_version = ".".join(str(x) for x in max_version)
|
||||
|
||||
self.specs = SpecifierSet(">=%s,<%s" % (min_version, max_version))
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Requirement) and self.hashCmp == other.hashCmp
|
||||
|
||||
def __ne__(self, other: Any) -> bool:
|
||||
return not self == other
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "Requirement.parse({line}, {lineno})".format(
|
||||
line=self.line, lineno=self.lineno
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
@property
|
||||
def is_pinned(self) -> bool:
|
||||
"""
|
||||
Checks if the requirement is pinned to a specific version.
|
||||
|
||||
Returns:
|
||||
bool: True if pinned, False otherwise.
|
||||
"""
|
||||
if (
|
||||
len(self.specs._specs) == 1
|
||||
and next(iter(self.specs._specs))._spec[0] == "=="
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_open_ranged(self) -> bool:
|
||||
"""
|
||||
Checks if the requirement has an open range of versions.
|
||||
|
||||
Returns:
|
||||
bool: True if open ranged, False otherwise.
|
||||
"""
|
||||
if (
|
||||
len(self.specs._specs) == 1
|
||||
and next(iter(self.specs._specs))._spec[0] == ">="
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_ranged(self) -> bool:
|
||||
"""
|
||||
Checks if the requirement has a range of versions.
|
||||
|
||||
Returns:
|
||||
bool: True if ranged, False otherwise.
|
||||
"""
|
||||
return len(self.specs._specs) >= 1 and not self.is_pinned
|
||||
|
||||
@property
|
||||
def is_loose(self) -> bool:
|
||||
"""
|
||||
Checks if the requirement has no version specifiers.
|
||||
|
||||
Returns:
|
||||
bool: True if loose, False otherwise.
|
||||
"""
|
||||
return len(self.specs._specs) == 0
|
||||
|
||||
@staticmethod
|
||||
def convert_semver(version: str) -> dict:
|
||||
"""
|
||||
Converts a version string to a semantic version dictionary.
|
||||
|
||||
Args:
|
||||
version (str): The version string.
|
||||
|
||||
Returns:
|
||||
dict: The semantic version dictionary.
|
||||
"""
|
||||
semver = {"major": 0, "minor": 0, "patch": 0}
|
||||
version_parts = version.split(".")
|
||||
# don't be overly clever here. repitition makes it more readable and works exactly how
|
||||
# it is supposed to
|
||||
try:
|
||||
semver["major"] = int(version_parts[0])
|
||||
semver["minor"] = int(version_parts[1])
|
||||
semver["patch"] = int(version_parts[2])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
return semver
|
||||
|
||||
@property
|
||||
def can_update_semver(self) -> bool:
|
||||
"""
|
||||
Checks if the requirement can be updated based on semantic versioning rules.
|
||||
|
||||
Returns:
|
||||
bool: True if it can be updated, False otherwise.
|
||||
"""
|
||||
# return early if there's no update filter set
|
||||
if "pyup: update" not in self.line:
|
||||
return True
|
||||
update = self.line.split("pyup: update")[1].strip().split("#")[0]
|
||||
current_version = Requirement.convert_semver(
|
||||
next(iter(self.specs._specs))._spec[1]
|
||||
)
|
||||
next_version = Requirement.convert_semver(self.latest_version) # type: ignore
|
||||
if update == "major":
|
||||
if current_version["major"] < next_version["major"]:
|
||||
return True
|
||||
elif update == "minor":
|
||||
if (
|
||||
current_version["major"] < next_version["major"]
|
||||
or current_version["minor"] < next_version["minor"]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def filter(self):
|
||||
"""
|
||||
Returns the filter for the requirement if specified.
|
||||
|
||||
Returns:
|
||||
Optional[SpecifierSet]: The filter specifier set, or None if not specified.
|
||||
"""
|
||||
rqfilter = False
|
||||
if "rq.filter:" in self.line:
|
||||
rqfilter = self.line.split("rq.filter:")[1].strip().split("#")[0]
|
||||
elif "pyup:" in self.line:
|
||||
if "pyup: update" not in self.line:
|
||||
rqfilter = self.line.split("pyup:")[1].strip().split("#")[0]
|
||||
# unset the filter once the date set in 'until' is reached
|
||||
if "until" in rqfilter:
|
||||
rqfilter, until = [part.strip() for part in rqfilter.split("until")]
|
||||
try:
|
||||
until = datetime.strptime(until, "%Y-%m-%d")
|
||||
if until < datetime.now():
|
||||
rqfilter = False
|
||||
except ValueError:
|
||||
# wrong date formatting
|
||||
pass
|
||||
if rqfilter:
|
||||
try:
|
||||
(rqfilter,) = parse_requirements("filter " + rqfilter)
|
||||
if len(rqfilter.specifier._specs) > 0:
|
||||
return rqfilter.specifier
|
||||
except ValueError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@property
|
||||
def version(self) -> Optional[str]:
|
||||
"""
|
||||
Returns the current version of the requirement.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The current version, or None if not pinned.
|
||||
"""
|
||||
if self.is_pinned:
|
||||
return next(iter(self.specs._specs))._spec[1]
|
||||
|
||||
specs = self.specs
|
||||
if self.filter:
|
||||
specs = SpecifierSet(
|
||||
",".join(
|
||||
[
|
||||
"".join(s._spec)
|
||||
for s in list(specs._specs) + list(self.filter._specs)
|
||||
]
|
||||
)
|
||||
)
|
||||
return self.get_latest_version_within_specs( # type: ignore
|
||||
specs,
|
||||
versions=self.package.versions,
|
||||
prereleases=self.prereleases, # type: ignore
|
||||
)
|
||||
|
||||
def get_hashes(self, version: str) -> List:
|
||||
"""
|
||||
Retrieves the hashes for a specific version from PyPI.
|
||||
|
||||
Args:
|
||||
version (str): The version to retrieve hashes for.
|
||||
|
||||
Returns:
|
||||
List: A list of hashes for the specified version.
|
||||
"""
|
||||
headers = get_meta_http_headers()
|
||||
r = requests.get(
|
||||
"https://pypi.org/pypi/{name}/{version}/json".format(
|
||||
name=self.key, version=version
|
||||
),
|
||||
headers=headers,
|
||||
)
|
||||
hashes = []
|
||||
data = r.json()
|
||||
|
||||
for item in data.get("urls", {}):
|
||||
sha256 = item.get("digests", {}).get("sha256", False)
|
||||
if sha256:
|
||||
hashes.append({"hash": sha256, "method": "sha256"})
|
||||
return hashes
|
||||
|
||||
def update_version(
|
||||
self, content: str, version: str, update_hashes: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Updates the version of the requirement in the content.
|
||||
|
||||
Args:
|
||||
content (str): The original content.
|
||||
version (str): The new version to update to.
|
||||
update_hashes (bool): Whether to update the hashes as well.
|
||||
|
||||
Returns:
|
||||
str: The updated content.
|
||||
"""
|
||||
if self.file_type == filetypes.tox_ini:
|
||||
updater_class = updater.ToxINIUpdater
|
||||
elif self.file_type == filetypes.conda_yml:
|
||||
updater_class = updater.CondaYMLUpdater
|
||||
elif self.file_type == filetypes.requirements_txt:
|
||||
updater_class = updater.RequirementsTXTUpdater
|
||||
elif self.file_type == filetypes.pipfile:
|
||||
updater_class = updater.PipfileUpdater
|
||||
elif self.file_type == filetypes.pipfile_lock:
|
||||
updater_class = updater.PipfileLockUpdater
|
||||
elif self.file_type == filetypes.setup_cfg:
|
||||
updater_class = updater.SetupCFGUpdater
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
dep = Dependency(
|
||||
name=self.name,
|
||||
specs=self.specs,
|
||||
line=self.line,
|
||||
line_numbers=[
|
||||
self.lineno,
|
||||
]
|
||||
if self.lineno != 0
|
||||
else None,
|
||||
dependency_type=self.file_type,
|
||||
hashes=self.hashes,
|
||||
extras=self.extras,
|
||||
)
|
||||
hashes = []
|
||||
if self.hashes and update_hashes:
|
||||
hashes = self.get_hashes(version)
|
||||
|
||||
return updater_class.update(
|
||||
content=content, dependency=dep, version=version, hashes=hashes, spec="=="
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(
|
||||
cls, s: str, lineno: int, file_type: str = filetypes.requirements_txt
|
||||
) -> "Requirement":
|
||||
"""
|
||||
Parses a requirement from a line of text.
|
||||
|
||||
Args:
|
||||
s (str): The line of text.
|
||||
lineno (int): The line number.
|
||||
file_type (str): The type of the file containing the requirement.
|
||||
|
||||
Returns:
|
||||
Requirement: The parsed requirement.
|
||||
"""
|
||||
# setuptools requires a space before the comment. If this isn't the case, add it.
|
||||
if "\t#" in s:
|
||||
(parsed,) = parse_requirements(s.replace("\t#", "\t #"))
|
||||
else:
|
||||
(parsed,) = parse_requirements(s)
|
||||
|
||||
return cls(
|
||||
name=parsed.name,
|
||||
specs=parsed.specifier,
|
||||
line=s,
|
||||
lineno=lineno,
|
||||
extras=list(parsed.extras),
|
||||
file_type=file_type,
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
Safety has detected a vulnerable package, [{{ pkg }}]({{ remediation['more_info_url'] }}), that should be updated from **{% if remediation['version'] %}{{ remediation['version'] }}{% else %}{{ remediation['requirement']['specifier'] }}{% endif %}** to **{{ remediation['recommended_version'] }}** to fix {{ vulns | length }} vulnerabilit{{ "y" if vulns|length == 1 else "ies" }}{% if overall_impact %}{{ " rated " + overall_impact if vulns|length == 1 else " with the highest CVSS severity rating being " + overall_impact }}{% endif %}.
|
||||
|
||||
To read more about the impact of {{ "this vulnerability" if vulns|length == 1 else "these vulnerabilities" }} see [PyUp’s {{ pkg }} page]({{ remediation['more_info_url'] }}).
|
||||
|
||||
{{ hint }}
|
||||
|
||||
If you're using `pip`, you can run:
|
||||
|
||||
```
|
||||
pip install {{ pkg }}=={{ remediation['recommended_version'] }}
|
||||
# Followed by a pip freeze
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary>Vulnerabilities Found</summary>
|
||||
{% for vuln in vulns %}
|
||||
* {{ vuln.advisory }}
|
||||
{% if vuln.severity and vuln.severity.cvssv3 and vuln.severity.cvssv3.base_severity %}
|
||||
* This vulnerability was rated {{ vuln.severity.cvssv3.base_severity }} ({{ vuln.severity.cvssv3.base_score }}) on CVSSv3.
|
||||
{% endif %}
|
||||
* To read more about this vulnerability, see PyUp’s [vulnerability page]({{ vuln.more_info_url }})
|
||||
{% endfor %}
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Changelog from {{ remediation['requirement']['name'] }}{{ remediation['requirement']['specifier'] }} to {{ remediation['recommended_version'] }}</summary>
|
||||
{% if summary_changelog %}
|
||||
The full changelog is too long to post here. See [PyUp’s {{ pkg }} page]({{ remediation['more_info_url'] }}) for more information.
|
||||
{% else %}
|
||||
{% for version, log in changelog.items() %}
|
||||
### {{ version }}
|
||||
|
||||
```
|
||||
{{ log }}
|
||||
```
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Ignoring {{ "This Vulnerability" if vulns|length == 1 else "These Vulnerabilities" }}</summary>
|
||||
|
||||
If you wish to [ignore this vulnerability](https://docs.pyup.io/docs/safety-20-policy-file), you can add the following to `.safety-policy.yml` in this repo:
|
||||
|
||||
```
|
||||
security:
|
||||
ignore-vulnerabilities:{% for vuln in vulns %}
|
||||
{{ vuln.vulnerability_id }}:
|
||||
reason: enter a reason as to why you're ignoring this vulnerability
|
||||
expires: 'YYYY-MM-DD' # datetime string - date this ignore will expire
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -0,0 +1,47 @@
|
||||
Vulnerability fix: This PR updates [{{ pkg }}]({{ remediation['more_info_url'] }}) from **{% if remediation['version'] %}{{ remediation['version'] }}{% else %}{{ remediation['requirement']['specifier'] }}{% endif %}** to **{{ remediation['recommended_version'] }}** to fix {{ vulns | length }} vulnerabilit{{ "y" if vulns|length == 1 else "ies" }}{% if overall_impact %}{{ " rated " + overall_impact if vulns|length == 1 else " with the highest CVSS severity rating being " + overall_impact }}{% endif %}.
|
||||
|
||||
To read more about the impact of {{ "this vulnerability" if vulns|length == 1 else "these vulnerabilities" }} see [PyUp’s {{ pkg }} page]({{ remediation['more_info_url'] }}).
|
||||
|
||||
{{ hint }}
|
||||
|
||||
<details>
|
||||
<summary>Vulnerabilities Fixed</summary>
|
||||
{% for vuln in vulns %}
|
||||
* {{ vuln.advisory }}
|
||||
{% if vuln.severity and vuln.severity.cvssv3 and vuln.severity.cvssv3.base_severity %}
|
||||
* This vulnerability was rated {{ vuln.severity.cvssv3.base_severity }} ({{ vuln.severity.cvssv3.base_score }}) on CVSSv3.
|
||||
{% endif %}
|
||||
* To read more about this vulnerability, see PyUp’s [vulnerability page]({{ vuln.more_info_url }})
|
||||
{% endfor %}
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Changelog</summary>
|
||||
{% if summary_changelog %}
|
||||
The full changelog is too long to post here. See [PyUp’s {{ pkg }} page]({{ remediation['more_info_url'] }}) for more information.
|
||||
{% else %}
|
||||
{% for version, log in changelog.items() %}
|
||||
### {{ version }}
|
||||
|
||||
```
|
||||
{{ log }}
|
||||
```
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Ignoring {{ "This Vulnerability" if vulns|length == 1 else "These Vulnerabilities" }}</summary>
|
||||
|
||||
If you wish to [ignore this vulnerability](https://docs.pyup.io/docs/safety-20-policy-file), you can add the following to `.safety-policy.yml` in this repo:
|
||||
|
||||
```
|
||||
security:
|
||||
ignore-vulnerabilities:{% for vuln in vulns %}
|
||||
{{ vuln.vulnerability_id }}:
|
||||
reason: enter a reason as to why you're ignoring this vulnerability
|
||||
expires: 'YYYY-MM-DD' # datetime string - date this ignore will expire
|
||||
{% endfor %}
|
||||
```
|
||||
|
||||
</details>
|
||||
393
Backend/venv/lib/python3.12/site-packages/safety/alerts/utils.py
Normal file
393
Backend/venv/lib/python3.12/site-packages/safety/alerts/utils.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# type: ignore
|
||||
import hashlib
|
||||
import os
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from packaging.version import parse as parse_version
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
|
||||
# Jinja2 will only be installed if the optional deps are installed.
|
||||
# It's fine if our functions fail, but don't let this top level
|
||||
# import error out.
|
||||
from safety.models import is_pinned_requirement
|
||||
from safety.output_utils import (
|
||||
get_unpinned_hint,
|
||||
get_specifier_range_info,
|
||||
get_fix_hint_for_unpinned,
|
||||
)
|
||||
|
||||
try:
|
||||
import jinja2
|
||||
except ImportError:
|
||||
jinja2 = None
|
||||
|
||||
import requests
|
||||
from safety.meta import get_meta_http_headers
|
||||
|
||||
|
||||
def highest_base_score(vulns: List[Dict[str, Any]]) -> float:
|
||||
"""
|
||||
Calculates the highest CVSS base score from a list of vulnerabilities.
|
||||
|
||||
Args:
|
||||
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
|
||||
|
||||
Returns:
|
||||
float: The highest CVSS base score.
|
||||
"""
|
||||
highest_base_score = 0
|
||||
for vuln in vulns:
|
||||
if vuln["severity"] is not None:
|
||||
highest_base_score = max(
|
||||
highest_base_score,
|
||||
(vuln["severity"].get("cvssv3", {}) or {}).get("base_score", 10),
|
||||
)
|
||||
|
||||
return highest_base_score
|
||||
|
||||
|
||||
def generate_branch_name(pkg: str, remediation: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a branch name for a given package and remediation.
|
||||
|
||||
Args:
|
||||
pkg (str): The package name.
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
|
||||
Returns:
|
||||
str: The generated branch name.
|
||||
"""
|
||||
return f"{pkg}/{remediation['requirement']['specifier']}/{remediation['recommended_version']}"
|
||||
|
||||
|
||||
def generate_issue_title(pkg: str, remediation: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates an issue title for a given package and remediation.
|
||||
|
||||
Args:
|
||||
pkg (str): The package name.
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
|
||||
Returns:
|
||||
str: The generated issue title.
|
||||
"""
|
||||
return f"Security Vulnerability in {pkg}{remediation['requirement']['specifier']}"
|
||||
|
||||
|
||||
def get_hint(remediation: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a hint for a given remediation.
|
||||
|
||||
Args:
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
|
||||
Returns:
|
||||
str: The generated hint.
|
||||
"""
|
||||
pinned = is_pinned_requirement(
|
||||
SpecifierSet(remediation["requirement"]["specifier"])
|
||||
)
|
||||
hint = ""
|
||||
|
||||
if not pinned:
|
||||
fix_hint = get_fix_hint_for_unpinned(remediation)
|
||||
hint = (
|
||||
f"{fix_hint}\n\n{get_unpinned_hint(remediation['requirement']['name'])} "
|
||||
f"{get_specifier_range_info(style=False)}"
|
||||
)
|
||||
|
||||
return hint
|
||||
|
||||
|
||||
def generate_title(
|
||||
pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
Generates a title for a pull request or issue.
|
||||
|
||||
Args:
|
||||
pkg (str): The package name.
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
|
||||
|
||||
Returns:
|
||||
str: The generated title.
|
||||
"""
|
||||
suffix = "y" if len(vulns) == 1 else "ies"
|
||||
from_dependency = (
|
||||
remediation["version"]
|
||||
if remediation["version"]
|
||||
else remediation["requirement"]["specifier"]
|
||||
)
|
||||
return f"Update {pkg} from {from_dependency} to {remediation['recommended_version']} to fix {len(vulns)} vulnerabilit{suffix}"
|
||||
|
||||
|
||||
def generate_body(
|
||||
pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]], *, api_key: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generates the body content for a pull request.
|
||||
|
||||
Args:
|
||||
pkg (str): The package name.
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
|
||||
api_key (str): The API key for fetching changelog data.
|
||||
|
||||
Returns:
|
||||
str: The generated body content.
|
||||
"""
|
||||
changelog = fetch_changelog(
|
||||
pkg,
|
||||
remediation["version"],
|
||||
remediation["recommended_version"],
|
||||
api_key=api_key,
|
||||
from_spec=remediation.get("requirement", {}).get("specifier", None),
|
||||
)
|
||||
|
||||
p = Path(__file__).parent / "templates"
|
||||
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(p))) # type: ignore
|
||||
template = env.get_template("pr.jinja2")
|
||||
|
||||
overall_impact = cvss3_score_to_label(highest_base_score(vulns))
|
||||
|
||||
context = {
|
||||
"pkg": pkg,
|
||||
"remediation": remediation,
|
||||
"vulns": vulns,
|
||||
"changelog": changelog,
|
||||
"overall_impact": overall_impact,
|
||||
"summary_changelog": False,
|
||||
"hint": get_hint(remediation),
|
||||
}
|
||||
|
||||
result = template.render(context)
|
||||
|
||||
# GitHub has a PR body length limit of 65536. If we're going over that, skip the changelog and just use a link.
|
||||
if len(result) < 65500:
|
||||
return result
|
||||
|
||||
context["summary_changelog"] = True
|
||||
|
||||
return template.render(context)
|
||||
|
||||
|
||||
def generate_issue_body(
|
||||
pkg: str, remediation: Dict[str, Any], vulns: List[Dict[str, Any]], *, api_key: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Generates the body content for an issue.
|
||||
|
||||
Args:
|
||||
pkg (str): The package name.
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
vulns (List[Dict[str, Any]]): The list of vulnerabilities.
|
||||
api_key (str): The API key for fetching changelog data.
|
||||
|
||||
Returns:
|
||||
str: The generated body content.
|
||||
"""
|
||||
changelog = fetch_changelog(
|
||||
pkg,
|
||||
remediation["version"],
|
||||
remediation["recommended_version"],
|
||||
api_key=api_key,
|
||||
from_spec=remediation.get("requirement", {}).get("specifier", None),
|
||||
)
|
||||
|
||||
p = Path(__file__).parent / "templates"
|
||||
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(p))) # type: ignore
|
||||
template = env.get_template("issue.jinja2")
|
||||
|
||||
overall_impact = cvss3_score_to_label(highest_base_score(vulns))
|
||||
|
||||
context = {
|
||||
"pkg": pkg,
|
||||
"remediation": remediation,
|
||||
"vulns": vulns,
|
||||
"changelog": changelog,
|
||||
"overall_impact": overall_impact,
|
||||
"summary_changelog": False,
|
||||
"hint": get_hint(remediation),
|
||||
}
|
||||
|
||||
result = template.render(context)
|
||||
|
||||
# GitHub has a PR body length limit of 65536. If we're going over that, skip the changelog and just use a link.
|
||||
if len(result) < 65500:
|
||||
return result
|
||||
|
||||
context["summary_changelog"] = True
|
||||
|
||||
return template.render(context)
|
||||
|
||||
|
||||
def generate_commit_message(pkg: str, remediation: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Generates a commit message for a given package and remediation.
|
||||
|
||||
Args:
|
||||
pkg (str): The package name.
|
||||
remediation (Dict[str, Any]): The remediation data.
|
||||
|
||||
Returns:
|
||||
str: The generated commit message.
|
||||
"""
|
||||
from_dependency = (
|
||||
remediation["version"]
|
||||
if remediation["version"]
|
||||
else remediation["requirement"]["specifier"]
|
||||
)
|
||||
|
||||
return (
|
||||
f"Update {pkg} from {from_dependency} to {remediation['recommended_version']}"
|
||||
)
|
||||
|
||||
|
||||
def git_sha1(raw_contents: bytes) -> str:
|
||||
"""
|
||||
Calculates the SHA-1 hash of the given raw contents.
|
||||
|
||||
Args:
|
||||
raw_contents (bytes): The raw contents to hash.
|
||||
|
||||
Returns:
|
||||
str: The SHA-1 hash.
|
||||
"""
|
||||
return hashlib.sha1(
|
||||
b"blob " + str(len(raw_contents)).encode("ascii") + b"\0" + raw_contents
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def fetch_changelog(
|
||||
package: str,
|
||||
from_version: Optional[str],
|
||||
to_version: str,
|
||||
*,
|
||||
api_key: str,
|
||||
from_spec: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetches the changelog for a package from a specified version to another version.
|
||||
|
||||
Args:
|
||||
package (str): The package name.
|
||||
from_version (Optional[str]): The starting version.
|
||||
to_version (str): The ending version.
|
||||
api_key (str): The API key for fetching changelog data.
|
||||
from_spec (Optional[str]): The specifier for the starting version.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The fetched changelog data.
|
||||
"""
|
||||
to_version_parsed = parse_version(to_version)
|
||||
|
||||
if from_version:
|
||||
from_version_parsed = parse_version(from_version)
|
||||
else:
|
||||
from_version_parsed = None
|
||||
from_spec = SpecifierSet(from_spec)
|
||||
|
||||
changelog = {}
|
||||
|
||||
headers = {"X-Api-Key": api_key}
|
||||
headers.update(get_meta_http_headers())
|
||||
|
||||
r = requests.get(
|
||||
"https://pyup.io/api/v1/changelogs/{}/".format(package), headers=headers
|
||||
)
|
||||
|
||||
if r.status_code == 200:
|
||||
data = r.json()
|
||||
if data:
|
||||
# sort the changelog by release
|
||||
sorted_log = sorted(
|
||||
data.items(), key=lambda v: parse_version(v[0]), reverse=True
|
||||
)
|
||||
|
||||
# go over each release and add it to the log if it's within the "upgrade
|
||||
# range" e.g. update from 1.2 to 1.3 includes a changelog for 1.2.1 but
|
||||
# not for 0.4.
|
||||
for version, log in sorted_log:
|
||||
parsed_version = parse_version(version)
|
||||
version_check = from_version and (parsed_version > from_version_parsed)
|
||||
spec_check = (
|
||||
from_spec
|
||||
and isinstance(from_spec, SpecifierSet)
|
||||
and from_spec.contains(parsed_version)
|
||||
)
|
||||
|
||||
if version_check or spec_check and parsed_version <= to_version_parsed:
|
||||
changelog[version] = log
|
||||
|
||||
return changelog
|
||||
|
||||
|
||||
def cvss3_score_to_label(score: float) -> Optional[str]:
|
||||
"""
|
||||
Converts a CVSS v3 score to a severity label.
|
||||
|
||||
Args:
|
||||
score (float): The CVSS v3 score.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The severity label.
|
||||
"""
|
||||
if 0.1 <= score <= 3.9:
|
||||
return "low"
|
||||
elif 4.0 <= score <= 6.9:
|
||||
return "medium"
|
||||
elif 7.0 <= score <= 8.9:
|
||||
return "high"
|
||||
elif 9.0 <= score <= 10.0:
|
||||
return "critical"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def require_files_report(func):
|
||||
@wraps(func)
|
||||
def inner(obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Decorator that ensures a report is generated against a file.
|
||||
|
||||
Args:
|
||||
obj (Any): The object containing the report.
|
||||
*args (Any): Additional arguments.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Any: The result of the decorated function.
|
||||
"""
|
||||
if obj.report["report_meta"]["scan_target"] != "files":
|
||||
click.secho(
|
||||
"This report was generated against an environment, but this alert command requires "
|
||||
"a scan report that was generated against a file. To learn more about the "
|
||||
"`safety alert` command visit https://docs.pyup.io/docs/safety-2-alerts",
|
||||
fg="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
files = obj.report["report_meta"]["scanned"]
|
||||
obj.requirements_files = {}
|
||||
for f in files:
|
||||
if not os.path.exists(f):
|
||||
cwd = os.getcwd()
|
||||
click.secho(
|
||||
"A requirements file scanned in the report, {}, does not exist (looking in {}).".format(
|
||||
f, cwd
|
||||
),
|
||||
fg="red",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
obj.requirements_files[f] = open(f, "rb").read()
|
||||
|
||||
return func(obj, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
@@ -0,0 +1,88 @@
|
||||
import sys
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_asyncio_patch():
|
||||
"""
|
||||
Apply a patch to asyncio's exception handling for subprocesses.
|
||||
|
||||
There are some issues with asyncio's exception handling for subprocesses,
|
||||
which causes a RuntimeError to be raised when the event loop was already closed.
|
||||
|
||||
This patch catches the RuntimeError and ignores it, which allows the event loop
|
||||
to be closed properly.
|
||||
|
||||
Similar issues:
|
||||
- https://bugs.python.org/issue39232
|
||||
- https://github.com/python/cpython/issues/92841
|
||||
"""
|
||||
|
||||
import asyncio.base_subprocess
|
||||
|
||||
original_subprocess_del = asyncio.base_subprocess.BaseSubprocessTransport.__del__
|
||||
|
||||
def patched_subprocess_del(self):
|
||||
try:
|
||||
original_subprocess_del(self)
|
||||
except (RuntimeError, ValueError, OSError) as e:
|
||||
if isinstance(e, RuntimeError) and str(e) != "Event loop is closed":
|
||||
raise
|
||||
if isinstance(e, ValueError) and str(e) != "I/O operation on closed pipe":
|
||||
raise
|
||||
if isinstance(e, OSError) and "[WinError 6]" not in str(e):
|
||||
raise
|
||||
logger.debug(f"Patched {original_subprocess_del}")
|
||||
|
||||
asyncio.base_subprocess.BaseSubprocessTransport.__del__ = patched_subprocess_del
|
||||
|
||||
if sys.platform == "win32":
|
||||
import asyncio.proactor_events as proactor_events
|
||||
|
||||
original_pipe_del = proactor_events._ProactorBasePipeTransport.__del__
|
||||
|
||||
def patched_pipe_del(self):
|
||||
try:
|
||||
original_pipe_del(self)
|
||||
except (RuntimeError, ValueError) as e:
|
||||
if isinstance(e, RuntimeError) and str(e) != "Event loop is closed":
|
||||
raise
|
||||
if (
|
||||
isinstance(e, ValueError)
|
||||
and str(e) != "I/O operation on closed pipe"
|
||||
):
|
||||
raise
|
||||
logger.debug(f"Patched {original_pipe_del}")
|
||||
|
||||
original_repr = proactor_events._ProactorBasePipeTransport.__repr__
|
||||
|
||||
def patched_repr(self):
|
||||
try:
|
||||
return original_repr(self)
|
||||
except ValueError as e:
|
||||
if str(e) != "I/O operation on closed pipe":
|
||||
raise
|
||||
logger.debug(f"Patched {original_repr}")
|
||||
return f"<{self.__class__} [closed]>"
|
||||
|
||||
proactor_events._ProactorBasePipeTransport.__del__ = patched_pipe_del
|
||||
proactor_events._ProactorBasePipeTransport.__repr__ = patched_repr
|
||||
|
||||
import subprocess
|
||||
|
||||
original_popen_del = subprocess.Popen.__del__
|
||||
|
||||
def patched_popen_del(self):
|
||||
try:
|
||||
original_popen_del(self)
|
||||
except OSError as e:
|
||||
if "[WinError 6]" not in str(e):
|
||||
raise
|
||||
logger.debug(f"Patched {original_popen_del}")
|
||||
|
||||
subprocess.Popen.__del__ = patched_popen_del
|
||||
|
||||
|
||||
apply_asyncio_patch()
|
||||
@@ -0,0 +1,10 @@
|
||||
from .cli_utils import auth_options, build_client_session, proxy_options, inject_session
|
||||
from .cli import auth
|
||||
|
||||
__all__ = [
|
||||
"build_client_session",
|
||||
"proxy_options",
|
||||
"auth_options",
|
||||
"inject_session",
|
||||
"auth",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
402
Backend/venv/lib/python3.12/site-packages/safety/auth/cli.py
Normal file
402
Backend/venv/lib/python3.12/site-packages/safety/auth/cli.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# type: ignore
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from safety.auth.models import Auth
|
||||
from safety.auth.utils import initialize, is_email_verified
|
||||
from safety.console import main_console as console
|
||||
from safety.constants import (
|
||||
MSG_FINISH_REGISTRATION_TPL,
|
||||
MSG_VERIFICATION_HINT,
|
||||
DEFAULT_EPILOG,
|
||||
)
|
||||
from safety.meta import get_version
|
||||
from safety.decorators import notify
|
||||
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import typer
|
||||
from rich.padding import Padding
|
||||
from typer import Typer
|
||||
|
||||
from safety.auth.main import (
|
||||
clean_session,
|
||||
get_auth_info,
|
||||
get_authorization_data,
|
||||
get_token,
|
||||
)
|
||||
from safety.auth.server import process_browser_callback
|
||||
from safety.events.utils import emit_auth_started, emit_auth_completed
|
||||
from safety.util import initialize_event_bus
|
||||
from safety.scan.constants import (
|
||||
CLI_AUTH_COMMAND_HELP,
|
||||
CLI_AUTH_HEADLESS_HELP,
|
||||
CLI_AUTH_LOGIN_HELP,
|
||||
CLI_AUTH_LOGOUT_HELP,
|
||||
CLI_AUTH_STATUS_HELP,
|
||||
)
|
||||
|
||||
from ..cli_util import SafetyCLISubGroup, get_command_for, pass_safety_cli_obj
|
||||
from safety.error_handlers import handle_cmd_exception
|
||||
from .constants import (
|
||||
MSG_FAIL_LOGIN_AUTHED,
|
||||
MSG_FAIL_REGISTER_AUTHED,
|
||||
MSG_LOGOUT_DONE,
|
||||
MSG_LOGOUT_FAILED,
|
||||
MSG_NON_AUTHENTICATED,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
auth_app = Typer(rich_markup_mode="rich", name="auth")
|
||||
|
||||
|
||||
CMD_LOGIN_NAME = "login"
|
||||
CMD_REGISTER_NAME = "register"
|
||||
CMD_STATUS_NAME = "status"
|
||||
CMD_LOGOUT_NAME = "logout"
|
||||
DEFAULT_CMD = CMD_LOGIN_NAME
|
||||
|
||||
|
||||
@auth_app.callback(
|
||||
invoke_without_command=True,
|
||||
cls=SafetyCLISubGroup,
|
||||
help=CLI_AUTH_COMMAND_HELP,
|
||||
epilog=DEFAULT_EPILOG,
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
@pass_safety_cli_obj
|
||||
def auth(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Authenticate Safety CLI with your account.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
"""
|
||||
LOG.info("auth started")
|
||||
|
||||
# If no subcommand is invoked, forward to the default command
|
||||
if not ctx.invoked_subcommand:
|
||||
default_command = get_command_for(name=DEFAULT_CMD, typer_instance=auth_app)
|
||||
return ctx.forward(default_command)
|
||||
|
||||
|
||||
def fail_if_authenticated(ctx: typer.Context, with_msg: str) -> None:
|
||||
"""
|
||||
Exits the command if the user is already authenticated.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
with_msg (str): The message to display if authenticated.
|
||||
"""
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
if info:
|
||||
console.print()
|
||||
email = f"[green]{ctx.obj.auth.email}[/green]"
|
||||
if not ctx.obj.auth.email_verified:
|
||||
email = f"{email} {render_email_note(ctx.obj.auth)}"
|
||||
|
||||
console.print(with_msg.format(email=email))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def render_email_note(auth: Auth) -> str:
|
||||
"""
|
||||
Renders a note indicating whether email verification is required.
|
||||
|
||||
Args:
|
||||
auth (Auth): The Auth object.
|
||||
|
||||
Returns:
|
||||
str: The rendered email note.
|
||||
"""
|
||||
return "" if auth.email_verified else "[red](email verification required)[/red]"
|
||||
|
||||
|
||||
def render_successful_login(auth: Auth, organization: Optional[str] = None) -> None:
|
||||
"""
|
||||
Renders a message indicating a successful login.
|
||||
|
||||
Args:
|
||||
auth (Auth): The Auth object.
|
||||
organization (Optional[str]): The organization name.
|
||||
"""
|
||||
DEFAULT = "--"
|
||||
name = auth.name if auth.name else DEFAULT
|
||||
email = auth.email if auth.email else DEFAULT
|
||||
email_note = render_email_note(auth)
|
||||
console.print()
|
||||
console.print("[bold][green]You're authenticated[/green][/bold]")
|
||||
if name and name != email:
|
||||
details = [f"[green][bold]Account:[/bold] {name}, {email}[/green] {email_note}"]
|
||||
else:
|
||||
details = [f"[green][bold]Account:[/bold] {email}[/green] {email_note}"]
|
||||
|
||||
if organization:
|
||||
details.insert(0, f"[green][bold]Organization:[/bold] {organization}[green]")
|
||||
|
||||
for msg in details:
|
||||
console.print(Padding(msg, (0, 0, 0, 1)), emoji=True)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def login(
|
||||
ctx: typer.Context,
|
||||
headless: Annotated[
|
||||
Optional[bool],
|
||||
typer.Option(
|
||||
"--headless",
|
||||
help=CLI_AUTH_HEADLESS_HELP,
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Authenticate Safety CLI with your safetycli.com account using your default browser.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
headless (bool): Whether to run in headless mode.
|
||||
"""
|
||||
LOG.info("login started")
|
||||
headless = headless is True
|
||||
|
||||
# Check if the user is already authenticated
|
||||
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)
|
||||
|
||||
console.print()
|
||||
|
||||
info = None
|
||||
|
||||
brief_msg: str = (
|
||||
"Redirecting your browser to log in; once authenticated, "
|
||||
"return here to start using Safety"
|
||||
)
|
||||
|
||||
if ctx.obj.auth.org:
|
||||
console.print(
|
||||
f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] organization."
|
||||
)
|
||||
|
||||
if headless:
|
||||
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"
|
||||
|
||||
# Get authorization data and generate the authorization URL
|
||||
uri, initial_state = get_authorization_data(
|
||||
client=ctx.obj.auth.client,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
organization=ctx.obj.auth.org,
|
||||
headless=headless,
|
||||
)
|
||||
click.secho(brief_msg)
|
||||
click.echo()
|
||||
|
||||
emit_auth_started(ctx.obj.event_bus, ctx)
|
||||
# Process the browser callback to complete the authentication
|
||||
info = process_browser_callback(
|
||||
uri, initial_state=initial_state, ctx=ctx, headless=headless
|
||||
)
|
||||
|
||||
is_success = False
|
||||
error_msg = None
|
||||
|
||||
if info:
|
||||
if info.get("email", None):
|
||||
organization = None
|
||||
if ctx.obj.auth.org and ctx.obj.auth.org.name:
|
||||
organization = ctx.obj.auth.org.name
|
||||
ctx.obj.auth.refresh_from(info)
|
||||
if headless:
|
||||
console.print()
|
||||
|
||||
initialize(ctx, refresh=True)
|
||||
initialize_event_bus(ctx=ctx)
|
||||
render_successful_login(ctx.obj.auth, organization=organization)
|
||||
is_success = True
|
||||
|
||||
console.print()
|
||||
if ctx.obj.auth.org or ctx.obj.auth.email_verified:
|
||||
if not getattr(ctx.obj, "only_auth_msg", False):
|
||||
console.print(
|
||||
"[tip]Tip[/tip]: now try [bold]`safety scan`[/bold] in your project’s root "
|
||||
"folder to run a project scan or [bold]`safety -–help`[/bold] to learn more."
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
MSG_FINISH_REGISTRATION_TPL.format(email=ctx.obj.auth.email)
|
||||
)
|
||||
console.print()
|
||||
console.print(MSG_VERIFICATION_HINT)
|
||||
else:
|
||||
click.secho("Safety is now authenticated but your email is missing.")
|
||||
else:
|
||||
error_msg = ":stop_sign: [red]"
|
||||
if ctx.obj.auth.org:
|
||||
error_msg += (
|
||||
f"Error logging into {ctx.obj.auth.org.name} organization "
|
||||
f"with auth ID: {ctx.obj.auth.org.id}."
|
||||
)
|
||||
else:
|
||||
error_msg += "Error logging into Safety."
|
||||
|
||||
error_msg += (
|
||||
" Please try again, or use [bold]`safety auth -–help`[/bold] "
|
||||
"for more information[/red]"
|
||||
)
|
||||
|
||||
console.print(error_msg, emoji=True)
|
||||
|
||||
emit_auth_completed(
|
||||
ctx.obj.event_bus, ctx, success=is_success, error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_LOGOUT_NAME, help=CLI_AUTH_LOGOUT_HELP)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def logout(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Log out of your current session.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
"""
|
||||
LOG.info("logout started")
|
||||
|
||||
id_token = get_token("id_token")
|
||||
|
||||
msg = MSG_NON_AUTHENTICATED
|
||||
|
||||
if id_token:
|
||||
# Clean the session if an ID token is found
|
||||
if clean_session(ctx.obj.auth.client):
|
||||
msg = MSG_LOGOUT_DONE
|
||||
else:
|
||||
msg = MSG_LOGOUT_FAILED
|
||||
|
||||
console.print(msg)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_STATUS_NAME, help=CLI_AUTH_STATUS_HELP)
|
||||
@click.option(
|
||||
"--ensure-auth/--no-ensure-auth",
|
||||
default=False,
|
||||
help="This will keep running the command until anauthentication is made.",
|
||||
)
|
||||
@click.option(
|
||||
"--login-timeout",
|
||||
"-w",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Max time allowed to wait for an authentication.",
|
||||
)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def status(
|
||||
ctx: typer.Context, ensure_auth: bool = False, login_timeout: int = 600
|
||||
) -> None:
|
||||
"""
|
||||
Display Safety CLI's current authentication status.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
ensure_auth (bool): Whether to keep running until authentication is made.
|
||||
login_timeout (int): Max time allowed to wait for authentication.
|
||||
"""
|
||||
LOG.info("status started")
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
safety_version = get_version()
|
||||
console.print(f"[{current_time}]: Safety {safety_version}")
|
||||
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
initialize(ctx, refresh=True)
|
||||
|
||||
if ensure_auth:
|
||||
console.print("running: safety auth status --ensure-auth")
|
||||
console.print()
|
||||
|
||||
if info:
|
||||
verified = is_email_verified(info)
|
||||
email_status = " [red](email not verified)[/red]" if not verified else ""
|
||||
|
||||
console.print(f"[green]Authenticated as {info['email']}[/green]{email_status}")
|
||||
elif ensure_auth:
|
||||
console.print(
|
||||
"Safety is not authenticated. Launching default browser to log in"
|
||||
)
|
||||
console.print()
|
||||
uri, initial_state = get_authorization_data(
|
||||
client=ctx.obj.auth.client,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
organization=ctx.obj.auth.org,
|
||||
ensure_auth=ensure_auth,
|
||||
)
|
||||
|
||||
# Process the browser callback to complete the authentication
|
||||
info = process_browser_callback(
|
||||
uri, initial_state=initial_state, timeout=login_timeout, ctx=ctx
|
||||
)
|
||||
|
||||
if not info:
|
||||
console.print(
|
||||
f"[red]Timeout error ({login_timeout} seconds): not successfully authenticated without the timeout period.[/red]"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
organization = None
|
||||
if ctx.obj.auth.org and ctx.obj.auth.org.name:
|
||||
organization = ctx.obj.auth.org.name
|
||||
|
||||
render_successful_login(ctx.obj.auth, organization=organization)
|
||||
console.print()
|
||||
|
||||
else:
|
||||
console.print(MSG_NON_AUTHENTICATED)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_REGISTER_NAME)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def register(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Create a new user account for the safetycli.com service.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
"""
|
||||
LOG.info("register started")
|
||||
|
||||
# Check if the user is already authenticated
|
||||
fail_if_authenticated(ctx, with_msg=MSG_FAIL_REGISTER_AUTHED)
|
||||
|
||||
# Get authorization data and generate the registration URL
|
||||
uri, initial_state = get_authorization_data(
|
||||
client=ctx.obj.auth.client,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
sign_up=True,
|
||||
)
|
||||
|
||||
console.print(
|
||||
"\nRedirecting your browser to register for a free account. Once registered, return here to start using Safety."
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Process the browser callback to complete the registration
|
||||
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx)
|
||||
console.print()
|
||||
|
||||
if info:
|
||||
console.print(f"[green]Successfully registered {info.get('email')}[/green]")
|
||||
console.print()
|
||||
else:
|
||||
console.print("[red]Unable to register in this time, try again.[/red]")
|
||||
@@ -0,0 +1,290 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple, Any, Callable
|
||||
|
||||
import click
|
||||
|
||||
from .main import (
|
||||
get_auth_info,
|
||||
get_host_config,
|
||||
get_organization,
|
||||
get_proxy_config,
|
||||
get_redirect_url,
|
||||
get_token_data,
|
||||
save_auth_config,
|
||||
get_token,
|
||||
clean_session,
|
||||
)
|
||||
from authlib.common.security import generate_token
|
||||
from safety.auth.constants import CLIENT_ID
|
||||
|
||||
from safety.auth.models import Organization, Auth
|
||||
from safety.auth.utils import (
|
||||
S3PresignedAdapter,
|
||||
SafetyAuthSession,
|
||||
get_keys,
|
||||
is_email_verified,
|
||||
)
|
||||
from safety.scan.constants import (
|
||||
CLI_KEY_HELP,
|
||||
CLI_PROXY_HOST_HELP,
|
||||
CLI_PROXY_PORT_HELP,
|
||||
CLI_PROXY_PROTOCOL_HELP,
|
||||
CLI_STAGE_HELP,
|
||||
)
|
||||
from safety.util import DependentOption, SafetyContext, get_proxy_dict
|
||||
from safety.models import SafetyCLI
|
||||
from safety_schemas.models import Stage
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_client_session(
|
||||
api_key: Optional[str] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[SafetyAuthSession, Dict[str, Any]]:
|
||||
"""
|
||||
Builds and configures the client session for authentication.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): The API key for authentication.
|
||||
proxies (Optional[Dict[str, str]]): Proxy configuration.
|
||||
headers (Optional[Dict[str, str]]): Additional headers.
|
||||
|
||||
Returns:
|
||||
Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
target_proxies = proxies
|
||||
|
||||
# Global proxy defined in the config.ini
|
||||
proxy_config, proxy_timeout, proxy_required = get_proxy_config()
|
||||
|
||||
if not proxies:
|
||||
target_proxies = proxy_config
|
||||
|
||||
def update_token(tokens, **kwargs):
|
||||
save_auth_config(
|
||||
access_token=tokens["access_token"],
|
||||
id_token=tokens["id_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
)
|
||||
load_auth_session(click_ctx=click.get_current_context(silent=True)) # type: ignore
|
||||
|
||||
client_session = SafetyAuthSession(
|
||||
client_id=CLIENT_ID,
|
||||
code_challenge_method="S256",
|
||||
redirect_uri=get_redirect_url(),
|
||||
update_token=update_token,
|
||||
scope="openid email profile offline_access",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter())
|
||||
|
||||
client_session.proxy_required = proxy_required
|
||||
client_session.proxy_timeout = proxy_timeout
|
||||
client_session.proxies = target_proxies # type: ignore
|
||||
client_session.headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
openid_config = client_session.fetch_openid_config()
|
||||
|
||||
client_session.metadata["token_endpoint"] = openid_config.get(
|
||||
"token_endpoint", None
|
||||
)
|
||||
|
||||
if api_key:
|
||||
client_session.api_key = api_key # type: ignore
|
||||
client_session.headers["X-Api-Key"] = api_key
|
||||
|
||||
if headers:
|
||||
client_session.headers.update(headers)
|
||||
|
||||
return client_session, openid_config
|
||||
|
||||
|
||||
def load_auth_session(click_ctx: click.Context) -> None:
|
||||
"""
|
||||
Loads the authentication session from the context.
|
||||
|
||||
Args:
|
||||
click_ctx (click.Context): The Click context object.
|
||||
"""
|
||||
if not click_ctx:
|
||||
LOG.warning("Click context is needed to be able to load the Auth data.")
|
||||
return
|
||||
|
||||
client = click_ctx.obj.auth.client
|
||||
keys = click_ctx.obj.auth.keys
|
||||
|
||||
access_token: str = get_token(name="access_token") # type: ignore
|
||||
refresh_token: str = get_token(name="refresh_token") # type: ignore
|
||||
id_token: str = get_token(name="id_token") # type: ignore
|
||||
|
||||
if access_token and keys:
|
||||
try:
|
||||
token = get_token_data(access_token, keys, silent_if_expired=True)
|
||||
client.token = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"id_token": id_token,
|
||||
"token_type": "bearer",
|
||||
"expires_at": token.get("exp", None), # type: ignore
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
clean_session(client)
|
||||
|
||||
|
||||
def proxy_options(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator that defines proxy options for Click commands.
|
||||
|
||||
Options defined per command, this will override the proxy settings defined in the
|
||||
config.ini file.
|
||||
|
||||
Args:
|
||||
func (Callable): The Click command function.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped Click command function with proxy options.
|
||||
"""
|
||||
func = click.option(
|
||||
"--proxy-protocol",
|
||||
type=click.Choice(["http", "https"]),
|
||||
default="https",
|
||||
cls=DependentOption,
|
||||
required_options=["proxy_host"],
|
||||
help=CLI_PROXY_PROTOCOL_HELP,
|
||||
)(func)
|
||||
func = click.option(
|
||||
"--proxy-port",
|
||||
multiple=False,
|
||||
type=int,
|
||||
default=80,
|
||||
cls=DependentOption,
|
||||
required_options=["proxy_host"],
|
||||
help=CLI_PROXY_PORT_HELP,
|
||||
)(func)
|
||||
func = click.option(
|
||||
"--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP
|
||||
)(func)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def auth_options(stage: bool = True) -> Callable:
|
||||
"""
|
||||
Decorator that defines authentication options for Click commands.
|
||||
|
||||
Args:
|
||||
stage (bool): Whether to include the stage option.
|
||||
|
||||
Returns:
|
||||
Callable: The decorator function.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func = click.option(
|
||||
"--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP
|
||||
)(func)
|
||||
|
||||
if stage:
|
||||
func = click.option(
|
||||
"--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP
|
||||
)(func)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def inject_session(
|
||||
ctx: click.Context,
|
||||
proxy_protocol: Optional[str] = None,
|
||||
proxy_host: Optional[str] = None,
|
||||
proxy_port: Optional[str] = None,
|
||||
key: Optional[str] = None,
|
||||
stage: Optional[Stage] = None,
|
||||
invoked_command: str = "",
|
||||
) -> Any:
|
||||
org: Optional[Organization] = get_organization()
|
||||
|
||||
if not stage:
|
||||
host_stage = get_host_config(key_name="stage")
|
||||
stage = host_stage if host_stage else Stage.development
|
||||
|
||||
proxy_config: Optional[Dict[str, str]] = get_proxy_dict(
|
||||
proxy_protocol, # type: ignore
|
||||
proxy_host, # type: ignore
|
||||
proxy_port, # type: ignore
|
||||
)
|
||||
|
||||
client_session, openid_config = build_client_session(
|
||||
api_key=key, proxies=proxy_config
|
||||
)
|
||||
keys = get_keys(client_session, openid_config)
|
||||
|
||||
auth = Auth(
|
||||
stage=stage,
|
||||
keys=keys,
|
||||
org=org,
|
||||
client_id=CLIENT_ID, # type: ignore
|
||||
client=client_session,
|
||||
code_verifier=generate_token(48),
|
||||
)
|
||||
|
||||
if not ctx.obj:
|
||||
ctx.obj = SafetyCLI()
|
||||
|
||||
ctx.obj.auth = auth
|
||||
|
||||
load_auth_session(ctx)
|
||||
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
if info:
|
||||
ctx.obj.auth.name = info.get("name")
|
||||
ctx.obj.auth.email = info.get("email")
|
||||
ctx.obj.auth.email_verified = is_email_verified(info) # type: ignore
|
||||
SafetyContext().account = info["email"]
|
||||
else:
|
||||
SafetyContext().account = ""
|
||||
|
||||
@ctx.call_on_close
|
||||
def clean_up_on_close():
|
||||
LOG.debug("Closing requests session.")
|
||||
ctx.obj.auth.client.close()
|
||||
|
||||
if ctx.obj.event_bus:
|
||||
from safety.events.utils import (
|
||||
create_internal_event,
|
||||
InternalEventType,
|
||||
InternalPayload,
|
||||
)
|
||||
|
||||
payload = InternalPayload(ctx=ctx)
|
||||
|
||||
flush_event = create_internal_event(
|
||||
event_type=InternalEventType.FLUSH_SECURITY_TRACES, payload=payload
|
||||
)
|
||||
close_event = create_internal_event(
|
||||
event_type=InternalEventType.CLOSE_RESOURCES, payload=payload
|
||||
)
|
||||
|
||||
flush_future = ctx.obj.event_bus.emit(flush_event)
|
||||
close_future = ctx.obj.event_bus.emit(close_event)
|
||||
|
||||
# Wait for both events to be processed
|
||||
if flush_future and close_future:
|
||||
try:
|
||||
flush_future.result()
|
||||
close_future.result()
|
||||
except Exception as e:
|
||||
LOG.warning(f"Error waiting for events to process: {e}")
|
||||
|
||||
ctx.obj.event_bus.stop()
|
||||
@@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
|
||||
from safety.constants import USER_CONFIG_DIR, get_config_setting
|
||||
|
||||
AUTH_CONFIG_FILE_NAME = "auth.ini"
|
||||
AUTH_CONFIG_USER = USER_CONFIG_DIR / Path(AUTH_CONFIG_FILE_NAME)
|
||||
|
||||
|
||||
HOST: str = "localhost"
|
||||
|
||||
CLIENT_ID = get_config_setting("CLIENT_ID")
|
||||
AUTH_SERVER_URL = get_config_setting("AUTH_SERVER_URL")
|
||||
SAFETY_PLATFORM_URL = get_config_setting("SAFETY_PLATFORM_URL")
|
||||
|
||||
OPENID_CONFIG_URL = f"{AUTH_SERVER_URL}/.well-known/openid-configuration"
|
||||
|
||||
CLAIM_EMAIL_VERIFIED_API = "https://api.safetycli.com/email_verified"
|
||||
CLAIM_EMAIL_VERIFIED_AUTH_SERVER = "email_verified"
|
||||
|
||||
CLI_AUTH = f"{SAFETY_PLATFORM_URL}/cli/auth"
|
||||
CLI_AUTH_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/auth/success"
|
||||
CLI_AUTH_LOGOUT = f"{SAFETY_PLATFORM_URL}/cli/logout"
|
||||
CLI_CALLBACK = f"{SAFETY_PLATFORM_URL}/cli/callback"
|
||||
CLI_LOGOUT_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/logout/success"
|
||||
|
||||
MSG_NON_AUTHENTICATED = (
|
||||
"Safety is not authenticated. Please run 'safety auth login' to log in."
|
||||
)
|
||||
MSG_FAIL_LOGIN_AUTHED = """[green]You are authenticated as[/green] {email}.
|
||||
|
||||
To log into a different account, first logout via: safety auth logout, and then login again."""
|
||||
MSG_FAIL_REGISTER_AUTHED = "You are currently logged in to {email}, please logout using `safety auth logout` before registering a new account."
|
||||
|
||||
MSG_LOGOUT_DONE = "[green]Logout done.[/green]"
|
||||
MSG_LOGOUT_FAILED = "[red]Logout failed. Try again.[/red]"
|
||||
329
Backend/venv/lib/python3.12/site-packages/safety/auth/main.py
Normal file
329
Backend/venv/lib/python3.12/site-packages/safety/auth/main.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import configparser
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from authlib.oidc.core import CodeIDToken
|
||||
from authlib.jose import jwt
|
||||
from authlib.jose.errors import ExpiredTokenError
|
||||
|
||||
from safety.auth.models import Organization
|
||||
from safety.auth.constants import (
|
||||
CLI_AUTH_LOGOUT,
|
||||
CLI_CALLBACK,
|
||||
AUTH_CONFIG_USER,
|
||||
CLI_AUTH,
|
||||
)
|
||||
from safety.constants import CONFIG
|
||||
from safety_schemas.models import Stage
|
||||
from safety.util import get_proxy_dict
|
||||
|
||||
|
||||
def get_authorization_data(
|
||||
client,
|
||||
code_verifier: str,
|
||||
organization: Optional[Organization] = None,
|
||||
sign_up: bool = False,
|
||||
ensure_auth: bool = False,
|
||||
headless: bool = False,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate the authorization URL for the authentication process.
|
||||
|
||||
Args:
|
||||
client: The authentication client.
|
||||
code_verifier (str): The code verifier for the PKCE flow.
|
||||
organization (Optional[Organization]): The organization to authenticate with.
|
||||
sign_up (bool): Whether the URL is for sign-up.
|
||||
ensure_auth (bool): Whether to ensure authentication.
|
||||
headless (bool): Whether to run in headless mode.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The authorization URL and initial state.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"sign_up": sign_up,
|
||||
"locale": "en",
|
||||
"ensure_auth": ensure_auth,
|
||||
"headless": headless,
|
||||
}
|
||||
if organization:
|
||||
kwargs["organization"] = organization.id
|
||||
|
||||
return client.create_authorization_url(
|
||||
CLI_AUTH, code_verifier=code_verifier, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def get_logout_url(id_token: str) -> str:
|
||||
"""
|
||||
Generate the logout URL.
|
||||
|
||||
Args:
|
||||
id_token (str): The ID token.
|
||||
|
||||
Returns:
|
||||
str: The logout URL.
|
||||
"""
|
||||
return f"{CLI_AUTH_LOGOUT}?id_token={id_token}"
|
||||
|
||||
|
||||
def get_redirect_url() -> str:
|
||||
"""
|
||||
Get the redirect URL for the authentication callback.
|
||||
|
||||
Returns:
|
||||
str: The redirect URL.
|
||||
"""
|
||||
return CLI_CALLBACK
|
||||
|
||||
|
||||
def get_organization() -> Optional[Organization]:
|
||||
"""
|
||||
Retrieve the organization configuration.
|
||||
|
||||
Returns:
|
||||
Optional[Organization]: The organization object, or None if not configured.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
org_conf: Union[Dict[str, str], configparser.SectionProxy] = (
|
||||
config["organization"] if "organization" in config.sections() else {}
|
||||
)
|
||||
org_id: Optional[str] = (
|
||||
org_conf["id"].replace('"', "") if org_conf.get("id", None) else None
|
||||
)
|
||||
org_name: Optional[str] = (
|
||||
org_conf["name"].replace('"', "") if org_conf.get("name", None) else None
|
||||
)
|
||||
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
org = Organization(id=org_id, name=org_name) # type: ignore
|
||||
|
||||
return org
|
||||
|
||||
|
||||
def get_auth_info(ctx) -> Optional[Dict]:
|
||||
"""
|
||||
Retrieve the authentication information.
|
||||
|
||||
Args:
|
||||
ctx: The context object containing authentication data.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The authentication information, or None if not authenticated.
|
||||
"""
|
||||
from safety.auth.utils import is_email_verified
|
||||
|
||||
info = None
|
||||
if ctx.obj.auth.client.token:
|
||||
try:
|
||||
info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore
|
||||
|
||||
verified = is_email_verified(info) # type: ignore
|
||||
if not verified:
|
||||
user_info = ctx.obj.auth.client.fetch_user_info()
|
||||
verified = is_email_verified(user_info)
|
||||
|
||||
if verified:
|
||||
# refresh only if needed
|
||||
raise ExpiredTokenError
|
||||
|
||||
except ExpiredTokenError:
|
||||
# id_token expired. So fire a manually a refresh
|
||||
try:
|
||||
ctx.obj.auth.client.refresh_token(
|
||||
ctx.obj.auth.client.metadata.get("token_endpoint"),
|
||||
refresh_token=ctx.obj.auth.client.token.get("refresh_token"),
|
||||
)
|
||||
info = get_token_data(
|
||||
get_token(name="id_token"), # type: ignore
|
||||
keys=ctx.obj.auth.keys, # type: ignore
|
||||
)
|
||||
except Exception as _e:
|
||||
clean_session(ctx.obj.auth.client)
|
||||
except Exception as _g:
|
||||
clean_session(ctx.obj.auth.client)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def get_token_data(
|
||||
token: str, keys: Any, silent_if_expired: bool = False
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Decode and validate the token data.
|
||||
|
||||
Args:
|
||||
token (str): The token to decode.
|
||||
keys (Any): The keys to use for decoding.
|
||||
silent_if_expired (bool): Whether to silently ignore expired tokens.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The decoded token data, or None if invalid.
|
||||
"""
|
||||
claims = jwt.decode(token, keys, claims_cls=CodeIDToken)
|
||||
try:
|
||||
claims.validate()
|
||||
except ExpiredTokenError as e:
|
||||
if not silent_if_expired:
|
||||
raise e
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
def get_token(name: str = "access_token") -> Optional[str]:
|
||||
""" "
|
||||
Retrieve a token from the local authentication configuration.
|
||||
|
||||
This returns tokens saved in the local auth configuration.
|
||||
There are two types of tokens: access_token and id_token
|
||||
|
||||
Args:
|
||||
name (str): The name of the token to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The token value, or None if not found.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(AUTH_CONFIG_USER)
|
||||
|
||||
if "auth" in config.sections() and name in config["auth"]:
|
||||
value = config["auth"][name]
|
||||
if value:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_host_config(key_name: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a configuration value from the host configuration.
|
||||
|
||||
Args:
|
||||
key_name (str): The name of the configuration key.
|
||||
|
||||
Returns:
|
||||
Optional[Any]: The configuration value, or None if not found.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
if not config.has_section("host"):
|
||||
return None
|
||||
|
||||
host_section = dict(config.items("host"))
|
||||
|
||||
if key_name in host_section:
|
||||
if key_name == "stage":
|
||||
# Support old alias in the config.ini
|
||||
if host_section[key_name] == "dev":
|
||||
host_section[key_name] = "development"
|
||||
if host_section[key_name] not in {env.value for env in Stage}:
|
||||
return None
|
||||
return Stage(host_section[key_name])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def str_to_bool(s: str) -> bool:
|
||||
"""
|
||||
Convert a string to a boolean value.
|
||||
|
||||
Args:
|
||||
s (str): The string to convert.
|
||||
|
||||
Returns:
|
||||
bool: The converted boolean value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the string cannot be converted.
|
||||
"""
|
||||
if s.lower() == "true" or s == "1":
|
||||
return True
|
||||
elif s.lower() == "false" or s == "0":
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Cannot convert '{s}' to a boolean value.")
|
||||
|
||||
|
||||
def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]:
|
||||
"""
|
||||
Retrieve the proxy configuration.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
proxy_dictionary = None
|
||||
required = False
|
||||
timeout = None
|
||||
proxy = None
|
||||
|
||||
if config.has_section("proxy"):
|
||||
proxy = dict(config.items("proxy"))
|
||||
|
||||
if proxy:
|
||||
try:
|
||||
proxy_dictionary = get_proxy_dict(
|
||||
proxy["protocol"],
|
||||
proxy["host"],
|
||||
proxy["port"], # type: ignore
|
||||
)
|
||||
required = str_to_bool(proxy["required"])
|
||||
timeout = proxy["timeout"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return proxy_dictionary, timeout, required # type: ignore
|
||||
|
||||
|
||||
def clean_session(client) -> bool:
|
||||
"""
|
||||
Clean the authentication session.
|
||||
|
||||
Args:
|
||||
client: The authentication client.
|
||||
|
||||
Returns:
|
||||
bool: Always returns True.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""}
|
||||
|
||||
with open(AUTH_CONFIG_USER, "w") as configfile:
|
||||
config.write(configfile)
|
||||
|
||||
client.token = None
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def save_auth_config(
|
||||
access_token: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the authentication configuration.
|
||||
|
||||
Args:
|
||||
access_token (Optional[str]): The access token.
|
||||
id_token (Optional[str]): The ID token.
|
||||
refresh_token (Optional[str]): The refresh token.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(AUTH_CONFIG_USER)
|
||||
config["auth"] = { # type: ignore
|
||||
"access_token": access_token,
|
||||
"id_token": id_token,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
with open(AUTH_CONFIG_USER, "w") as configfile:
|
||||
config.write(configfile) # type: ignore
|
||||
105
Backend/venv/lib/python3.12/site-packages/safety/auth/models.py
Normal file
105
Backend/venv/lib/python3.12/site-packages/safety/auth/models.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from authlib.integrations.base_client import BaseOAuth
|
||||
|
||||
from safety_schemas.models import Stage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Organization:
|
||||
id: str
|
||||
name: str
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
Convert the Organization instance to a dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The dictionary representation of the organization.
|
||||
"""
|
||||
return {"id": self.id, "name": self.name}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Auth:
|
||||
org: Optional[Organization]
|
||||
keys: Any
|
||||
client: Any
|
||||
code_verifier: str
|
||||
client_id: str
|
||||
stage: Optional[Stage] = Stage.development
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
email_verified: bool = False
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Check if the authentication information is valid.
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise.
|
||||
"""
|
||||
if os.getenv("SAFETY_DB_DIR"):
|
||||
return True
|
||||
|
||||
if not self.client:
|
||||
return False
|
||||
|
||||
if self.client.api_key:
|
||||
return True
|
||||
|
||||
return bool(self.client.token and self.email_verified)
|
||||
|
||||
def refresh_from(self, info: Dict) -> None:
|
||||
"""
|
||||
Refresh the authentication information from the provided info.
|
||||
|
||||
Args:
|
||||
info (dict): The information to refresh from.
|
||||
"""
|
||||
from safety.auth.utils import is_email_verified
|
||||
|
||||
self.name = info.get("name")
|
||||
self.email = info.get("email")
|
||||
self.email_verified = is_email_verified(info) # type: ignore
|
||||
|
||||
def get_auth_method(self) -> str:
|
||||
"""
|
||||
Get the authentication method.
|
||||
|
||||
Returns:
|
||||
str: The authentication method.
|
||||
"""
|
||||
if self.client.api_key:
|
||||
return "API Key"
|
||||
|
||||
if self.client.token:
|
||||
return "Token"
|
||||
|
||||
return "None"
|
||||
|
||||
|
||||
class XAPIKeyAuth(BaseOAuth):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
"""
|
||||
Initialize the XAPIKeyAuth instance.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
|
||||
def __call__(self, r: Any) -> Any:
|
||||
"""
|
||||
Add the API key to the request headers.
|
||||
|
||||
Args:
|
||||
r (Any): The request object.
|
||||
|
||||
Returns:
|
||||
Any: The modified request object.
|
||||
"""
|
||||
r.headers["X-API-Key"] = self.api_key
|
||||
return r
|
||||
308
Backend/venv/lib/python3.12/site-packages/safety/auth/server.py
Normal file
308
Backend/venv/lib/python3.12/site-packages/safety/auth/server.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# type: ignore
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Optional, Dict, Tuple
|
||||
import urllib.parse
|
||||
import threading
|
||||
import click
|
||||
|
||||
from safety.auth.utils import is_jupyter_notebook
|
||||
from safety.console import main_console as console
|
||||
|
||||
from safety.auth.constants import (
|
||||
AUTH_SERVER_URL,
|
||||
CLI_AUTH_SUCCESS,
|
||||
CLI_LOGOUT_SUCCESS,
|
||||
HOST,
|
||||
)
|
||||
from safety.auth.main import save_auth_config
|
||||
from rich.prompt import Prompt
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_available_port() -> Optional[int]:
|
||||
"""
|
||||
Find an available port on localhost within the dynamic port range (49152-65536).
|
||||
|
||||
Returns:
|
||||
Optional[int]: An available port number, or None if no ports are available.
|
||||
"""
|
||||
# Dynamic ports IANA
|
||||
port_range = list(range(49152, 65536))
|
||||
random.shuffle(port_range)
|
||||
|
||||
for port in port_range:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.connect(("localhost", port))
|
||||
# If the connect succeeds, the port is already in use
|
||||
except socket.error:
|
||||
# If the connect fails, the port is available
|
||||
return port
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def auth_process(
|
||||
code: str, state: str, initial_state: str, code_verifier: str, client: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Process the authentication callback and exchange the authorization code for tokens.
|
||||
|
||||
Args:
|
||||
code (str): The authorization code.
|
||||
state (str): The state parameter from the callback.
|
||||
initial_state (str): The initial state parameter.
|
||||
code_verifier (str): The code verifier for PKCE.
|
||||
client (Any): The OAuth client.
|
||||
|
||||
Returns:
|
||||
Any: The user information.
|
||||
|
||||
Raises:
|
||||
SystemExit: If there is an error during authentication.
|
||||
"""
|
||||
err = None
|
||||
|
||||
if initial_state is None or initial_state != state:
|
||||
err = (
|
||||
"The state parameter value provided does not match the expected "
|
||||
"value. The state parameter is used to protect against Cross-Site "
|
||||
"Request Forgery (CSRF) attacks. For security reasons, the "
|
||||
"authorization process cannot proceed with an invalid state "
|
||||
"parameter value. Please try again, ensuring that the state "
|
||||
"parameter value provided in the authorization request matches "
|
||||
"the value returned in the callback."
|
||||
)
|
||||
|
||||
if err:
|
||||
click.secho(f"Error: {err}", fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
tokens = client.fetch_token(
|
||||
url=f"{AUTH_SERVER_URL}/oauth/token",
|
||||
code_verifier=code_verifier,
|
||||
client_id=client.client_id,
|
||||
grant_type="authorization_code",
|
||||
code=code,
|
||||
)
|
||||
|
||||
save_auth_config(
|
||||
access_token=tokens["access_token"],
|
||||
id_token=tokens["id_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
)
|
||||
return client.fetch_user_info()
|
||||
|
||||
except Exception as e:
|
||||
LOG.exception(e)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
||||
def auth(self, code: str, state: str, err: str, error_description: str) -> None:
|
||||
"""
|
||||
Handle the authentication callback.
|
||||
|
||||
Args:
|
||||
code (str): The authorization code.
|
||||
state (str): The state parameter.
|
||||
err (str): The error message, if any.
|
||||
error_description (str): The error description, if any.
|
||||
"""
|
||||
initial_state = self.server.initial_state
|
||||
ctx = self.server.ctx
|
||||
|
||||
result = auth_process(
|
||||
code=code,
|
||||
state=state,
|
||||
initial_state=initial_state,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
client=ctx.obj.auth.client,
|
||||
)
|
||||
|
||||
self.server.callback = result
|
||||
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})
|
||||
|
||||
def logout(self) -> None:
|
||||
"""
|
||||
Handle the logout callback.
|
||||
"""
|
||||
ctx = self.server.ctx
|
||||
uri = CLI_LOGOUT_SUCCESS
|
||||
|
||||
if ctx.obj.auth.org:
|
||||
uri = f"{uri}&org_id={ctx.obj.auth.org.id}"
|
||||
|
||||
self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
"""
|
||||
Handle GET requests.
|
||||
"""
|
||||
query = urllib.parse.urlparse(self.path).query
|
||||
params = urllib.parse.parse_qs(query)
|
||||
callback_type: Optional[str] = None
|
||||
|
||||
try:
|
||||
c_type = params.get("type", [])
|
||||
if (
|
||||
isinstance(c_type, list)
|
||||
and len(c_type) == 1
|
||||
and isinstance(c_type[0], str)
|
||||
):
|
||||
callback_type = c_type[0]
|
||||
except Exception:
|
||||
msg = "Unable to process the callback, try again."
|
||||
self.send_error(400, msg)
|
||||
click.secho("Unable to process the callback, try again.")
|
||||
return
|
||||
|
||||
if callback_type == "logout":
|
||||
self.logout()
|
||||
return
|
||||
|
||||
code = params.get("code", [""])[0]
|
||||
state = params.get("state", [""])[0]
|
||||
err = params.get("error", [""])[0]
|
||||
error_description = params.get("error_description", [""])[0]
|
||||
|
||||
self.auth(code=code, state=state, err=err, error_description=error_description)
|
||||
|
||||
def do_redirect(self, location: str, params: Dict) -> None:
|
||||
"""
|
||||
Redirect the client to the specified location.
|
||||
|
||||
Args:
|
||||
location (str): The URL to redirect to.
|
||||
params (dict): Additional parameters for the redirection.
|
||||
"""
|
||||
self.send_response(302)
|
||||
self.send_header("Location", location)
|
||||
self.send_header("Connection", "close")
|
||||
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
"""
|
||||
Log an arbitrary message.
|
||||
|
||||
Args:
|
||||
format (str): The format string.
|
||||
args (Any): Arguments for the format string.
|
||||
"""
|
||||
LOG.info(format % args)
|
||||
|
||||
|
||||
def process_browser_callback(uri: str, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Process the browser callback for authentication.
|
||||
|
||||
Args:
|
||||
uri (str): The authorization URL.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Any: The user information.
|
||||
|
||||
Raises:
|
||||
SystemExit: If there is an error during the process.
|
||||
"""
|
||||
|
||||
class ThreadedHTTPServer(http.server.HTTPServer):
|
||||
def __init__(self, server_address: Tuple, RequestHandlerClass: Any) -> None:
|
||||
"""
|
||||
Initialize the ThreadedHTTPServer.
|
||||
Args:
|
||||
server_address (Tuple): The server address as a tuple (host, port).
|
||||
RequestHandlerClass (Any): The request handler class.
|
||||
"""
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
self.initial_state = None
|
||||
self.ctx = None
|
||||
self.callback = None
|
||||
self.timeout_reached = False
|
||||
|
||||
def handle_timeout(self) -> None:
|
||||
"""
|
||||
Handle server timeout.
|
||||
"""
|
||||
self.timeout_reached = True
|
||||
return super().handle_timeout()
|
||||
|
||||
PORT = find_available_port()
|
||||
|
||||
if not PORT:
|
||||
click.secho("No available ports.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
headless = kwargs.get("headless", False)
|
||||
initial_state = kwargs.get("initial_state", None)
|
||||
ctx = kwargs.get("ctx", None)
|
||||
message = "Copy and paste this URL into your browser:\n:icon_warning: Ensure there are no extra spaces, especially at line breaks, as they may break the link."
|
||||
|
||||
if not headless:
|
||||
# Start a threaded HTTP server to handle the callback
|
||||
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
|
||||
server.initial_state = initial_state
|
||||
server.timeout = kwargs.get("timeout", 600)
|
||||
server.ctx = ctx
|
||||
server_thread = threading.Thread(target=server.handle_request)
|
||||
server_thread.start()
|
||||
message = "If the browser does not automatically open in 5 seconds, copy and paste this url into your browser:"
|
||||
|
||||
target = uri if headless else f"{uri}&port={PORT}"
|
||||
|
||||
if is_jupyter_notebook():
|
||||
console.print(f"{message} {target}")
|
||||
else:
|
||||
console.print(f"{message} [link={target}]{target}[/link]")
|
||||
|
||||
if headless:
|
||||
# Handle the headless mode where user manually provides the response
|
||||
exchange_data = None
|
||||
while not exchange_data:
|
||||
auth_code_text = Prompt.ask(
|
||||
"Paste the response here", default=None, console=console
|
||||
)
|
||||
try:
|
||||
exchange_data = json.loads(auth_code_text)
|
||||
state = exchange_data["state"]
|
||||
code = exchange_data["code"]
|
||||
except Exception:
|
||||
code = state = None
|
||||
|
||||
return auth_process(
|
||||
code=code,
|
||||
state=state,
|
||||
initial_state=initial_state,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
client=ctx.obj.auth.client,
|
||||
)
|
||||
else:
|
||||
# Wait for the browser authentication in non-headless mode
|
||||
console.print()
|
||||
wait_msg = "waiting for browser authentication"
|
||||
with console.status(wait_msg, spinner="bouncingBar"):
|
||||
time.sleep(2)
|
||||
click.launch(target)
|
||||
server_thread.join()
|
||||
|
||||
except OSError as e:
|
||||
if e.errno == socket.errno.EADDRINUSE:
|
||||
reason = f"The port {HOST}:{PORT} is currently being used by another application or process. Please choose a different port or terminate the conflicting application/process to free up the port."
|
||||
else:
|
||||
reason = "An error occurred while performing this operation."
|
||||
|
||||
click.secho(reason)
|
||||
sys.exit(1)
|
||||
|
||||
return server.callback
|
||||
756
Backend/venv/lib/python3.12/site-packages/safety/auth/utils.py
Normal file
756
Backend/venv/lib/python3.12/site-packages/safety/auth/utils.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# type: ignore
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, List, Literal
|
||||
|
||||
import requests
|
||||
from authlib.integrations.base_client.errors import OAuthError
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
from requests.adapters import HTTPAdapter
|
||||
from safety_schemas.models import STAGE_ID_MAPPING, Stage
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from safety.auth.constants import (
|
||||
AUTH_SERVER_URL,
|
||||
OPENID_CONFIG_URL,
|
||||
)
|
||||
from safety.constants import (
|
||||
PLATFORM_API_CHECK_UPDATES_ENDPOINT,
|
||||
PLATFORM_API_INITIALIZE_ENDPOINT,
|
||||
PLATFORM_API_POLICY_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_CHECK_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
|
||||
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
|
||||
REQUEST_TIMEOUT,
|
||||
FeatureType,
|
||||
get_config_setting,
|
||||
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT,
|
||||
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT,
|
||||
)
|
||||
from safety.error_handlers import output_exception
|
||||
from safety.errors import (
|
||||
InvalidCredentialError,
|
||||
NetworkConnectionError,
|
||||
RequestTimeoutError,
|
||||
SafetyError,
|
||||
ServerError,
|
||||
TooManyRequestsError,
|
||||
)
|
||||
from safety.meta import get_meta_http_headers
|
||||
from safety.models import SafetyCLI
|
||||
from safety.scan.util import AuthenticationType
|
||||
from safety.util import SafetyContext
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_keys(
|
||||
client_session: OAuth2Session, openid_config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve the keys from the OpenID configuration.
|
||||
|
||||
Args:
|
||||
client_session (OAuth2Session): The OAuth2 session.
|
||||
openid_config (Dict[str, Any]): The OpenID configuration.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The keys, if available.
|
||||
"""
|
||||
if "jwks_uri" in openid_config:
|
||||
return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def is_email_verified(info: Dict[str, Any]) -> Optional[bool]:
|
||||
"""
|
||||
Check if the email is verified.
|
||||
|
||||
Args:
|
||||
info (Dict[str, Any]): The user information.
|
||||
|
||||
Returns:
|
||||
bool: True
|
||||
"""
|
||||
# return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(
|
||||
# CLAIM_EMAIL_VERIFIED_AUTH_SERVER
|
||||
# )
|
||||
|
||||
# Always return True to avoid email verification
|
||||
return True
|
||||
|
||||
|
||||
def extract_detail(response: requests.Response) -> Optional[str]:
|
||||
"""
|
||||
Extract the reason from an HTTP response.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The response.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The reason.
|
||||
"""
|
||||
detail = None
|
||||
|
||||
try:
|
||||
detail = response.json().get("detail")
|
||||
except Exception:
|
||||
LOG.debug("Failed to extract detail from response: %s", response.status_code)
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def parse_response(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to parse the response from an HTTP request.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to wrap.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped function.
|
||||
"""
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
NetworkConnectionError,
|
||||
RequestTimeoutError,
|
||||
TooManyRequestsError,
|
||||
ServerError,
|
||||
)
|
||||
),
|
||||
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
r = func(*args, **kwargs)
|
||||
except OAuthError as e:
|
||||
LOG.exception("OAuth failed: %s", e)
|
||||
raise InvalidCredentialError(
|
||||
message="Your token authentication expired, try login again."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise NetworkConnectionError()
|
||||
except requests.exceptions.Timeout:
|
||||
raise RequestTimeoutError()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise e
|
||||
|
||||
# TODO: Handle content as JSON and fallback to text for all responses
|
||||
|
||||
if r.status_code == 403:
|
||||
reason = extract_detail(response=r)
|
||||
|
||||
raise InvalidCredentialError(
|
||||
credential="Failed authentication.", reason=reason
|
||||
)
|
||||
|
||||
if r.status_code == 429:
|
||||
raise TooManyRequestsError(reason=r.text)
|
||||
|
||||
if r.status_code >= 400 and r.status_code < 500:
|
||||
error_code = None
|
||||
try:
|
||||
data = r.json()
|
||||
reason = data.get("detail", "Unable to find reason.")
|
||||
error_code = data.get("error_code", None)
|
||||
except Exception:
|
||||
reason = r.reason
|
||||
|
||||
raise SafetyError(message=reason, error_code=error_code)
|
||||
|
||||
if r.status_code >= 500 and r.status_code < 600:
|
||||
reason = extract_detail(response=r)
|
||||
LOG.debug("ServerError %s -> Response returned: %s", r.status_code, r.text)
|
||||
raise ServerError(reason=reason)
|
||||
|
||||
data = None
|
||||
|
||||
try:
|
||||
data = r.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise ServerError(message=f"Bad JSON response from the server: {e}")
|
||||
|
||||
return data
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class SafetyAuthSession(OAuth2Session):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the SafetyAuthSession.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments for the parent class.
|
||||
**kwargs (Any): Keyword arguments for the parent class.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.proxy_required: bool = False
|
||||
self.proxy_timeout: Optional[int] = None
|
||||
self.api_key = None
|
||||
|
||||
def get_credential(self) -> Optional[str]:
|
||||
"""
|
||||
Get the current authentication credential.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The API key, token, or None.
|
||||
"""
|
||||
if self.api_key:
|
||||
return self.api_key
|
||||
|
||||
if self.token:
|
||||
return SafetyContext().account
|
||||
|
||||
return None
|
||||
|
||||
def is_using_auth_credentials(self) -> bool:
|
||||
"""
|
||||
Check if the session is using authentication credentials.
|
||||
|
||||
This does NOT check if the client is authenticated.
|
||||
|
||||
Returns:
|
||||
bool: True if using authentication credentials, False otherwise.
|
||||
"""
|
||||
return self.get_authentication_type() != AuthenticationType.none
|
||||
|
||||
def get_authentication_type(self) -> AuthenticationType:
|
||||
"""
|
||||
Get the type of authentication being used.
|
||||
|
||||
Returns:
|
||||
AuthenticationType: The type of authentication.
|
||||
"""
|
||||
if self.api_key:
|
||||
return AuthenticationType.api_key
|
||||
|
||||
if self.token:
|
||||
return AuthenticationType.token
|
||||
|
||||
return AuthenticationType.none
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
withhold_token: bool = False,
|
||||
auth: Optional[Tuple] = None,
|
||||
bearer: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make an HTTP request with the appropriate authentication.
|
||||
|
||||
Use the right auth parameter for Safety supported auth types.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method.
|
||||
url (str): The URL to request.
|
||||
withhold_token (bool): Whether to withhold the token.
|
||||
auth (Optional[Tuple]): The authentication tuple.
|
||||
bearer (bool): Whether to use bearer authentication.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
requests.Response: The HTTP response.
|
||||
|
||||
Raises:
|
||||
Exception: If the request fails.
|
||||
"""
|
||||
# By default use the token_auth
|
||||
TIMEOUT_KEYWARD = "timeout"
|
||||
func_timeout = (
|
||||
kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT
|
||||
)
|
||||
|
||||
if "headers" not in kwargs:
|
||||
kwargs["headers"] = {}
|
||||
|
||||
kwargs["headers"].update(get_meta_http_headers())
|
||||
|
||||
if self.api_key:
|
||||
kwargs["headers"]["X-Api-Key"] = self.api_key
|
||||
|
||||
if not self.token or not bearer:
|
||||
# Fallback to no token auth
|
||||
auth = ()
|
||||
|
||||
# Override proxies
|
||||
if self.proxies:
|
||||
kwargs["proxies"] = self.proxies
|
||||
|
||||
if self.proxy_timeout:
|
||||
kwargs["timeout"] = int(self.proxy_timeout) / 1000
|
||||
|
||||
if ("proxies" not in kwargs or not self.proxies) and self.proxy_required:
|
||||
output_exception(
|
||||
"Proxy connection is required but there is not a proxy setup.", # type: ignore
|
||||
exit_code_output=True,
|
||||
)
|
||||
|
||||
request_func = super(SafetyAuthSession, self).request
|
||||
params = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"withhold_token": withhold_token,
|
||||
"auth": auth,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
try:
|
||||
return request_func(**params)
|
||||
except Exception as e:
|
||||
LOG.debug("Request failed: %s", e)
|
||||
|
||||
if self.proxy_required:
|
||||
output_exception(
|
||||
f"Proxy is required but the connection failed because: {e}", # type: ignore
|
||||
exit_code_output=True,
|
||||
)
|
||||
|
||||
if "proxies" in kwargs or self.proxies:
|
||||
params["proxies"] = {}
|
||||
params["timeout"] = func_timeout
|
||||
self.proxies = {}
|
||||
message = (
|
||||
"The proxy configuration failed to function and was disregarded."
|
||||
)
|
||||
LOG.debug(message)
|
||||
if message not in [
|
||||
a["message"] for a in SafetyContext.local_announcements
|
||||
]:
|
||||
SafetyContext.local_announcements.append(
|
||||
{"message": message, "type": "warning", "local": True}
|
||||
)
|
||||
|
||||
return request_func(**params)
|
||||
|
||||
raise e
|
||||
|
||||
def fetch_openid_config(self) -> Any:
|
||||
"""
|
||||
Fetch the OpenID configuration from the authorization server.
|
||||
|
||||
Returns:
|
||||
Any: The OpenID configuration.
|
||||
"""
|
||||
|
||||
try:
|
||||
openid_config = self.get(
|
||||
url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT
|
||||
).json()
|
||||
except Exception as e:
|
||||
LOG.debug("Unable to load the openID config: %s", e)
|
||||
openid_config = {}
|
||||
|
||||
return openid_config
|
||||
|
||||
@parse_response
|
||||
def fetch_user_info(self) -> Any:
|
||||
"""
|
||||
Fetch user information from the authorization server.
|
||||
|
||||
Returns:
|
||||
Any: The user information.
|
||||
"""
|
||||
USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo"
|
||||
|
||||
r = self.get(url=USER_INFO_ENDPOINT)
|
||||
|
||||
return r
|
||||
|
||||
@parse_response
|
||||
def check_project(
|
||||
self,
|
||||
scan_stage: str,
|
||||
safety_source: str,
|
||||
project_slug: Optional[str] = None,
|
||||
git_origin: Optional[str] = None,
|
||||
project_slug_source: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Check project information.
|
||||
|
||||
Args:
|
||||
scan_stage (str): The scan stage.
|
||||
safety_source (str): The safety source.
|
||||
project_slug (Optional[str]): The project slug.
|
||||
git_origin (Optional[str]): The git origin.
|
||||
project_slug_source (Optional[str]): The project slug source.
|
||||
|
||||
Returns:
|
||||
Any: The project information.
|
||||
"""
|
||||
|
||||
data = {
|
||||
"scan_stage": scan_stage,
|
||||
"safety_source": safety_source,
|
||||
"project_slug": project_slug,
|
||||
"project_slug_source": project_slug_source,
|
||||
"git_origin": git_origin,
|
||||
}
|
||||
|
||||
r = self.post(url=PLATFORM_API_PROJECT_CHECK_ENDPOINT, json=data)
|
||||
|
||||
return r
|
||||
|
||||
@parse_response
|
||||
def project(self, project_id: str) -> Any:
|
||||
"""
|
||||
Get project information.
|
||||
|
||||
Args:
|
||||
project_id (str): The project ID.
|
||||
|
||||
Returns:
|
||||
Any: The project information.
|
||||
"""
|
||||
data = {"project": project_id}
|
||||
|
||||
return self.get(url=PLATFORM_API_PROJECT_ENDPOINT, params=data)
|
||||
|
||||
@parse_response
|
||||
def download_policy(
|
||||
self, project_id: Optional[str], stage: Stage, branch: Optional[str]
|
||||
) -> Any:
|
||||
"""
|
||||
Download the project policy.
|
||||
|
||||
Args:
|
||||
project_id (Optional[str]): The project ID.
|
||||
stage (Stage): The stage.
|
||||
branch (Optional[str]): The branch.
|
||||
|
||||
Returns:
|
||||
Any: The policy data.
|
||||
"""
|
||||
data = {
|
||||
"project": project_id,
|
||||
"stage": STAGE_ID_MAPPING[stage],
|
||||
"branch": branch,
|
||||
}
|
||||
|
||||
return self.get(url=PLATFORM_API_POLICY_ENDPOINT, params=data)
|
||||
|
||||
@parse_response
|
||||
def project_scan_request(self, project_id: str) -> Any:
|
||||
"""
|
||||
Request a project scan.
|
||||
|
||||
Args:
|
||||
project_id (str): The project ID.
|
||||
|
||||
Returns:
|
||||
Any: The scan request result.
|
||||
"""
|
||||
data = {"project_id": project_id}
|
||||
|
||||
return self.post(url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data)
|
||||
|
||||
@parse_response
|
||||
def upload_report(self, json_report: str) -> Any:
|
||||
"""
|
||||
Upload a scan report.
|
||||
|
||||
Args:
|
||||
json_report (str): The JSON report.
|
||||
|
||||
Returns:
|
||||
Any: The upload result.
|
||||
"""
|
||||
|
||||
return self.post(
|
||||
url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
|
||||
data=json_report,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def upload_requirements(self, json_payload: str) -> Any:
|
||||
"""
|
||||
Upload a scan report.
|
||||
Args:
|
||||
json_payload (str): The JSON payload to upload.
|
||||
Returns:
|
||||
Any: The result of the upload operation.
|
||||
"""
|
||||
return self.post(
|
||||
url=PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
|
||||
data=json.dumps(json_payload),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
@parse_response
|
||||
def check_updates(
|
||||
self,
|
||||
version: int,
|
||||
safety_version: Optional[str] = None,
|
||||
python_version: Optional[str] = None,
|
||||
os_type: Optional[str] = None,
|
||||
os_release: Optional[str] = None,
|
||||
os_description: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Check for updates.
|
||||
|
||||
Args:
|
||||
version (int): The version.
|
||||
safety_version (Optional[str]): The Safety version.
|
||||
python_version (Optional[str]): The Python version.
|
||||
os_type (Optional[str]): The OS type.
|
||||
os_release (Optional[str]): The OS release.
|
||||
os_description (Optional[str]): The OS description.
|
||||
|
||||
Returns:
|
||||
Any: The update check result.
|
||||
"""
|
||||
data = {
|
||||
"version": version,
|
||||
"safety_version": safety_version,
|
||||
"python_version": python_version,
|
||||
"os_type": os_type,
|
||||
"os_release": os_release,
|
||||
"os_description": os_description,
|
||||
}
|
||||
|
||||
return self.get(url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data)
|
||||
|
||||
@parse_response
|
||||
def audit_packages(
|
||||
self, packages: List[str], ecosystem: Literal["pypi", "npmjs"]
|
||||
) -> Any:
|
||||
"""
|
||||
Audits packages for vulnerabilities
|
||||
Args:
|
||||
packages: list of package specifiers
|
||||
ecosystem: the ecosystem to audit
|
||||
|
||||
Returns:
|
||||
Any: The packages audit result.
|
||||
"""
|
||||
url = (
|
||||
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT
|
||||
if ecosystem == "npmjs"
|
||||
else FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT
|
||||
)
|
||||
|
||||
data = {"packages": [{"package_specifier": package} for package in packages]}
|
||||
|
||||
return self.post(url=url, json=data)
|
||||
|
||||
@parse_response
|
||||
def initialize(self) -> Any:
|
||||
"""
|
||||
Initialize a run.
|
||||
|
||||
Returns:
|
||||
Any: The initialization result.
|
||||
"""
|
||||
try:
|
||||
response = self.get(
|
||||
url=PLATFORM_API_INITIALIZE_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=5,
|
||||
)
|
||||
return response
|
||||
except requests.exceptions.Timeout:
|
||||
LOG.error("Auth request to initialize timed out after 5 seconds.")
|
||||
except Exception:
|
||||
LOG.exception("Exception trying to auth initialize", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
class S3PresignedAdapter(HTTPAdapter):
|
||||
def send( # type: ignore
|
||||
self, request: requests.PreparedRequest, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Send a request, removing the Authorization header.
|
||||
|
||||
Args:
|
||||
request (requests.PreparedRequest): The prepared request.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
requests.Response: The response.
|
||||
"""
|
||||
request.headers.pop("Authorization", None)
|
||||
return super().send(request, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_jupyter_notebook() -> bool:
|
||||
"""
|
||||
Detects if the code is running in a Jupyter notebook environment, including
|
||||
various cloud-hosted Jupyter notebooks.
|
||||
|
||||
Returns:
|
||||
bool: True if the environment is identified as a Jupyter notebook (or
|
||||
equivalent cloud-based environment), False otherwise.
|
||||
|
||||
Supported environments:
|
||||
- Google Colab
|
||||
- Amazon SageMaker
|
||||
- Azure Notebooks
|
||||
- Kaggle Notebooks
|
||||
- Databricks Notebooks
|
||||
- Datalore by JetBrains
|
||||
- Paperspace Gradient Notebooks
|
||||
- Classic Jupyter Notebook and JupyterLab
|
||||
"""
|
||||
if (
|
||||
(
|
||||
importlib.util.find_spec("google")
|
||||
and importlib.util.find_spec("google.colab")
|
||||
)
|
||||
is not None
|
||||
or importlib.util.find_spec("sagemaker") is not None
|
||||
or importlib.util.find_spec("azureml") is not None
|
||||
or importlib.util.find_spec("kaggle") is not None
|
||||
or importlib.util.find_spec("dbutils") is not None
|
||||
or importlib.util.find_spec("datalore") is not None
|
||||
or importlib.util.find_spec("gradient") is not None
|
||||
):
|
||||
return True
|
||||
|
||||
# Detect classic Jupyter Notebook, JupyterLab, and other IPython kernel-based environments
|
||||
try:
|
||||
from IPython import get_ipython # type: ignore
|
||||
|
||||
ipython = get_ipython()
|
||||
if ipython is not None and "IPKernelApp" in ipython.config:
|
||||
return True
|
||||
except (ImportError, AttributeError, NameError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def save_flags_config(flags: Dict[FeatureType, bool]) -> None:
|
||||
"""
|
||||
Save feature flags configuration to file.
|
||||
|
||||
This function attempts to save feature flags to the configuration file
|
||||
but will fail silently if unable to do so (e.g., due to permission issues
|
||||
or disk problems). Silent failure is chosen to prevent configuration issues
|
||||
from disrupting core application functionality.
|
||||
|
||||
Note that if saving fails, the application will continue using existing
|
||||
or default flag values until the next restart.
|
||||
|
||||
Args:
|
||||
flags: Dictionary mapping feature types to their enabled/disabled state
|
||||
|
||||
The operation will be logged (with stack trace) if it fails.
|
||||
"""
|
||||
import configparser
|
||||
|
||||
from safety.constants import CONFIG_FILE_USER
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG_FILE_USER)
|
||||
|
||||
flag_settings = {key.name.upper(): str(value) for key, value in flags.items()}
|
||||
|
||||
if not config.has_section("settings"):
|
||||
config.add_section("settings")
|
||||
|
||||
settings = dict(config.items("settings"))
|
||||
settings.update(flag_settings)
|
||||
|
||||
for key, value in settings.items():
|
||||
config.set("settings", key, value)
|
||||
|
||||
try:
|
||||
with open(CONFIG_FILE_USER, "w") as config_file:
|
||||
config.write(config_file)
|
||||
except Exception:
|
||||
LOG.exception("Unable to save flags configuration.")
|
||||
|
||||
|
||||
def get_feature_name(feature: FeatureType, as_attr: bool = False) -> str:
|
||||
"""Returns a formatted feature name with enabled suffix.
|
||||
|
||||
Args:
|
||||
feature: The feature to format the name for
|
||||
as_attr: If True, formats for attribute usage (underscore),
|
||||
otherwise uses hyphen
|
||||
|
||||
Returns:
|
||||
Formatted feature name string with enabled suffix
|
||||
"""
|
||||
name = feature.name.lower()
|
||||
separator = "_" if as_attr else "-"
|
||||
return f"{name}{separator}enabled"
|
||||
|
||||
|
||||
def str_to_bool(value) -> Optional[bool]:
|
||||
"""Convert basic string representations to boolean."""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.lower().strip()
|
||||
if value in ("true"):
|
||||
return True
|
||||
if value in ("false"):
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def initialize(ctx: Any, refresh: bool = True) -> None:
|
||||
"""
|
||||
Initializes the run by loading settings.
|
||||
|
||||
Args:
|
||||
ctx (Any): The context object.
|
||||
refresh (bool): Whether to refresh settings from the server. Defaults to True.
|
||||
"""
|
||||
settings = None
|
||||
current_values = {}
|
||||
|
||||
if not ctx.obj:
|
||||
ctx.obj = SafetyCLI()
|
||||
|
||||
for feature in FeatureType:
|
||||
value = get_config_setting(feature.name)
|
||||
if value is not None:
|
||||
current_values[feature] = str_to_bool(value)
|
||||
|
||||
if refresh:
|
||||
try:
|
||||
settings = ctx.obj.auth.client.initialize() # type: ignore
|
||||
except Exception:
|
||||
LOG.info("Unable to initialize, continue with default values.")
|
||||
|
||||
if settings:
|
||||
for feature in FeatureType:
|
||||
server_value = str_to_bool(settings.get(feature.config_key))
|
||||
if server_value is not None:
|
||||
if (
|
||||
feature not in current_values
|
||||
or current_values[feature] != server_value
|
||||
):
|
||||
current_values[feature] = server_value
|
||||
|
||||
save_flags_config(current_values)
|
||||
|
||||
for feature, value in current_values.items():
|
||||
if value is not None:
|
||||
setattr(ctx.obj, feature.attr_name, value)
|
||||
1353
Backend/venv/lib/python3.12/site-packages/safety/cli.py
Normal file
1353
Backend/venv/lib/python3.12/site-packages/safety/cli.py
Normal file
File diff suppressed because it is too large
Load Diff
967
Backend/venv/lib/python3.12/site-packages/safety/cli_util.py
Normal file
967
Backend/venv/lib/python3.12/site-packages/safety/cli_util.py
Normal file
@@ -0,0 +1,967 @@
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import click
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
from typer.core import MarkupMode, TyperCommand, TyperGroup
|
||||
from click.utils import make_str
|
||||
|
||||
|
||||
from safety.constants import (
|
||||
BETA_PANEL_DESCRIPTION_HELP,
|
||||
MSG_NO_AUTHD_CICD_PROD_STG,
|
||||
MSG_NO_AUTHD_CICD_PROD_STG_ORG,
|
||||
MSG_NO_AUTHD_DEV_STG,
|
||||
MSG_NO_AUTHD_DEV_STG_ORG_PROMPT,
|
||||
MSG_NO_AUTHD_DEV_STG_PROMPT,
|
||||
MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL,
|
||||
MSG_NO_VERIFIED_EMAIL_TPL,
|
||||
CONTEXT_COMMAND_TYPE,
|
||||
FeatureType,
|
||||
)
|
||||
from safety.scan.constants import CONSOLE_HELP_THEME
|
||||
from safety.models import SafetyCLI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from click.core import Command, Context
|
||||
from safety.auth.models import Auth
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CommandType(Enum):
|
||||
MAIN = "main"
|
||||
UTILITY = "utility"
|
||||
BETA = "beta"
|
||||
|
||||
|
||||
def custom_print_options_panel(
|
||||
name: str, params: List[Any], ctx: Any, console: Console
|
||||
) -> None:
|
||||
"""
|
||||
Print a panel with options.
|
||||
|
||||
Args:
|
||||
name (str): The title of the panel.
|
||||
params (List[Any]): The list of options/arguments to print.
|
||||
ctx (Any): The context object.
|
||||
markup_mode (str): The markup mode.
|
||||
console (Console): The console to print to.
|
||||
"""
|
||||
table = Table(title=name, show_lines=True)
|
||||
for param in params:
|
||||
opts = getattr(param, "opts", "")
|
||||
help_text = getattr(param, "help", "")
|
||||
table.add_row(str(opts), help_text)
|
||||
console.print(table)
|
||||
|
||||
|
||||
def custom_print_commands_panel(
|
||||
name: str, commands: List[Any], console: Console
|
||||
) -> None:
|
||||
"""
|
||||
Print a panel with commands.
|
||||
|
||||
Args:
|
||||
name (str): The title of the panel.
|
||||
commands (List[Any]): The list of commands to print.
|
||||
console (Console): The console to print to.
|
||||
"""
|
||||
table = Table(title=name, show_lines=True)
|
||||
for command in commands:
|
||||
table.add_row(command.name, command.help or "")
|
||||
console.print(table)
|
||||
|
||||
|
||||
def custom_make_rich_text(text: str) -> Text:
|
||||
"""
|
||||
Create rich text.
|
||||
|
||||
Args:
|
||||
text (str): The text to format.
|
||||
|
||||
Returns:
|
||||
Text: The formatted rich text.
|
||||
"""
|
||||
return Text(text)
|
||||
|
||||
|
||||
def custom_get_help_text(obj: Any) -> Text:
|
||||
"""
|
||||
Get the help text for an object.
|
||||
|
||||
Args:
|
||||
obj (Any): The object to get help text for.
|
||||
|
||||
Returns:
|
||||
Text: The formatted help text.
|
||||
"""
|
||||
return Text(obj.help)
|
||||
|
||||
|
||||
def custom_make_command_help(help_text: str) -> Text:
|
||||
"""
|
||||
Create rich text for command help.
|
||||
|
||||
Args:
|
||||
help_text (str): The help text to format.
|
||||
markup_mode (str): The markup mode.
|
||||
|
||||
Returns:
|
||||
Text: The formatted rich text.
|
||||
"""
|
||||
return Text(help_text)
|
||||
|
||||
|
||||
def get_command_for(name: str, typer_instance: typer.Typer) -> click.Command:
|
||||
"""
|
||||
Retrieve a command by name from a Typer instance.
|
||||
|
||||
Args:
|
||||
name (str): The name of the command.
|
||||
typer_instance (typer.Typer): The Typer instance.
|
||||
|
||||
Returns:
|
||||
click.Command: The found command.
|
||||
"""
|
||||
single_command = next(
|
||||
(
|
||||
command
|
||||
for command in typer_instance.registered_commands
|
||||
if command.name == name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not single_command:
|
||||
raise ValueError("Unable to find the command name.")
|
||||
|
||||
single_command.context_settings = typer_instance.info.context_settings
|
||||
click_command = typer.main.get_command_from_info(
|
||||
single_command,
|
||||
pretty_exceptions_short=typer_instance.pretty_exceptions_short,
|
||||
rich_markup_mode=typer_instance.rich_markup_mode,
|
||||
)
|
||||
if typer_instance._add_completion:
|
||||
click_install_param, click_show_param = (
|
||||
typer.main.get_install_completion_arguments()
|
||||
)
|
||||
click_command.params.append(click_install_param)
|
||||
click_command.params.append(click_show_param)
|
||||
return click_command
|
||||
|
||||
|
||||
def pass_safety_cli_obj(func):
|
||||
"""
|
||||
Decorator to ensure the SafetyCLI object exists for a command.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def inner(ctx, *args, **kwargs):
|
||||
if not ctx.obj:
|
||||
ctx.obj = SafetyCLI()
|
||||
|
||||
return func(ctx, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def pretty_format_help(
|
||||
obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode
|
||||
) -> None:
|
||||
"""
|
||||
Format and print help text in a pretty format.
|
||||
|
||||
Args:
|
||||
obj (Union[click.Command, click.Group]): The Click command or group.
|
||||
ctx (click.Context): The Click context.
|
||||
markup_mode (MarkupMode): The markup mode.
|
||||
"""
|
||||
from rich.align import Align
|
||||
from rich.console import Console
|
||||
from rich.padding import Padding
|
||||
from rich.theme import Theme
|
||||
from typer.rich_utils import (
|
||||
ARGUMENTS_PANEL_TITLE,
|
||||
COMMANDS_PANEL_TITLE,
|
||||
OPTIONS_PANEL_TITLE,
|
||||
STYLE_USAGE_COMMAND,
|
||||
highlighter,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
with console.use_theme(Theme(styles=CONSOLE_HELP_THEME)) as theme_context:
|
||||
console = theme_context.console
|
||||
# Print command / group help if we have some
|
||||
if obj.help:
|
||||
console.print()
|
||||
|
||||
# Print with some padding
|
||||
console.print(
|
||||
Padding(Align(custom_get_help_text(obj=obj), pad=False), (0, 1, 0, 1))
|
||||
)
|
||||
|
||||
# Print usage
|
||||
console.print(
|
||||
Padding(highlighter(obj.get_usage(ctx)), 1), style=STYLE_USAGE_COMMAND
|
||||
)
|
||||
|
||||
if isinstance(obj, click.MultiCommand):
|
||||
panel_to_commands: DefaultDict[str, List[click.Command]] = defaultdict(list)
|
||||
for command_name in obj.list_commands(ctx):
|
||||
command = obj.get_command(ctx, command_name)
|
||||
if command and not command.hidden:
|
||||
panel_name = (
|
||||
getattr(command, "rich_help_panel", None)
|
||||
or COMMANDS_PANEL_TITLE
|
||||
)
|
||||
panel_to_commands[panel_name].append(command)
|
||||
|
||||
# Print each command group panel
|
||||
default_commands = panel_to_commands.get(COMMANDS_PANEL_TITLE, [])
|
||||
custom_print_commands_panel(
|
||||
name=COMMANDS_PANEL_TITLE,
|
||||
commands=default_commands,
|
||||
console=console,
|
||||
)
|
||||
for panel_name, commands in panel_to_commands.items():
|
||||
if panel_name == COMMANDS_PANEL_TITLE:
|
||||
# Already printed above
|
||||
continue
|
||||
custom_print_commands_panel(
|
||||
name=panel_name,
|
||||
commands=commands,
|
||||
console=console,
|
||||
)
|
||||
|
||||
panel_to_arguments: DefaultDict[str, List[click.Argument]] = defaultdict(list)
|
||||
panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list)
|
||||
for param in obj.get_params(ctx):
|
||||
# Skip if option is hidden
|
||||
if getattr(param, "hidden", False):
|
||||
continue
|
||||
if isinstance(param, click.Argument):
|
||||
panel_name = (
|
||||
getattr(param, "rich_help_panel", None) or ARGUMENTS_PANEL_TITLE
|
||||
)
|
||||
panel_to_arguments[panel_name].append(param)
|
||||
elif isinstance(param, click.Option):
|
||||
panel_name = (
|
||||
getattr(param, "rich_help_panel", None) or OPTIONS_PANEL_TITLE
|
||||
)
|
||||
panel_to_options[panel_name].append(param)
|
||||
|
||||
default_options = panel_to_options.get(OPTIONS_PANEL_TITLE, [])
|
||||
custom_print_options_panel(
|
||||
name=OPTIONS_PANEL_TITLE,
|
||||
params=default_options,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
for panel_name, options in panel_to_options.items():
|
||||
if panel_name == OPTIONS_PANEL_TITLE:
|
||||
# Already printed above
|
||||
continue
|
||||
custom_print_options_panel(
|
||||
name=panel_name,
|
||||
params=options,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
|
||||
default_arguments = panel_to_arguments.get(ARGUMENTS_PANEL_TITLE, [])
|
||||
custom_print_options_panel(
|
||||
name=ARGUMENTS_PANEL_TITLE,
|
||||
params=default_arguments,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
for panel_name, arguments in panel_to_arguments.items():
|
||||
if panel_name == ARGUMENTS_PANEL_TITLE:
|
||||
# Already printed above
|
||||
continue
|
||||
custom_print_options_panel(
|
||||
name=panel_name,
|
||||
params=arguments,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
|
||||
if ctx.parent:
|
||||
params = []
|
||||
for param in ctx.parent.command.params:
|
||||
if isinstance(param, click.Option):
|
||||
params.append(param)
|
||||
|
||||
custom_print_options_panel(
|
||||
name="Global-Options",
|
||||
params=params,
|
||||
ctx=ctx.parent,
|
||||
console=console,
|
||||
)
|
||||
|
||||
# Epilogue if we have it
|
||||
if obj.epilog:
|
||||
# Remove single linebreaks, replace double with single
|
||||
lines = obj.epilog.split("\n\n")
|
||||
epilogue = "\n".join([x.replace("\n", " ").strip() for x in lines])
|
||||
epilogue_text = custom_make_rich_text(text=epilogue)
|
||||
console.print(Padding(Align(epilogue_text, pad=False), 1))
|
||||
|
||||
|
||||
def print_main_command_panels(
|
||||
*,
|
||||
name: str,
|
||||
commands_type: CommandType,
|
||||
commands: List[click.Command],
|
||||
markup_mode: MarkupMode,
|
||||
console,
|
||||
) -> None:
|
||||
"""
|
||||
Print the main command panels.
|
||||
|
||||
Args:
|
||||
name (str): The name of the panel.
|
||||
commands (List[click.Command]): List of commands to display.
|
||||
markup_mode (MarkupMode): The markup mode.
|
||||
console: The Rich console.
|
||||
"""
|
||||
from rich import box
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
from typer.rich_utils import (
|
||||
ALIGN_COMMANDS_PANEL,
|
||||
STYLE_COMMANDS_PANEL_BORDER,
|
||||
STYLE_COMMANDS_TABLE_BORDER_STYLE,
|
||||
STYLE_COMMANDS_TABLE_BOX,
|
||||
STYLE_COMMANDS_TABLE_LEADING,
|
||||
STYLE_COMMANDS_TABLE_PAD_EDGE,
|
||||
STYLE_COMMANDS_TABLE_PADDING,
|
||||
STYLE_COMMANDS_TABLE_ROW_STYLES,
|
||||
STYLE_COMMANDS_TABLE_SHOW_LINES,
|
||||
)
|
||||
|
||||
t_styles: Dict[str, Any] = {
|
||||
"show_lines": STYLE_COMMANDS_TABLE_SHOW_LINES,
|
||||
"leading": STYLE_COMMANDS_TABLE_LEADING,
|
||||
"box": STYLE_COMMANDS_TABLE_BOX,
|
||||
"border_style": STYLE_COMMANDS_TABLE_BORDER_STYLE,
|
||||
"row_styles": STYLE_COMMANDS_TABLE_ROW_STYLES,
|
||||
"pad_edge": STYLE_COMMANDS_TABLE_PAD_EDGE,
|
||||
"padding": STYLE_COMMANDS_TABLE_PADDING,
|
||||
}
|
||||
box_style = getattr(box, t_styles.pop("box"), None)
|
||||
|
||||
commands_table = Table(
|
||||
highlight=False,
|
||||
show_header=False,
|
||||
expand=True,
|
||||
box=box_style,
|
||||
**t_styles,
|
||||
)
|
||||
|
||||
console_width = 80
|
||||
column_width = 25
|
||||
|
||||
if console.size and console.size[0] > 80:
|
||||
console_width = console.size[0]
|
||||
|
||||
from rich.console import Group
|
||||
|
||||
description = None
|
||||
|
||||
if commands_type is CommandType.BETA:
|
||||
description = Group(Text(""), Text(BETA_PANEL_DESCRIPTION_HELP), Text(""))
|
||||
|
||||
commands_table.add_column(
|
||||
style="bold cyan", no_wrap=True, width=column_width, max_width=column_width
|
||||
)
|
||||
commands_table.add_column(width=console_width - column_width)
|
||||
|
||||
rows = []
|
||||
|
||||
for command in commands:
|
||||
helptext = command.short_help or command.help or ""
|
||||
command_name = command.name or ""
|
||||
command_name_text = (
|
||||
Text(command_name, style="")
|
||||
if commands_type is CommandType.BETA
|
||||
else Text(command_name)
|
||||
)
|
||||
rows.append(
|
||||
[
|
||||
command_name_text,
|
||||
custom_make_command_help(
|
||||
help_text=helptext,
|
||||
),
|
||||
]
|
||||
)
|
||||
rows.append([])
|
||||
for row in rows:
|
||||
commands_table.add_row(*row)
|
||||
if commands_table.row_count:
|
||||
renderables = (
|
||||
[description, commands_table]
|
||||
if description is not None
|
||||
else [Text(""), commands_table]
|
||||
)
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Group(*renderables),
|
||||
border_style=STYLE_COMMANDS_PANEL_BORDER,
|
||||
title=name,
|
||||
title_align=ALIGN_COMMANDS_PANEL,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# The help output for the main safety root command: `safety --help`
|
||||
def format_main_help(
|
||||
obj: Union[click.Command, click.Group], ctx: click.Context, markup_mode: MarkupMode
|
||||
) -> None:
|
||||
"""
|
||||
Format the main help output for the safety root command.
|
||||
|
||||
Args:
|
||||
obj (Union[click.Command, click.Group]): The Click command or group.
|
||||
ctx (click.Context): The Click context.
|
||||
markup_mode (MarkupMode): The markup mode.
|
||||
"""
|
||||
from rich.align import Align
|
||||
from rich.console import Console
|
||||
from rich.padding import Padding
|
||||
from rich.theme import Theme
|
||||
from typer.rich_utils import (
|
||||
ARGUMENTS_PANEL_TITLE,
|
||||
COMMANDS_PANEL_TITLE,
|
||||
OPTIONS_PANEL_TITLE,
|
||||
STYLE_USAGE_COMMAND,
|
||||
highlighter,
|
||||
)
|
||||
|
||||
typer_console = Console()
|
||||
|
||||
with typer_console.use_theme(Theme(styles=CONSOLE_HELP_THEME)) as theme_context:
|
||||
console = theme_context.console
|
||||
|
||||
# Print command / group help if we have some
|
||||
if obj.help:
|
||||
console.print()
|
||||
# Print with some padding
|
||||
console.print(
|
||||
Padding(
|
||||
Align(
|
||||
custom_get_help_text(obj=obj),
|
||||
pad=False,
|
||||
),
|
||||
(0, 1, 0, 1),
|
||||
)
|
||||
)
|
||||
|
||||
# Print usage
|
||||
console.print(
|
||||
Padding(highlighter(obj.get_usage(ctx)), 1), style=STYLE_USAGE_COMMAND
|
||||
)
|
||||
|
||||
if isinstance(obj, click.MultiCommand):
|
||||
UTILITY_COMMANDS_PANEL_TITLE = "Utility commands"
|
||||
BETA_COMMANDS_PANEL_TITLE = "Beta Commands"
|
||||
|
||||
COMMANDS_PANEL_TITLE_CONSTANTS = {
|
||||
CommandType.MAIN: COMMANDS_PANEL_TITLE,
|
||||
CommandType.UTILITY: UTILITY_COMMANDS_PANEL_TITLE,
|
||||
CommandType.BETA: BETA_COMMANDS_PANEL_TITLE,
|
||||
}
|
||||
|
||||
panel_to_commands: Dict[CommandType, List[click.Command]] = {}
|
||||
|
||||
# Keep order of panels
|
||||
for command_type in COMMANDS_PANEL_TITLE_CONSTANTS.keys():
|
||||
panel_to_commands[command_type] = []
|
||||
|
||||
for command_name in obj.list_commands(ctx):
|
||||
command = obj.get_command(ctx, command_name)
|
||||
if command and not command.hidden:
|
||||
command_type = command.context_settings.get(
|
||||
CONTEXT_COMMAND_TYPE, CommandType.MAIN
|
||||
)
|
||||
panel_to_commands[command_type].append(command)
|
||||
|
||||
for command_type, commands in panel_to_commands.items():
|
||||
print_main_command_panels(
|
||||
name=COMMANDS_PANEL_TITLE_CONSTANTS[command_type],
|
||||
commands_type=command_type,
|
||||
commands=commands,
|
||||
markup_mode=markup_mode,
|
||||
console=console,
|
||||
)
|
||||
|
||||
panel_to_arguments: DefaultDict[str, List[click.Argument]] = defaultdict(list)
|
||||
panel_to_options: DefaultDict[str, List[click.Option]] = defaultdict(list)
|
||||
for param in obj.get_params(ctx):
|
||||
# Skip if option is hidden
|
||||
if getattr(param, "hidden", False):
|
||||
continue
|
||||
if isinstance(param, click.Argument):
|
||||
panel_name = (
|
||||
getattr(param, "rich_help_panel", None) or ARGUMENTS_PANEL_TITLE
|
||||
)
|
||||
panel_to_arguments[panel_name].append(param)
|
||||
elif isinstance(param, click.Option):
|
||||
panel_name = (
|
||||
getattr(param, "rich_help_panel", None) or OPTIONS_PANEL_TITLE
|
||||
)
|
||||
panel_to_options[panel_name].append(param)
|
||||
default_arguments = panel_to_arguments.get(ARGUMENTS_PANEL_TITLE, [])
|
||||
custom_print_options_panel(
|
||||
name=ARGUMENTS_PANEL_TITLE,
|
||||
params=default_arguments,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
for panel_name, arguments in panel_to_arguments.items():
|
||||
if panel_name == ARGUMENTS_PANEL_TITLE:
|
||||
# Already printed above
|
||||
continue
|
||||
custom_print_options_panel(
|
||||
name=panel_name,
|
||||
params=arguments,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
default_options = panel_to_options.get(OPTIONS_PANEL_TITLE, [])
|
||||
custom_print_options_panel(
|
||||
name=OPTIONS_PANEL_TITLE,
|
||||
params=default_options,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
for panel_name, options in panel_to_options.items():
|
||||
if panel_name == OPTIONS_PANEL_TITLE:
|
||||
# Already printed above
|
||||
continue
|
||||
custom_print_options_panel(
|
||||
name=panel_name,
|
||||
params=options,
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
)
|
||||
|
||||
# Epilogue if we have it
|
||||
if obj.epilog:
|
||||
# Remove single linebreaks, replace double with single
|
||||
lines = obj.epilog.split("\n\n")
|
||||
epilogue = "\n".join([x.replace("\n", " ").strip() for x in lines])
|
||||
epilogue_text = custom_make_rich_text(text=epilogue)
|
||||
console.print(Padding(Align(epilogue_text, pad=False), 1))
|
||||
|
||||
|
||||
def process_auth_status_not_ready(console, auth: "Auth", ctx: typer.Context) -> None:
|
||||
"""
|
||||
Handle the process when the authentication status is not ready.
|
||||
|
||||
Args:
|
||||
console: The Rich console.
|
||||
auth (Auth): The Auth object.
|
||||
ctx (typer.Context): The Typer context.
|
||||
"""
|
||||
from rich.prompt import Confirm, Prompt
|
||||
from safety_schemas.models import Stage
|
||||
from safety.auth.constants import CLI_AUTH, MSG_NON_AUTHENTICATED
|
||||
|
||||
if not auth.client or not auth.client.is_using_auth_credentials():
|
||||
if auth.stage is Stage.development:
|
||||
console.print()
|
||||
if auth.org:
|
||||
confirmed = Confirm.ask(
|
||||
MSG_NO_AUTHD_DEV_STG_ORG_PROMPT,
|
||||
choices=["Y", "N", "y", "n"],
|
||||
show_choices=False,
|
||||
show_default=False,
|
||||
default=True,
|
||||
console=console,
|
||||
)
|
||||
|
||||
if not confirmed:
|
||||
sys.exit(0)
|
||||
|
||||
from safety.auth.cli import auth_app
|
||||
|
||||
login_command = get_command_for(name="login", typer_instance=auth_app)
|
||||
ctx.invoke(login_command)
|
||||
else:
|
||||
console.print(MSG_NO_AUTHD_DEV_STG)
|
||||
console.print()
|
||||
choices = ["L", "R", "l", "r"]
|
||||
next_command = Prompt.ask(
|
||||
MSG_NO_AUTHD_DEV_STG_PROMPT,
|
||||
default=None,
|
||||
choices=choices,
|
||||
show_choices=False,
|
||||
console=console,
|
||||
)
|
||||
|
||||
from safety.auth.cli import auth_app
|
||||
|
||||
login_command = get_command_for(name="login", typer_instance=auth_app)
|
||||
register_command = get_command_for(
|
||||
name="register", typer_instance=auth_app
|
||||
)
|
||||
if next_command is None or next_command.lower() not in choices:
|
||||
sys.exit(0)
|
||||
|
||||
console.print()
|
||||
if next_command.lower() == "r":
|
||||
ctx.invoke(register_command)
|
||||
else:
|
||||
ctx.invoke(login_command)
|
||||
|
||||
if not ctx.obj.auth.email_verified:
|
||||
sys.exit(1)
|
||||
else:
|
||||
if not auth.org:
|
||||
console.print(MSG_NO_AUTHD_CICD_PROD_STG_ORG.format(LOGIN_URL=CLI_AUTH))
|
||||
|
||||
else:
|
||||
console.print(MSG_NO_AUTHD_CICD_PROD_STG)
|
||||
console.print(
|
||||
MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL.format(
|
||||
LOGIN_URL=CLI_AUTH, SIGNUP_URL=f"{CLI_AUTH}/?sign_up=True"
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
elif not auth.email_verified:
|
||||
console.print()
|
||||
console.print(
|
||||
MSG_NO_VERIFIED_EMAIL_TPL.format(
|
||||
email=auth.email if auth.email else "Missing email"
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
console.print(MSG_NON_AUTHENTICATED)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class CustomContext(click.Context):
|
||||
def __init__(
|
||||
self,
|
||||
command: "Command",
|
||||
parent: Optional["Context"] = None,
|
||||
command_type: CommandType = CommandType.MAIN,
|
||||
feature_type: Optional[FeatureType] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.command_type = command_type
|
||||
self.feature_type = feature_type
|
||||
self.started_at = time.monotonic()
|
||||
self.command_alias_used: Optional[str] = None
|
||||
super().__init__(command, parent=parent, **kwargs)
|
||||
|
||||
|
||||
class SafetyCLISubGroup(TyperGroup):
|
||||
"""
|
||||
Custom TyperGroup with additional functionality for Safety CLI.
|
||||
"""
|
||||
|
||||
context_class = CustomContext
|
||||
|
||||
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""
|
||||
Format help message with rich formatting.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
formatter (click.HelpFormatter): Click help formatter.
|
||||
"""
|
||||
pretty_format_help(self, ctx, markup_mode=self.rich_markup_mode)
|
||||
|
||||
def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""
|
||||
Format usage message.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
formatter (click.HelpFormatter): Click help formatter.
|
||||
"""
|
||||
command_path = ctx.command_path
|
||||
pieces = self.collect_usage_pieces(ctx)
|
||||
main_group = ctx.parent
|
||||
if main_group:
|
||||
command_path = (
|
||||
f"{main_group.command_path} [GLOBAL-OPTIONS] {ctx.command.name}"
|
||||
)
|
||||
|
||||
formatter.write_usage(command_path, " ".join(pieces))
|
||||
|
||||
def command(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> click.Command: # type: ignore[override]
|
||||
"""
|
||||
Create a new command.
|
||||
|
||||
Args:
|
||||
*args: Variable length argument list.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
click.Command: The created command.
|
||||
"""
|
||||
super().command(*args, **kwargs)
|
||||
|
||||
|
||||
class SafetyCLICommand(TyperCommand):
|
||||
"""
|
||||
Custom TyperCommand with additional functionality for Safety CLI.
|
||||
"""
|
||||
|
||||
context_class = CustomContext
|
||||
|
||||
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""
|
||||
Format help message with rich formatting.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
formatter (click.HelpFormatter): Click help formatter.
|
||||
"""
|
||||
pretty_format_help(self, ctx, markup_mode=self.rich_markup_mode)
|
||||
|
||||
def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""
|
||||
Format usage message.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
formatter (click.HelpFormatter): Click help formatter.
|
||||
"""
|
||||
command_path = ctx.command_path
|
||||
pieces = self.collect_usage_pieces(ctx)
|
||||
main_group = ctx.parent
|
||||
if main_group:
|
||||
command_path = (
|
||||
f"{main_group.command_path} [GLOBAL-OPTIONS] {ctx.command.name}"
|
||||
)
|
||||
|
||||
formatter.write_usage(command_path, " ".join(pieces))
|
||||
|
||||
|
||||
class SafetyCLILegacyGroup(click.Group):
|
||||
"""
|
||||
Custom Click Group to handle legacy command-line arguments.
|
||||
"""
|
||||
|
||||
context_class = CustomContext
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.all_commands = {}
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return super().__call__(*args, **kwargs)
|
||||
|
||||
def add_command(self, cmd, name=None) -> None:
|
||||
super().add_command(cmd, name)
|
||||
|
||||
name = name or cmd.name
|
||||
self.all_commands[name] = cmd
|
||||
|
||||
def parse_args(self, ctx: click.Context, args: List[str]) -> List[str]:
|
||||
ctx = cast(CustomContext, ctx)
|
||||
|
||||
if len(args) >= 1:
|
||||
if "pip" in args[0] and ctx:
|
||||
ctx.command_alias_used = args[0]
|
||||
args[0] = "pip"
|
||||
|
||||
parsed_args = super().parse_args(ctx, args)
|
||||
|
||||
args = ctx.args
|
||||
|
||||
# Workaround for legacy check options, that now are global options
|
||||
subcommand_args = set(args)
|
||||
PROXY_HOST_OPTIONS = set(["--proxy-host", "-ph"])
|
||||
if (
|
||||
"check" in ctx.protected_args
|
||||
or "license" in ctx.protected_args
|
||||
and (
|
||||
bool(
|
||||
PROXY_HOST_OPTIONS.intersection(subcommand_args)
|
||||
or "--key" in subcommand_args
|
||||
)
|
||||
)
|
||||
):
|
||||
proxy_options, key = self.parse_legacy_args(args)
|
||||
if proxy_options:
|
||||
ctx.params.update(proxy_options)
|
||||
|
||||
if key:
|
||||
ctx.params.update({"key": key})
|
||||
|
||||
return parsed_args
|
||||
|
||||
def parse_legacy_args(
|
||||
self, args: List[str]
|
||||
) -> Tuple[Optional[Dict[str, str]], Optional[str]]:
|
||||
"""
|
||||
Parse legacy command-line arguments for proxy settings and keys.
|
||||
|
||||
Args:
|
||||
args (List[str]): List of command-line arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict[str, str]], Optional[str]]: Parsed proxy options and key.
|
||||
"""
|
||||
options = {"proxy_protocol": "https", "proxy_port": 80, "proxy_host": None}
|
||||
key = None
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
if arg in ["--proxy-protocol", "-pr"] and i + 1 < len(args):
|
||||
options["proxy_protocol"] = args[i + 1]
|
||||
elif arg in ["--proxy-port", "-pp"] and i + 1 < len(args):
|
||||
options["proxy_port"] = int(args[i + 1])
|
||||
elif arg in ["--proxy-host", "-ph"] and i + 1 < len(args):
|
||||
options["proxy_host"] = args[i + 1]
|
||||
elif arg in ["--key"] and i + 1 < len(args):
|
||||
key = args[i + 1]
|
||||
|
||||
proxy = options if options["proxy_host"] else None
|
||||
return proxy, key
|
||||
|
||||
def get_filtered_commands(self, ctx: click.Context) -> Dict[str, click.Command]:
|
||||
from safety.auth.utils import initialize
|
||||
|
||||
initialize(ctx, refresh=False)
|
||||
|
||||
# Filter commands here:
|
||||
from .constants import CONTEXT_FEATURE_TYPE
|
||||
|
||||
disabled_features = [
|
||||
feature_type
|
||||
for feature_type in FeatureType
|
||||
if not getattr(ctx.obj, feature_type.attr_name, False)
|
||||
]
|
||||
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.commands.items()
|
||||
if v.context_settings.get(CONTEXT_FEATURE_TYPE, None)
|
||||
not in disabled_features
|
||||
or k in ["firewall"]
|
||||
}
|
||||
|
||||
def invoke(self, ctx: click.Context) -> None:
|
||||
"""
|
||||
Invoke the command, handling legacy arguments.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
"""
|
||||
session_kwargs = {
|
||||
"ctx": ctx,
|
||||
"proxy_protocol": ctx.params.pop("proxy_protocol", None),
|
||||
"proxy_host": ctx.params.pop("proxy_host", None),
|
||||
"proxy_port": ctx.params.pop("proxy_port", None),
|
||||
"key": ctx.params.pop("key", None),
|
||||
"stage": ctx.params.pop("stage", None),
|
||||
}
|
||||
invoked_command = make_str(next(iter(ctx.protected_args), ""))
|
||||
|
||||
from safety.auth.cli_utils import inject_session
|
||||
|
||||
inject_session(**session_kwargs, invoked_command=invoked_command)
|
||||
|
||||
# call initialize if the --key is used.
|
||||
if session_kwargs["key"]:
|
||||
from safety.auth.utils import initialize
|
||||
|
||||
initialize(ctx, refresh=True)
|
||||
|
||||
self.commands = self.get_filtered_commands(ctx)
|
||||
|
||||
# Now, invoke the original behavior
|
||||
super(SafetyCLILegacyGroup, self).invoke(ctx)
|
||||
|
||||
def list_commands(self, ctx: click.Context) -> List[str]:
|
||||
"""Override click.Group.list_commands with custom filtering"""
|
||||
self.commands = self.get_filtered_commands(ctx)
|
||||
|
||||
return super().list_commands(ctx)
|
||||
|
||||
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""
|
||||
Format help message with rich formatting.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
formatter (click.HelpFormatter): Click help formatter.
|
||||
"""
|
||||
# The main `safety --help`
|
||||
if self.name == "cli":
|
||||
format_main_help(self, ctx, markup_mode="rich")
|
||||
# All other help outputs
|
||||
else:
|
||||
pretty_format_help(self, ctx, markup_mode="rich")
|
||||
|
||||
|
||||
class SafetyCLILegacyCommand(click.Command):
|
||||
"""
|
||||
Custom Click Command to handle legacy command-line arguments.
|
||||
"""
|
||||
|
||||
context_class = CustomContext
|
||||
|
||||
def format_help(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
||||
"""
|
||||
Format help message with rich formatting.
|
||||
|
||||
Args:
|
||||
ctx (click.Context): Click context.
|
||||
formatter (click.HelpFormatter): Click help formatter.
|
||||
"""
|
||||
pretty_format_help(self, ctx, markup_mode="rich")
|
||||
|
||||
|
||||
def get_git_branch_name() -> Optional[str]:
|
||||
"""
|
||||
Retrieves the current Git branch name.
|
||||
|
||||
Returns:
|
||||
str: The current Git branch name, or None if it cannot be determined.
|
||||
"""
|
||||
try:
|
||||
branch_name = subprocess.check_output(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
).strip()
|
||||
return branch_name if branch_name else None
|
||||
except Exception:
|
||||
return None
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,177 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safety.codebase.render import render_initialization_result
|
||||
from safety.errors import SafetyError
|
||||
from safety.events.utils.emission import (
|
||||
emit_codebase_detection_status,
|
||||
emit_codebase_setup_completed,
|
||||
)
|
||||
from safety.init.main import launch_auth_if_needed
|
||||
from safety.tool.main import find_local_tool_files
|
||||
from safety.util import clean_project_id
|
||||
from typing_extensions import Annotated
|
||||
|
||||
import typer
|
||||
from safety.cli_util import SafetyCLISubGroup, SafetyCLICommand
|
||||
from .constants import (
|
||||
CMD_CODEBASE_INIT_NAME,
|
||||
CMD_HELP_CODEBASE_INIT,
|
||||
CMD_HELP_CODEBASE,
|
||||
CMD_CODEBASE_GROUP_NAME,
|
||||
CMD_HELP_CODEBASE_INIT_DISABLE_FIREWALL,
|
||||
CMD_HELP_CODEBASE_INIT_LINK_TO,
|
||||
CMD_HELP_CODEBASE_INIT_NAME,
|
||||
CMD_HELP_CODEBASE_INIT_PATH,
|
||||
)
|
||||
|
||||
from ..cli_util import CommandType, get_command_for
|
||||
from ..error_handlers import handle_cmd_exception
|
||||
from ..decorators import notify
|
||||
from ..constants import CONTEXT_COMMAND_TYPE, DEFAULT_EPILOG
|
||||
from safety.console import main_console as console
|
||||
from .main import initialize_codebase, prepare_unverified_codebase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
cli_apps_opts = {
|
||||
"rich_markup_mode": "rich",
|
||||
"cls": SafetyCLISubGroup,
|
||||
"name": CMD_CODEBASE_GROUP_NAME,
|
||||
}
|
||||
codebase_app = typer.Typer(**cli_apps_opts)
|
||||
|
||||
DEFAULT_CMD = CMD_CODEBASE_INIT_NAME
|
||||
|
||||
|
||||
@codebase_app.callback(
|
||||
invoke_without_command=True,
|
||||
cls=SafetyCLISubGroup,
|
||||
help=CMD_HELP_CODEBASE,
|
||||
epilog=DEFAULT_EPILOG,
|
||||
context_settings={
|
||||
"allow_extra_args": True,
|
||||
"ignore_unknown_options": True,
|
||||
CONTEXT_COMMAND_TYPE: CommandType.BETA,
|
||||
},
|
||||
)
|
||||
def codebase(
|
||||
ctx: typer.Context,
|
||||
):
|
||||
"""
|
||||
Group command for Safety Codebase (project) operations. Running this command will forward to the default command.
|
||||
"""
|
||||
logger.info("codebase started")
|
||||
|
||||
# If no subcommand is invoked, forward to the default command
|
||||
if not ctx.invoked_subcommand:
|
||||
default_command = get_command_for(name=DEFAULT_CMD, typer_instance=codebase_app)
|
||||
return ctx.forward(default_command)
|
||||
|
||||
|
||||
@codebase_app.command(
|
||||
cls=SafetyCLICommand,
|
||||
help=CMD_HELP_CODEBASE_INIT,
|
||||
name=CMD_CODEBASE_INIT_NAME,
|
||||
epilog=DEFAULT_EPILOG,
|
||||
options_metavar="[OPTIONS]",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def init(
|
||||
ctx: typer.Context,
|
||||
name: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(
|
||||
help=CMD_HELP_CODEBASE_INIT_NAME,
|
||||
callback=lambda name: clean_project_id(name) if name else None,
|
||||
),
|
||||
] = None,
|
||||
link_to: Annotated[
|
||||
Optional[str],
|
||||
typer.Option(
|
||||
"--link-to",
|
||||
help=CMD_HELP_CODEBASE_INIT_LINK_TO,
|
||||
callback=lambda name: clean_project_id(name) if name else None,
|
||||
),
|
||||
] = None,
|
||||
skip_firewall_setup: Annotated[
|
||||
bool, typer.Option(help=CMD_HELP_CODEBASE_INIT_DISABLE_FIREWALL)
|
||||
] = False,
|
||||
codebase_path: Annotated[
|
||||
Path,
|
||||
typer.Option(
|
||||
"--path",
|
||||
exists=True,
|
||||
file_okay=False,
|
||||
dir_okay=True,
|
||||
writable=True,
|
||||
readable=True,
|
||||
resolve_path=True,
|
||||
show_default=False,
|
||||
help=CMD_HELP_CODEBASE_INIT_PATH,
|
||||
),
|
||||
] = Path("."),
|
||||
):
|
||||
"""
|
||||
Initialize a Safety Codebase. The codebase may be entirely new to Safety Platform,
|
||||
or may already exist in Safety Platform and the user is wanting to initialize it locally.
|
||||
"""
|
||||
logger.info("codebase init started")
|
||||
|
||||
if link_to and name:
|
||||
raise typer.BadParameter("--link-to and --name cannot be used together")
|
||||
|
||||
org_slug = launch_auth_if_needed(ctx, console)
|
||||
|
||||
if not org_slug:
|
||||
raise SafetyError(
|
||||
"Organization not found, please run 'safety auth status' or 'safety auth login'"
|
||||
)
|
||||
|
||||
should_enable_firewall = not skip_firewall_setup and ctx.obj.firewall_enabled
|
||||
|
||||
unverified_codebase = prepare_unverified_codebase(
|
||||
codebase_path=codebase_path,
|
||||
user_provided_name=name,
|
||||
user_provided_link_to=link_to,
|
||||
)
|
||||
|
||||
local_files = find_local_tool_files(codebase_path)
|
||||
|
||||
emit_codebase_detection_status(
|
||||
event_bus=ctx.obj.event_bus,
|
||||
ctx=ctx,
|
||||
detected=any(local_files),
|
||||
detected_files=local_files if local_files else None,
|
||||
)
|
||||
|
||||
project_file_created, project_status = initialize_codebase(
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
codebase_path=codebase_path,
|
||||
unverified_codebase=unverified_codebase,
|
||||
org_slug=org_slug,
|
||||
link_to=link_to,
|
||||
should_enable_firewall=should_enable_firewall,
|
||||
)
|
||||
|
||||
codebase_init_status = (
|
||||
"reinitialized" if unverified_codebase.created else project_status
|
||||
)
|
||||
codebase_id = ctx.obj.project.id if ctx.obj.project and ctx.obj.project.id else None
|
||||
|
||||
render_initialization_result(
|
||||
console=console,
|
||||
codebase_init_status=codebase_init_status,
|
||||
codebase_id=codebase_id,
|
||||
)
|
||||
|
||||
emit_codebase_setup_completed(
|
||||
event_bus=ctx.obj.event_bus,
|
||||
ctx=ctx,
|
||||
is_created=project_file_created,
|
||||
codebase_id=codebase_id,
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
CMD_HELP_CODEBASE_INIT = "Initialize a Safety Codebase (like git init for security). Sets up a new codebase or connects your local project to an existing one on Safety Platform."
|
||||
CMD_HELP_CODEBASE = (
|
||||
"[BETA] Manage your Safety Codebase integration.\nExample: safety codebase init"
|
||||
)
|
||||
|
||||
CMD_CODEBASE_GROUP_NAME = "codebase"
|
||||
CMD_CODEBASE_INIT_NAME = "init"
|
||||
|
||||
|
||||
# init options help
|
||||
CMD_HELP_CODEBASE_INIT_NAME = "Name of the codebase. Defaults to GIT origin name, parent directory name, or random string if parent directory is unnamed. The value will be normalized for use as an identifier."
|
||||
CMD_HELP_CODEBASE_INIT_LINK_TO = (
|
||||
"Link to an existing codebase using its codebase slug (found in Safety Platform)."
|
||||
)
|
||||
CMD_HELP_CODEBASE_INIT_DISABLE_FIREWALL = "Don't enable Firewall protection for this codebase (enabled by default when available in your organization)"
|
||||
CMD_HELP_CODEBASE_INIT_PATH = (
|
||||
"Path to the codebase directory. Defaults to current directory."
|
||||
)
|
||||
|
||||
|
||||
CODEBASE_INIT_REINITIALIZED = "Reinitialized existing codebase {codebase_name}"
|
||||
CODEBASE_INIT_ALREADY_EXISTS = "A codebase already exists in this directory. Please delete .safety-project.ini and run `safety codebase init` again to initialize a new codebase."
|
||||
CODEBASE_INIT_NOT_FOUND_LINK_TO = "\nError: codebase '{codebase_name}' specified with --link-to does not exist.\n\nTo create a new codebase instead, use one of:\n safety codebase init\n safety codebase init --name \"custom name\"\n\nTo link to an existing codebase, verify the codebase id and try again."
|
||||
CODEBASE_INIT_NOT_FOUND_PROJECT_FILE = "\nError: codebase '{codebase_name}' specified with the current .safety-project.ini file does not exist.\n\nTo create a new codebase instead, delete the corrupted .safety-project.ini file and then use one of:\n safety codebase init\n safety codebase init --name \"custom name\"\n\nTo link to an existing codebase, verify the codebase id and try again."
|
||||
CODEBASE_INIT_LINKED = "Linked to codebase {codebase_name}."
|
||||
CODEBASE_INIT_CREATED = "Created new codebase {codebase_name}."
|
||||
CODEBASE_INIT_ERROR = "Error: unable to initialize the codebase. Please try again."
|
||||
@@ -0,0 +1,115 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from ..codebase_utils import load_unverified_project_from_config
|
||||
from safety.errors import SafetyError, SafetyException
|
||||
from pathlib import Path
|
||||
from safety.codebase.constants import (
|
||||
CODEBASE_INIT_ERROR,
|
||||
CODEBASE_INIT_NOT_FOUND_LINK_TO,
|
||||
CODEBASE_INIT_NOT_FOUND_PROJECT_FILE,
|
||||
)
|
||||
from safety.init.main import create_project
|
||||
from typer import Context
|
||||
from rich.console import Console
|
||||
import sys
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..codebase_utils import UnverifiedProjectModel
|
||||
|
||||
|
||||
def initialize_codebase(
|
||||
ctx: Context,
|
||||
console: Console,
|
||||
codebase_path: Path,
|
||||
unverified_codebase: "UnverifiedProjectModel",
|
||||
org_slug: str,
|
||||
link_to: Optional[str] = None,
|
||||
should_enable_firewall: bool = False,
|
||||
):
|
||||
is_interactive = sys.stdin.isatty()
|
||||
link_behavior = "prompt"
|
||||
create_if_missing = True
|
||||
is_codebase_file_created = unverified_codebase.created
|
||||
|
||||
if link_to or is_codebase_file_created:
|
||||
link_behavior = "always"
|
||||
create_if_missing = False
|
||||
elif not is_interactive:
|
||||
link_behavior = "never"
|
||||
|
||||
project_file_created, project_status = create_project(
|
||||
ctx=ctx,
|
||||
console=console,
|
||||
target=codebase_path,
|
||||
unverified_project=unverified_codebase,
|
||||
create_if_missing=create_if_missing,
|
||||
link_behavior=link_behavior,
|
||||
)
|
||||
|
||||
if project_status == "not_found":
|
||||
codebase_name = "Unknown"
|
||||
msg = "Codebase not found."
|
||||
|
||||
if link_to:
|
||||
msg = CODEBASE_INIT_NOT_FOUND_LINK_TO
|
||||
codebase_name = link_to
|
||||
elif is_codebase_file_created:
|
||||
msg = CODEBASE_INIT_NOT_FOUND_PROJECT_FILE
|
||||
codebase_name = unverified_codebase.id
|
||||
|
||||
raise SafetyError(msg.format(codebase_name=codebase_name))
|
||||
elif project_status == "found" and not is_interactive:
|
||||
# Non-TTY mode: Project exists but we can't link (link_behavior="never")
|
||||
suggested_name = unverified_codebase.id
|
||||
raise SafetyError(
|
||||
f"Project '{suggested_name}' already exists. "
|
||||
f"In non-interactive mode, use --link-to '{suggested_name}' to link to the existing project, "
|
||||
f"or use --name with a different project name to create a new one."
|
||||
)
|
||||
|
||||
if not ctx.obj.project:
|
||||
raise SafetyException(CODEBASE_INIT_ERROR)
|
||||
|
||||
if should_enable_firewall:
|
||||
from ..tool.main import configure_local_directory
|
||||
|
||||
configure_local_directory(codebase_path, org_slug, ctx.obj.project.id)
|
||||
|
||||
return project_file_created, project_status
|
||||
|
||||
|
||||
def fail_if_codebase_name_mismatch(
|
||||
provided_name: str,
|
||||
unverified_codebase: "UnverifiedProjectModel",
|
||||
) -> None:
|
||||
"""
|
||||
Useful to prevent the user from overwriting an existing codebase by mistyping the name.
|
||||
"""
|
||||
if unverified_codebase.id and provided_name != unverified_codebase.id:
|
||||
from safety.codebase.constants import CODEBASE_INIT_ALREADY_EXISTS
|
||||
|
||||
raise SafetyError(CODEBASE_INIT_ALREADY_EXISTS)
|
||||
|
||||
|
||||
def prepare_unverified_codebase(
|
||||
codebase_path: Path,
|
||||
user_provided_name: Optional[str] = None,
|
||||
user_provided_link_to: Optional[str] = None,
|
||||
) -> "UnverifiedProjectModel":
|
||||
"""
|
||||
Prepare the unverified codebase object based on the provided name and link to.
|
||||
"""
|
||||
unverified_codebase = load_unverified_project_from_config(
|
||||
project_root=codebase_path
|
||||
)
|
||||
|
||||
provided_name = user_provided_name or user_provided_link_to
|
||||
|
||||
if provided_name:
|
||||
fail_if_codebase_name_mismatch(
|
||||
provided_name=provided_name,
|
||||
unverified_codebase=unverified_codebase,
|
||||
)
|
||||
|
||||
unverified_codebase.id = provided_name
|
||||
|
||||
return unverified_codebase
|
||||
@@ -0,0 +1,34 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
def render_initialization_result(
|
||||
console: "Console",
|
||||
codebase_init_status: Optional[str] = None,
|
||||
codebase_id: Optional[str] = None,
|
||||
):
|
||||
if not codebase_init_status or not codebase_id:
|
||||
console.print("Error: unable to initialize codebase")
|
||||
return
|
||||
|
||||
message = None
|
||||
|
||||
if codebase_init_status == "created":
|
||||
from safety.codebase.constants import CODEBASE_INIT_CREATED
|
||||
|
||||
message = CODEBASE_INIT_CREATED
|
||||
|
||||
if codebase_init_status == "linked":
|
||||
from safety.codebase.constants import CODEBASE_INIT_LINKED
|
||||
|
||||
message = CODEBASE_INIT_LINKED
|
||||
|
||||
if codebase_init_status == "reinitialized":
|
||||
from safety.codebase.constants import CODEBASE_INIT_REINITIALIZED
|
||||
|
||||
message = CODEBASE_INIT_REINITIALIZED
|
||||
|
||||
if message:
|
||||
console.print(message.format(codebase_name=codebase_id))
|
||||
@@ -0,0 +1,89 @@
|
||||
import configparser
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from safety_schemas.models import ProjectModel
|
||||
|
||||
|
||||
PROJECT_CONFIG = ".safety-project.ini"
|
||||
PROJECT_CONFIG_SECTION = "project"
|
||||
PROJECT_CONFIG_ID = "id"
|
||||
PROJECT_CONFIG_URL = "url"
|
||||
PROJECT_CONFIG_NAME = "name"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnverifiedProjectModel:
|
||||
"""
|
||||
Data class representing an unverified project model.
|
||||
"""
|
||||
|
||||
id: Optional[str]
|
||||
project_path: Path
|
||||
created: bool
|
||||
name: Optional[str] = None
|
||||
url_path: Optional[str] = None
|
||||
|
||||
|
||||
def load_unverified_project_from_config(project_root: Path) -> UnverifiedProjectModel:
|
||||
"""
|
||||
Loads an unverified project from the configuration file located at the project root.
|
||||
|
||||
Args:
|
||||
project_root (Path): The root directory of the project.
|
||||
|
||||
Returns:
|
||||
UnverifiedProjectModel: An instance of UnverifiedProjectModel.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
project_path = project_root / PROJECT_CONFIG
|
||||
config.read(project_path)
|
||||
id = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_ID, fallback=None)
|
||||
url = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_URL, fallback=None)
|
||||
name = config.get(PROJECT_CONFIG_SECTION, PROJECT_CONFIG_NAME, fallback=None)
|
||||
created = True
|
||||
if not id:
|
||||
created = False
|
||||
|
||||
return UnverifiedProjectModel(
|
||||
id=id, url_path=url, name=name, project_path=project_path, created=created
|
||||
)
|
||||
|
||||
|
||||
def save_project_info(project: ProjectModel, project_path: Path) -> bool:
|
||||
"""
|
||||
Saves the project information to the configuration file.
|
||||
|
||||
Args:
|
||||
project (ProjectModel): The ProjectModel object containing project
|
||||
information.
|
||||
project_path (Path): The path to the configuration file.
|
||||
|
||||
Returns:
|
||||
bool: True if the project information was saved successfully, False
|
||||
otherwise.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(project_path)
|
||||
|
||||
if PROJECT_CONFIG_SECTION not in config.sections():
|
||||
config[PROJECT_CONFIG_SECTION] = {}
|
||||
|
||||
config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_ID] = project.id
|
||||
if project.url_path:
|
||||
config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_URL] = project.url_path
|
||||
if project.name:
|
||||
config[PROJECT_CONFIG_SECTION][PROJECT_CONFIG_NAME] = project.name
|
||||
|
||||
try:
|
||||
with open(project_path, "w") as configfile:
|
||||
config.write(configfile)
|
||||
except Exception:
|
||||
logger.exception("Error saving project info")
|
||||
return False
|
||||
|
||||
return True
|
||||
149
Backend/venv/lib/python3.12/site-packages/safety/console.py
Normal file
149
Backend/venv/lib/python3.12/site-packages/safety/console.py
Normal file
@@ -0,0 +1,149 @@
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union
|
||||
from rich.console import Console
|
||||
from rich.theme import Theme
|
||||
from safety.emoji import load_emoji
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import HighlighterType, JustifyMethod, OverflowMethod
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def should_use_ascii():
|
||||
"""
|
||||
Check if we should use ASCII alternatives for emojis
|
||||
"""
|
||||
encoding = getattr(sys.stdout, "encoding", "").lower()
|
||||
|
||||
if encoding in {"utf-8", "utf8", "cp65001", "utf-8-sig"}:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_spinner_animation() -> List[str]:
|
||||
"""
|
||||
Get the spinner animation based on the encoding
|
||||
"""
|
||||
if should_use_ascii():
|
||||
spinner = [
|
||||
"[ ]",
|
||||
"[= ]",
|
||||
"[== ]",
|
||||
"[=== ]",
|
||||
"[====]",
|
||||
"[ ===]",
|
||||
"[ ==]",
|
||||
"[ =]",
|
||||
]
|
||||
else:
|
||||
spinner = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
||||
return spinner
|
||||
|
||||
|
||||
def replace_non_ascii_chars(text: str):
|
||||
"""
|
||||
Replace non-ascii characters with ascii alternatives
|
||||
"""
|
||||
CHARS_MAP = {
|
||||
"━": "-",
|
||||
"’": "'",
|
||||
}
|
||||
|
||||
for char, replacement in CHARS_MAP.items():
|
||||
text = text.replace(char, replacement)
|
||||
|
||||
try:
|
||||
text.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
LOG.warning("No handled non-ascii characters detected, encoding with replace")
|
||||
text = text.encode("ascii", "replace").decode("ascii")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class SafeConsole(Console):
|
||||
"""
|
||||
Console subclass that handles emoji encoding issues by detecting
|
||||
problematic encoding environments and replacing emojis with ASCII alternatives.
|
||||
Uses string replacement for custom emoji namespace to avoid private API usage.
|
||||
"""
|
||||
|
||||
def render_str(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
style: Union[str, "Style"] = "",
|
||||
justify: Optional["JustifyMethod"] = None,
|
||||
overflow: Optional["OverflowMethod"] = None,
|
||||
emoji: Optional[bool] = None,
|
||||
markup: Optional[bool] = None,
|
||||
highlight: Optional[bool] = None,
|
||||
highlighter: Optional["HighlighterType"] = None,
|
||||
) -> "Text":
|
||||
"""
|
||||
Override render_str to pre-process our custom emojis before Rich handles the text.
|
||||
"""
|
||||
|
||||
use_ascii = should_use_ascii()
|
||||
text = load_emoji(text, use_ascii=use_ascii)
|
||||
|
||||
if use_ascii:
|
||||
text = replace_non_ascii_chars(text)
|
||||
|
||||
# Let Rich handle everything else normally
|
||||
return super().render_str(
|
||||
text,
|
||||
style=style,
|
||||
justify=justify,
|
||||
overflow=overflow,
|
||||
emoji=emoji,
|
||||
markup=markup,
|
||||
highlight=highlight,
|
||||
highlighter=highlighter,
|
||||
)
|
||||
|
||||
|
||||
SAFETY_THEME = {
|
||||
"file_title": "bold default on default",
|
||||
"dep_name": "bold yellow on default",
|
||||
"scan_meta_title": "bold default on default",
|
||||
"vuln_brief": "red on default",
|
||||
"rem_brief": "bold green on default",
|
||||
"rem_severity": "bold red on default",
|
||||
"brief_severity": "bold default on default",
|
||||
"status.spinner": "green",
|
||||
"recommended_ver": "bold cyan on default",
|
||||
"vuln_id": "bold default on default",
|
||||
"number": "bold cyan on default",
|
||||
"link": "underline bright_blue on default",
|
||||
"tip": "bold default on default",
|
||||
"specifier": "bold cyan on default",
|
||||
"vulns_found_number": "red on default",
|
||||
}
|
||||
|
||||
|
||||
non_interactive = os.getenv("NON_INTERACTIVE") == "1"
|
||||
|
||||
console_kwargs: Dict[str, Any] = {
|
||||
"theme": Theme(SAFETY_THEME, inherit=False),
|
||||
"emoji": not should_use_ascii(),
|
||||
}
|
||||
|
||||
if non_interactive:
|
||||
LOG.info(
|
||||
"NON_INTERACTIVE environment variable is set, forcing non-interactive mode"
|
||||
)
|
||||
console_kwargs["force_terminal"] = True
|
||||
console_kwargs["force_interactive"] = False
|
||||
|
||||
main_console = SafeConsole(**console_kwargs)
|
||||
259
Backend/venv/lib/python3.12/site-packages/safety/constants.py
Normal file
259
Backend/venv/lib/python3.12/site-packages/safety/constants.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import configparser
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from safety.meta import get_version
|
||||
|
||||
JSON_SCHEMA_VERSION = "2.0.0"
|
||||
|
||||
# TODO fix this
|
||||
OPEN_MIRRORS = [
|
||||
f"https://pyup.io/aws/safety/free/{JSON_SCHEMA_VERSION}/",
|
||||
]
|
||||
|
||||
DIR_NAME = ".safety"
|
||||
|
||||
|
||||
def get_system_dir() -> Path:
|
||||
"""
|
||||
Get the system directory for the safety configuration.
|
||||
|
||||
Returns:
|
||||
Path: The system directory path.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
raw_dir = os.getenv("SAFETY_SYSTEM_CONFIG_PATH")
|
||||
app_data = os.environ.get("ALLUSERSPROFILE", None)
|
||||
|
||||
if not raw_dir:
|
||||
if sys.platform.startswith("win") and app_data:
|
||||
raw_dir = app_data
|
||||
elif sys.platform.startswith("darwin"):
|
||||
raw_dir = "/Library/Application Support"
|
||||
elif sys.platform.startswith("linux"):
|
||||
raw_dir = "/etc"
|
||||
else:
|
||||
raw_dir = "/"
|
||||
|
||||
return Path(raw_dir, DIR_NAME)
|
||||
|
||||
|
||||
def get_user_dir() -> Path:
|
||||
"""
|
||||
Get the user directory for the safety configuration.
|
||||
|
||||
Returns:
|
||||
Path: The user directory path.
|
||||
"""
|
||||
path = Path("~", DIR_NAME).expanduser()
|
||||
return path
|
||||
|
||||
|
||||
USER_CONFIG_DIR = get_user_dir()
|
||||
SYSTEM_CONFIG_DIR = get_system_dir()
|
||||
|
||||
CACHE_FILE_DIR = USER_CONFIG_DIR / f"{JSON_SCHEMA_VERSION.replace('.', '')}"
|
||||
DB_CACHE_FILE = CACHE_FILE_DIR / "cache.json"
|
||||
|
||||
CONFIG_FILE_NAME = "config.ini"
|
||||
CONFIG_FILE_SYSTEM = SYSTEM_CONFIG_DIR / CONFIG_FILE_NAME if SYSTEM_CONFIG_DIR else None
|
||||
CONFIG_FILE_USER = USER_CONFIG_DIR / CONFIG_FILE_NAME
|
||||
|
||||
CONFIG = (
|
||||
CONFIG_FILE_SYSTEM
|
||||
if CONFIG_FILE_SYSTEM and CONFIG_FILE_SYSTEM.exists()
|
||||
else CONFIG_FILE_USER
|
||||
)
|
||||
|
||||
SAFETY_POLICY_FILE_NAME = ".safety-policy.yml"
|
||||
SYSTEM_POLICY_FILE = SYSTEM_CONFIG_DIR / SAFETY_POLICY_FILE_NAME
|
||||
USER_POLICY_FILE = USER_CONFIG_DIR / SAFETY_POLICY_FILE_NAME
|
||||
|
||||
DEFAULT_DOMAIN = "safetycli.com"
|
||||
DEFAULT_EMAIL = f"support@{DEFAULT_DOMAIN}"
|
||||
|
||||
|
||||
class URLSettings(Enum):
|
||||
PLATFORM_API_BASE_URL = f"https://platform.{DEFAULT_DOMAIN}/cli/api/v1"
|
||||
DATA_API_BASE_URL = f"https://data.{DEFAULT_DOMAIN}/api/v1/safety/"
|
||||
CLIENT_ID = "AWnwFBMr9DdZbxbDwYxjm4Gb24pFTnMp"
|
||||
AUTH_SERVER_URL = f"https://auth.{DEFAULT_DOMAIN}"
|
||||
SAFETY_PLATFORM_URL = f"https://platform.{DEFAULT_DOMAIN}"
|
||||
FIREWALL_API_BASE_URL = "https://pkgs.safetycli.com"
|
||||
|
||||
|
||||
class FeatureType(Enum):
|
||||
"""
|
||||
Defines server-controlled features for dynamic feature management.
|
||||
|
||||
Each enum value represents a toggleable feature controlled through
|
||||
server-side configuration, enabling gradual rollouts to different user
|
||||
segments. Features are cached during CLI initialization.
|
||||
|
||||
History:
|
||||
Created to support progressive feature rollouts and A/B testing without
|
||||
disturbing users.
|
||||
"""
|
||||
|
||||
FIREWALL = "firewall"
|
||||
PLATFORM = "platform"
|
||||
EVENTS = "events"
|
||||
|
||||
@property
|
||||
def config_key(self) -> str:
|
||||
"""For JSON/config lookup e.g. 'feature-a-enabled'"""
|
||||
return f"{self.name.lower()}-enabled"
|
||||
|
||||
@property
|
||||
def attr_name(self) -> str:
|
||||
"""For Python attribute access e.g. 'feature_a_enabled'"""
|
||||
return f"{self.name.lower()}_enabled"
|
||||
|
||||
|
||||
def get_config_setting(name: str, default=None) -> Optional[str]:
|
||||
"""
|
||||
Get the configuration setting from the config file or defaults.
|
||||
|
||||
Args:
|
||||
name (str): The name of the setting to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The value of the setting if found, otherwise None.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
if name in [setting.name for setting in URLSettings]:
|
||||
default = URLSettings[name]
|
||||
|
||||
if "settings" in config.sections() and name in config["settings"]:
|
||||
value = config["settings"][name]
|
||||
if value:
|
||||
return value
|
||||
|
||||
return default.value if default else default
|
||||
|
||||
|
||||
DATA_API_BASE_URL = get_config_setting("DATA_API_BASE_URL")
|
||||
PLATFORM_API_BASE_URL = get_config_setting("PLATFORM_API_BASE_URL")
|
||||
|
||||
PLATFORM_API_PROJECT_ENDPOINT = f"{PLATFORM_API_BASE_URL}/project"
|
||||
PLATFORM_API_PROJECT_CHECK_ENDPOINT = f"{PLATFORM_API_BASE_URL}/project-check"
|
||||
PLATFORM_API_POLICY_ENDPOINT = f"{PLATFORM_API_BASE_URL}/policy"
|
||||
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT = (
|
||||
f"{PLATFORM_API_BASE_URL}/project-scan-request"
|
||||
)
|
||||
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT = f"{PLATFORM_API_BASE_URL}/scan"
|
||||
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT = (
|
||||
f"{PLATFORM_API_BASE_URL}/process_files"
|
||||
)
|
||||
PLATFORM_API_CHECK_UPDATES_ENDPOINT = f"{PLATFORM_API_BASE_URL}/versions-and-configs"
|
||||
PLATFORM_API_INITIALIZE_ENDPOINT = f"{PLATFORM_API_BASE_URL}/initialize"
|
||||
PLATFORM_API_EVENTS_ENDPOINT = f"{PLATFORM_API_BASE_URL}/events"
|
||||
|
||||
FIREWALL_API_BASE_URL = get_config_setting("FIREWALL_API_BASE_URL")
|
||||
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT = f"{FIREWALL_API_BASE_URL}/audit/pypi/packages/"
|
||||
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT = (
|
||||
f"{FIREWALL_API_BASE_URL}/audit/npmjs/packages/"
|
||||
)
|
||||
|
||||
API_MIRRORS = [DATA_API_BASE_URL]
|
||||
|
||||
# Fetch the REQUEST_TIMEOUT from the environment variable, defaulting to 30 if not set
|
||||
REQUEST_TIMEOUT = int(os.getenv("SAFETY_REQUEST_TIMEOUT", 30))
|
||||
|
||||
# Colors
|
||||
YELLOW = "yellow"
|
||||
RED = "red"
|
||||
GREEN = "green"
|
||||
|
||||
# MESSAGES
|
||||
IGNORE_UNPINNED_REQ_REASON = (
|
||||
"This vulnerability is being ignored due to the 'ignore-unpinned-requirements' flag (default True). "
|
||||
"To change this, set 'ignore-unpinned-requirements' to False under 'security' in your policy file. "
|
||||
"See https://docs.pyup.io/docs/safety-20-policy-file for more information."
|
||||
)
|
||||
|
||||
# REGEXES
|
||||
HASH_REGEX_GROUPS = r"--hash[=| ](\w+):(\w+)"
|
||||
|
||||
DOCS_API_KEY_URL = "https://docs.safetycli.com/cli/api-keys"
|
||||
MSG_NO_AUTHD_DEV_STG = "Please login or register Safety CLI [bold](free forever)[/bold] to scan and secure your projects with Safety"
|
||||
MSG_NO_AUTHD_DEV_STG_PROMPT = "(R)egister for a free account in 30 seconds, or (L)ogin with an existing account to continue (R/L)"
|
||||
MSG_NO_AUTHD_DEV_STG_ORG_PROMPT = "Please log in to secure your projects with Safety. Press enter to continue to log in (Y/N)"
|
||||
MSG_NO_AUTHD_CICD_PROD_STG = "Enter your Safety API key to scan projects in CI/CD using the --key argument or setting your API key in the SAFETY_API_KEY environment variable."
|
||||
MSG_NO_AUTHD_CICD_PROD_STG_ORG = f"""
|
||||
Login to get your API key
|
||||
|
||||
To log in: [link]{{LOGIN_URL}}[/link]
|
||||
|
||||
Read more at: [link]{DOCS_API_KEY_URL}[/link]
|
||||
"""
|
||||
|
||||
MSG_NO_AUTHD_NOTE_CICD_PROD_STG_TPL = f"""
|
||||
Login or register for a free account to get your API key
|
||||
|
||||
To log in: [link]{{LOGIN_URL}}[/link]
|
||||
To register: [link]{{SIGNUP_URL}}[/link]
|
||||
|
||||
Read more at: [link]{DOCS_API_KEY_URL}[/link]
|
||||
"""
|
||||
|
||||
MSG_FINISH_REGISTRATION_TPL = (
|
||||
"To complete your account open the “verify your email” email sent to {email}"
|
||||
)
|
||||
|
||||
MSG_VERIFICATION_HINT = "Can’t find the verification email? Login at [link]`https://platform.safetycli.com/login/`[/link] to resend the verification email"
|
||||
|
||||
MSG_NO_VERIFIED_EMAIL_TPL = f"""Email verification is required for {{email}}
|
||||
|
||||
{MSG_FINISH_REGISTRATION_TPL}
|
||||
|
||||
{MSG_VERIFICATION_HINT}"""
|
||||
|
||||
# Exit codes
|
||||
EXIT_CODE_OK = 0
|
||||
EXIT_CODE_FAILURE = 1
|
||||
EXIT_CODE_VULNERABILITIES_FOUND = 64
|
||||
EXIT_CODE_INVALID_AUTH_CREDENTIAL = 65
|
||||
EXIT_CODE_TOO_MANY_REQUESTS = 66
|
||||
EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB = 67
|
||||
EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB = 68
|
||||
EXIT_CODE_MALFORMED_DB = 69
|
||||
EXIT_CODE_INVALID_PROVIDED_REPORT = 70
|
||||
EXIT_CODE_INVALID_REQUIREMENT = 71
|
||||
EXIT_CODE_EMAIL_NOT_VERIFIED = 72
|
||||
|
||||
# For Depreciated Messages
|
||||
BAR_LINE = "+===========================================================================================================================================================================================+"
|
||||
|
||||
BETA_PANEL_DESCRIPTION_HELP = "These commands are experimental and part of our commitment to delivering innovative features. As we refine functionality, they may be significantly altered or, in rare cases, removed without prior notice. We welcome your feedback and encourage cautious use."
|
||||
|
||||
CONTEXT_COMMAND_TYPE = "command_type"
|
||||
CONTEXT_FEATURE_TYPE = "feature_type"
|
||||
|
||||
CLI_VERSION = get_version()
|
||||
CLI_WEBSITE_URL = "https://safetycli.com"
|
||||
CLI_DOCUMENTATION_URL = "https://docs.safetycli.com"
|
||||
CLI_SUPPORT_EMAIL = "support@safetycli.com"
|
||||
|
||||
# Main Safety --help data:
|
||||
CLI_MAIN_INTRODUCTION = (
|
||||
f"Safety CLI 3 - Vulnerability Scanning for Secure Python Development\n\n"
|
||||
"Leverage the most comprehensive vulnerability data available to secure your projects against vulnerable and malicious packages. Safety CLI is a Python dependency vulnerability scanner that enhances software supply chain security at every stage of development.\n\n"
|
||||
f"Documentation: {CLI_DOCUMENTATION_URL}\n"
|
||||
f"Contact: {CLI_SUPPORT_EMAIL}\n"
|
||||
)
|
||||
|
||||
DEFAULT_EPILOG = (
|
||||
f"\nSafety CLI version: {CLI_VERSION}\n"
|
||||
f"\nDocumentation: {CLI_DOCUMENTATION_URL}\n\n\n\n"
|
||||
"Made with love by Safety Cybersecurity\n\n"
|
||||
f"{CLI_WEBSITE_URL}\n\n"
|
||||
f"{CLI_SUPPORT_EMAIL}\n"
|
||||
)
|
||||
@@ -0,0 +1,40 @@
|
||||
from functools import wraps
|
||||
|
||||
from safety.events.utils import emit_command_executed
|
||||
|
||||
|
||||
def notify(func):
|
||||
"""
|
||||
A decorator that wraps a function to emit events.
|
||||
|
||||
Args:
|
||||
func (callable): The function to be wrapped by the decorator.
|
||||
|
||||
Returns:
|
||||
callable: The wrapped function with notification logic.
|
||||
|
||||
The decorator ensures that the `emit_command_executed` function is called
|
||||
after the wrapped function completes, regardless of whether it exits
|
||||
normally or via a `SystemExit` exception.
|
||||
|
||||
Example:
|
||||
@notify
|
||||
def my_function(ctx, *args, **kwargs):
|
||||
# function implementation
|
||||
pass
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def inner(ctx, *args, **kwargs):
|
||||
try:
|
||||
result = func(ctx, *args, **kwargs)
|
||||
emit_command_executed(ctx.obj.event_bus, ctx, returned_code=0)
|
||||
return result
|
||||
except SystemExit as e:
|
||||
# Handle sys.exit() case
|
||||
exit_code = e.code if isinstance(e.code, int) else 1
|
||||
emit_command_executed(ctx.obj.event_bus, ctx, returned_code=exit_code)
|
||||
raise
|
||||
# Any other exceptions will bypass notification and propagate normally
|
||||
|
||||
return inner
|
||||
104
Backend/venv/lib/python3.12/site-packages/safety/emoji.py
Normal file
104
Backend/venv/lib/python3.12/site-packages/safety/emoji.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Custom emoji namespace mapping
|
||||
import re
|
||||
from typing import Match
|
||||
|
||||
|
||||
CUSTOM_EMOJI_MAP = {
|
||||
"icon_check": "✓",
|
||||
"icon_warning": "⚠️",
|
||||
"icon_info": "ℹ️",
|
||||
}
|
||||
|
||||
# ASCII fallback mapping for problematic environments
|
||||
ASCII_FALLBACK_MAP = {
|
||||
"icon_check": "+",
|
||||
"icon_warning": "!",
|
||||
"icon_info": "i",
|
||||
"white_heavy_check_mark": "++",
|
||||
"white_check_mark": "+",
|
||||
"check_mark": "+",
|
||||
"heavy_check_mark": "+",
|
||||
"shield": "[SHIELD]",
|
||||
"x": "X",
|
||||
"lock": "[LOCK]",
|
||||
"key": "[KEY]",
|
||||
"pencil": "[EDIT]",
|
||||
"arrow_up": "^",
|
||||
"stop_sign": "[STOP]",
|
||||
"warning": "!",
|
||||
"locked": "[LOCK]",
|
||||
"pushpin": "[PIN]",
|
||||
"magnifying_glass_tilted_left": "[SCAN]",
|
||||
"fire": "[CRIT]",
|
||||
"yellow_circle": "[HIGH]",
|
||||
"sparkles": "*",
|
||||
"mag_right": "[VIEW]",
|
||||
"link": "->",
|
||||
"light_bulb": "[TIP]",
|
||||
"trophy": "[DONE]",
|
||||
"rocket": ">>",
|
||||
"busts_in_silhouette": "[TEAM]",
|
||||
"floppy_disk": "[SAVE]",
|
||||
"heavy_plus_sign": "[ADD]",
|
||||
"books": "[DOCS]",
|
||||
"speech_balloon": "[HELP]",
|
||||
}
|
||||
|
||||
# Pre-compiled regex for emoji processing (Rich-style)
|
||||
CUSTOM_EMOJI_PATTERN = re.compile(r"(:icon_\w+:)")
|
||||
|
||||
|
||||
def process_custom_emojis(text: str, use_ascii: bool = False) -> str:
|
||||
"""
|
||||
Pre-process our custom emoji namespace before Rich handles the text.
|
||||
This only handles our custom :icon_*: emojis.
|
||||
"""
|
||||
if not isinstance(text, str) or ":icon_" not in text:
|
||||
return text
|
||||
|
||||
def replace_custom_emoji(match: Match[str]) -> str:
|
||||
emoji_code = match.group(1) # :icon_check:
|
||||
emoji_name = emoji_code[1:-1] # icon_check
|
||||
|
||||
# If we should use ASCII, use the fallback
|
||||
if use_ascii:
|
||||
return ASCII_FALLBACK_MAP.get(emoji_name, emoji_code)
|
||||
|
||||
return CUSTOM_EMOJI_MAP.get(emoji_name, emoji_code)
|
||||
|
||||
return CUSTOM_EMOJI_PATTERN.sub(replace_custom_emoji, text)
|
||||
|
||||
|
||||
def process_rich_emojis_fallback(text: str) -> str:
|
||||
"""
|
||||
Replace Rich emoji codes with ASCII alternatives when in problematic environments.
|
||||
"""
|
||||
# Simple pattern to match Rich emoji codes like :emoji_name:
|
||||
emoji_pattern = re.compile(r":([a-zA-Z0-9_]+):")
|
||||
|
||||
def replace_with_ascii(match: Match[str]) -> str:
|
||||
emoji_name = match.group(1)
|
||||
# Check if we have an ASCII fallback
|
||||
ascii_replacement = ASCII_FALLBACK_MAP.get(emoji_name, None)
|
||||
if ascii_replacement:
|
||||
return ascii_replacement
|
||||
|
||||
# Otherwise keep the original
|
||||
return match.group(0)
|
||||
|
||||
return emoji_pattern.sub(replace_with_ascii, text)
|
||||
|
||||
|
||||
def load_emoji(text: str, use_ascii: bool = False) -> str:
|
||||
"""
|
||||
Load emoji from text if emoji is present.
|
||||
"""
|
||||
|
||||
# Pre-process our custom emojis
|
||||
text = process_custom_emojis(text, use_ascii)
|
||||
|
||||
# If we need ASCII fallbacks, also process Rich emoji codes
|
||||
if use_ascii:
|
||||
text = process_rich_emojis_fallback(text)
|
||||
|
||||
return text
|
||||
28
Backend/venv/lib/python3.12/site-packages/safety/encoding.py
Normal file
28
Backend/venv/lib/python3.12/site-packages/safety/encoding.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def detect_encoding(file_path: Path) -> str:
|
||||
"""
|
||||
UTF-8 is the most common encoding standard, this is a simple
|
||||
way to improve the support for related Windows based files.
|
||||
|
||||
Handles the most common cases efficiently.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
# Read first 3 bytes for BOM detection
|
||||
bom = f.read(3)
|
||||
|
||||
# Check most common Windows patterns first
|
||||
if bom[:2] in (b"\xff\xfe", b"\xfe\xff"):
|
||||
return "utf-16"
|
||||
elif bom.startswith(b"\xef\xbb\xbf"):
|
||||
return "utf-8-sig"
|
||||
|
||||
return "utf-8"
|
||||
except Exception:
|
||||
logger.exception("Error detecting encoding")
|
||||
return "utf-8"
|
||||
@@ -0,0 +1,89 @@
|
||||
# Standard library imports
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
# Third-party imports
|
||||
import click
|
||||
|
||||
# Local imports
|
||||
from safety.constants import EXIT_CODE_FAILURE, EXIT_CODE_OK
|
||||
from safety.errors import SafetyError, SafetyException
|
||||
|
||||
from safety.events.utils import emit_command_error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from safety.scan.models import ScanOutput
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def output_exception(exception: Exception, exit_code_output: bool = True) -> None:
|
||||
"""
|
||||
Output an exception message to the console and exit.
|
||||
|
||||
Args:
|
||||
exception (Exception): The exception to output.
|
||||
exit_code_output (bool): Whether to output the exit code.
|
||||
|
||||
Exits:
|
||||
Exits the program with the appropriate exit code.
|
||||
"""
|
||||
click.secho(str(exception), fg="red", file=sys.stderr)
|
||||
|
||||
if exit_code_output:
|
||||
exit_code = EXIT_CODE_FAILURE
|
||||
if hasattr(exception, "get_exit_code"):
|
||||
exit_code = exception.get_exit_code()
|
||||
else:
|
||||
exit_code = EXIT_CODE_OK
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
def handle_cmd_exception(func):
|
||||
"""
|
||||
Decorator to handle exceptions in command functions.
|
||||
|
||||
Args:
|
||||
func: The command function to wrap.
|
||||
|
||||
Returns:
|
||||
The wrapped function.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def inner(ctx, output: Optional["ScanOutput"] = None, *args, **kwargs):
|
||||
if output:
|
||||
from safety.scan.models import ScanOutput
|
||||
|
||||
kwargs.update({"output": output})
|
||||
|
||||
if output is ScanOutput.NONE:
|
||||
return func(ctx, *args, **kwargs)
|
||||
|
||||
try:
|
||||
return func(ctx, *args, **kwargs)
|
||||
except click.ClickException as e:
|
||||
emit_command_error(
|
||||
ctx.obj.event_bus, ctx, message=str(e), traceback=traceback.format_exc()
|
||||
)
|
||||
raise e
|
||||
except SafetyError as e:
|
||||
LOG.exception("Expected SafetyError happened: %s", e)
|
||||
emit_command_error(
|
||||
ctx.obj.event_bus, ctx, message=str(e), traceback=traceback.format_exc()
|
||||
)
|
||||
output_exception(e, exit_code_output=True)
|
||||
except Exception as e:
|
||||
emit_command_error(
|
||||
ctx.obj.event_bus, ctx, message=str(e), traceback=traceback.format_exc()
|
||||
)
|
||||
LOG.exception("Unexpected Exception happened: %s", e)
|
||||
exception = e if isinstance(e, SafetyException) else SafetyException(info=e)
|
||||
output_exception(exception, exit_code_output=True)
|
||||
|
||||
return inner
|
||||
287
Backend/venv/lib/python3.12/site-packages/safety/errors.py
Normal file
287
Backend/venv/lib/python3.12/site-packages/safety/errors.py
Normal file
@@ -0,0 +1,287 @@
|
||||
from typing import Optional
|
||||
|
||||
from safety.constants import (
|
||||
EXIT_CODE_EMAIL_NOT_VERIFIED,
|
||||
EXIT_CODE_FAILURE,
|
||||
EXIT_CODE_INVALID_AUTH_CREDENTIAL,
|
||||
EXIT_CODE_INVALID_PROVIDED_REPORT,
|
||||
EXIT_CODE_INVALID_REQUIREMENT,
|
||||
EXIT_CODE_MALFORMED_DB,
|
||||
EXIT_CODE_TOO_MANY_REQUESTS,
|
||||
EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB,
|
||||
EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB,
|
||||
)
|
||||
|
||||
|
||||
class SafetyException(Exception):
|
||||
"""
|
||||
Base exception for Safety CLI errors.
|
||||
|
||||
Args:
|
||||
message (str): The error message template.
|
||||
info (str): Additional information to include in the error message.
|
||||
"""
|
||||
def __init__(self, message: str = "Unhandled exception happened: {info}", info: str = ""):
|
||||
self.message = message.format(info=info)
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this exception.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_FAILURE
|
||||
|
||||
|
||||
class SafetyError(Exception):
|
||||
"""
|
||||
Generic Safety CLI error.
|
||||
|
||||
Args:
|
||||
message (str): The error message.
|
||||
error_code (Optional[int]): The error code.
|
||||
"""
|
||||
def __init__(self, message: str = "Unhandled Safety generic error", error_code: Optional[int] = None):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_FAILURE
|
||||
|
||||
|
||||
class MalformedDatabase(SafetyError):
|
||||
"""
|
||||
Error raised when the vulnerability database is malformed.
|
||||
|
||||
Args:
|
||||
reason (Optional[str]): The reason for the error.
|
||||
fetched_from (str): The source of the fetched data.
|
||||
message (str): The error message template.
|
||||
"""
|
||||
def __init__(self, reason: Optional[str] = None, fetched_from: str = "server",
|
||||
message: str = "Sorry, something went wrong.\n"
|
||||
"Safety CLI cannot read the data fetched from {fetched_from} because it is malformed.\n"):
|
||||
info = f"Reason, {reason}" if reason else ""
|
||||
info = "Reason, {reason}".format(reason=reason)
|
||||
self.message = message.format(fetched_from=fetched_from) + (info if reason else "")
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_MALFORMED_DB
|
||||
|
||||
|
||||
class DatabaseFetchError(SafetyError):
|
||||
"""
|
||||
Error raised when the vulnerability database cannot be fetched.
|
||||
|
||||
Args:
|
||||
message (str): The error message.
|
||||
"""
|
||||
def __init__(self, message: str = "Unable to load vulnerability database"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_UNABLE_TO_FETCH_VULNERABILITY_DB
|
||||
|
||||
|
||||
class InvalidProvidedReportError(SafetyError):
|
||||
"""
|
||||
Error raised when the provided report is invalid for applying fixes.
|
||||
|
||||
Args:
|
||||
message (str): The error message.
|
||||
"""
|
||||
def __init__(self, message: str = "Unable to apply fix: the report needs to be generated from a file. "
|
||||
"Environment isn't supported yet."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_INVALID_PROVIDED_REPORT
|
||||
|
||||
|
||||
class InvalidRequirementError(SafetyError):
|
||||
"""
|
||||
Error raised when a requirement is invalid.
|
||||
|
||||
Args:
|
||||
message (str): The error message template.
|
||||
line (str): The invalid requirement line.
|
||||
"""
|
||||
def __init__(self, message: str = "Unable to parse the requirement: {line}", line: str = ""):
|
||||
self.message = message.format(line=line)
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_INVALID_REQUIREMENT
|
||||
|
||||
|
||||
class DatabaseFileNotFoundError(DatabaseFetchError):
|
||||
"""
|
||||
Error raised when the vulnerability database file is not found.
|
||||
|
||||
Args:
|
||||
db (Optional[str]): The database file path.
|
||||
message (str): The error message template.
|
||||
"""
|
||||
def __init__(self, db: Optional[str] = None, message: str = "Unable to find vulnerability database in {db}"):
|
||||
self.db = db
|
||||
self.message = message.format(db=db)
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_UNABLE_TO_LOAD_LOCAL_VULNERABILITY_DB
|
||||
|
||||
|
||||
class InvalidCredentialError(DatabaseFetchError):
|
||||
"""
|
||||
Error raised when authentication credentials are invalid.
|
||||
|
||||
Args:
|
||||
credential (Optional[str]): The invalid credential.
|
||||
message (str): The error message template.
|
||||
reason (Optional[str]): The reason for the error.
|
||||
"""
|
||||
|
||||
def __init__(self, credential: Optional[str] = None,
|
||||
message: str = "Your authentication credential{credential}is invalid. See {link}.",
|
||||
reason: Optional[str] = None):
|
||||
self.credential = credential
|
||||
self.link = 'https://docs.safetycli.com/safety-docs/support/invalid-api-key-error'
|
||||
self.message = message.format(credential=f" '{self.credential}' ", link=self.link) if self.credential else message.format(credential=' ', link=self.link)
|
||||
info = f" Reason: {reason}"
|
||||
self.message = self.message + (info if reason else "")
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_INVALID_AUTH_CREDENTIAL
|
||||
|
||||
class NotVerifiedEmailError(SafetyError):
|
||||
"""
|
||||
Error raised when the user's email is not verified.
|
||||
|
||||
Args:
|
||||
message (str): The error message.
|
||||
"""
|
||||
def __init__(self, message: str = "email is not verified"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_EMAIL_NOT_VERIFIED
|
||||
|
||||
class TooManyRequestsError(DatabaseFetchError):
|
||||
"""
|
||||
Error raised when too many requests are made to the server.
|
||||
|
||||
Args:
|
||||
reason (Optional[str]): The reason for the error.
|
||||
message (str): The error message template.
|
||||
"""
|
||||
def __init__(self, reason: Optional[str] = None,
|
||||
message: str = "Too many requests."):
|
||||
info = f" Reason: {reason}"
|
||||
self.message = message + (info if reason else "")
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_exit_code(self) -> int:
|
||||
"""
|
||||
Get the exit code associated with this error.
|
||||
|
||||
Returns:
|
||||
int: The exit code.
|
||||
"""
|
||||
return EXIT_CODE_TOO_MANY_REQUESTS
|
||||
|
||||
|
||||
class NetworkConnectionError(DatabaseFetchError):
|
||||
"""
|
||||
Error raised when there is a network connection issue.
|
||||
|
||||
Args:
|
||||
message (str): The error message.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Check your network connection, unable to reach the server."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RequestTimeoutError(DatabaseFetchError):
|
||||
"""
|
||||
Error raised when a request times out.
|
||||
|
||||
Args:
|
||||
message (str): The error message.
|
||||
"""
|
||||
def __init__(self, message: str = "Check your network connection, the request timed out."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ServerError(DatabaseFetchError):
|
||||
"""
|
||||
Error raised when there is a server issue.
|
||||
|
||||
Args:
|
||||
reason (Optional[str]): The reason for the error.
|
||||
message (str): The error message template.
|
||||
"""
|
||||
def __init__(self, reason: Optional[str] = None,
|
||||
message: str = "Sorry, something went wrong.\n"
|
||||
"Our engineers are working quickly to resolve the issue."):
|
||||
info = f" Reason: {reason}"
|
||||
self.message = message + (info if reason else "")
|
||||
super().__init__(self.message)
|
||||
@@ -0,0 +1,25 @@
|
||||
from .handlers import EventHandler
|
||||
from .types import (
|
||||
CloseResourcesEvent,
|
||||
CommandErrorEvent,
|
||||
CommandExecutedEvent,
|
||||
FirewallConfiguredEvent,
|
||||
FirewallDisabledEvent,
|
||||
FirewallHeartbeatEvent,
|
||||
FlushSecurityTracesEvent,
|
||||
PackageInstalledEvent,
|
||||
PackageUninstalledEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EventHandler",
|
||||
"CloseResourcesEvent",
|
||||
"FlushSecurityTracesEvent",
|
||||
"CommandExecutedEvent",
|
||||
"CommandErrorEvent",
|
||||
"PackageInstalledEvent",
|
||||
"PackageUninstalledEvent",
|
||||
"FirewallHeartbeatEvent",
|
||||
"FirewallConfiguredEvent",
|
||||
"FirewallDisabledEvent",
|
||||
]
|
||||
Binary file not shown.
@@ -0,0 +1,7 @@
|
||||
from .bus import EventBus
|
||||
from .utils import start_event_bus
|
||||
|
||||
__all__ = [
|
||||
"EventBus",
|
||||
"start_event_bus",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Core EventBus implementation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from concurrent.futures import Future
|
||||
from typing import Dict, List, Any, Optional, Callable, TypeVar
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..handlers import EventHandler
|
||||
|
||||
from safety_schemas.models.events import Event, EventTypeBase, PayloadBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventBusMetrics:
|
||||
"""
|
||||
Metrics for the event bus.
|
||||
"""
|
||||
|
||||
events_emitted: int = 0
|
||||
events_processed: int = 0
|
||||
events_failed: int = 0
|
||||
queue_high_water_mark: int = 0
|
||||
handler_durations: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
|
||||
E = TypeVar("E", bound=Event)
|
||||
|
||||
# Define bounded type variables
|
||||
EventTypeT = TypeVar("EventTypeT", bound=EventTypeBase)
|
||||
PayloadT = TypeVar("PayloadT", bound=PayloadBase)
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Event bus that runs in a separate thread with its own asyncio event loop.
|
||||
|
||||
This class manages event subscription and dispatching across threads.
|
||||
|
||||
This is an approach to leverage asyncio without migrating current codebase
|
||||
to async.
|
||||
"""
|
||||
|
||||
def __init__(self, max_queue_size: int = 1000):
|
||||
"""
|
||||
Initialize the event bus.
|
||||
|
||||
Args:
|
||||
max_queue_size: Maximum number of events that can be queued
|
||||
"""
|
||||
self._handlers: Dict[EventTypeBase, List[EventHandler[Any]]] = {}
|
||||
|
||||
# Queue for passing events from main thread to event bus thread
|
||||
self._event_queue: queue.Queue = queue.Queue(maxsize=max_queue_size)
|
||||
|
||||
# Thread management
|
||||
self._running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._shutdown_event = threading.Event()
|
||||
|
||||
# Setup logging
|
||||
self.logger = logging.getLogger("event_bus")
|
||||
|
||||
# Metrics
|
||||
self.metrics = EventBusMetrics()
|
||||
|
||||
def subscribe(
|
||||
self, event_types: List[EventTypeBase], handler: EventHandler[E]
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe a handler to one or more event types.
|
||||
|
||||
Args:
|
||||
event_types: The list of event types to subscribe to
|
||||
handler: The handler to register
|
||||
"""
|
||||
for event_type in event_types:
|
||||
if event_type not in self._handlers:
|
||||
self._handlers[event_type] = []
|
||||
|
||||
self.logger.info(
|
||||
f"Registering handler {handler.__class__.__name__} "
|
||||
f"for event type {event_type}"
|
||||
)
|
||||
self._handlers[event_type].append(handler)
|
||||
|
||||
def emit(
|
||||
self,
|
||||
event: Event[EventTypeT, PayloadT],
|
||||
block: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Optional[Future]:
|
||||
"""
|
||||
Emit an event to be processed by the event bus.
|
||||
|
||||
Args:
|
||||
event: The event to emit
|
||||
block: Whether to block if the queue is full
|
||||
timeout: How long to wait if blocking
|
||||
|
||||
Returns:
|
||||
Future that will contain the results, or None if the event couldn't be queued
|
||||
"""
|
||||
if not self._running:
|
||||
self.logger.warning("Event bus is not running, but an event was emitted")
|
||||
|
||||
self.metrics.events_emitted += 1
|
||||
|
||||
# Create a future to track the results
|
||||
future = Future()
|
||||
|
||||
try:
|
||||
# Track queue size for metrics
|
||||
current_size = self._event_queue.qsize()
|
||||
self.metrics.queue_high_water_mark = max(
|
||||
current_size, self.metrics.queue_high_water_mark
|
||||
)
|
||||
|
||||
# Put the event in the queue
|
||||
self._event_queue.put((event, future), block=block, timeout=timeout)
|
||||
self.logger.debug("Emitted %s (%s)", event.type, event.id)
|
||||
return future
|
||||
|
||||
except queue.Full:
|
||||
self.logger.error(f"Event queue is full, dropping event: {event}")
|
||||
future.set_exception(RuntimeError("Event queue is full"))
|
||||
return future
|
||||
|
||||
def start(self):
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._shutdown_event.clear()
|
||||
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self, timeout=5.0):
|
||||
if not self._running:
|
||||
return True
|
||||
|
||||
self._running = False
|
||||
self._event_queue.put((None, None), block=False) # Send sentinel
|
||||
return self._shutdown_event.wait(timeout)
|
||||
|
||||
def _run_event_loop(self):
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
|
||||
async def main():
|
||||
pending_tasks = set()
|
||||
|
||||
# Process the queue until shutdown
|
||||
while self._running or not self._event_queue.empty():
|
||||
try:
|
||||
# Get the next event with a short timeout
|
||||
try:
|
||||
event, future = self._event_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
# Check for shutdown sentinel
|
||||
if event is None:
|
||||
self.logger.info("Received shutdown sentinel")
|
||||
break
|
||||
|
||||
# Process the event
|
||||
task = asyncio.create_task(self._dispatch_event(event, future))
|
||||
self.logger.debug(f"Dispatching {event.type} ({event.id})")
|
||||
pending_tasks.add(task)
|
||||
task.add_done_callback(lambda t: pending_tasks.discard(t))
|
||||
except Exception as e:
|
||||
self.logger.exception(f"Error processing event: {e}")
|
||||
|
||||
# Wait for any pending tasks before exiting
|
||||
if pending_tasks:
|
||||
self.logger.info(f"Waiting for {len(pending_tasks)} pending tasks")
|
||||
await asyncio.gather(*pending_tasks, return_exceptions=True)
|
||||
|
||||
try:
|
||||
# Single run_until_complete call for the entire lifecycle
|
||||
self._loop.run_until_complete(main())
|
||||
finally:
|
||||
self._loop.close()
|
||||
self._shutdown_event.set()
|
||||
|
||||
async def _dispatch_event(
|
||||
self, event: Event[EventTypeBase, PayloadBase], future: Future
|
||||
) -> None:
|
||||
"""
|
||||
Dispatch an event to all registered handlers.
|
||||
|
||||
Args:
|
||||
event: The event to dispatch
|
||||
future: Future to set with the results
|
||||
"""
|
||||
results = []
|
||||
|
||||
handlers = self._handlers.get(event.type, [])
|
||||
|
||||
if not handlers:
|
||||
self.logger.warning(f"No handlers registered for event type {event.type}")
|
||||
future.set_result([])
|
||||
return
|
||||
|
||||
# Create tasks for all handlers and run them concurrently
|
||||
tasks = []
|
||||
for handler in handlers:
|
||||
task = asyncio.create_task(self._handle_event(handler, event))
|
||||
tasks.append(task)
|
||||
|
||||
trace_id = event.correlation_id if event.correlation_id else "-"
|
||||
|
||||
self.logger.debug(
|
||||
"Event %s | %s | %s Handler(s) Task(s)", trace_id, event.type, len(tasks)
|
||||
)
|
||||
|
||||
# Wait for all handlers to complete
|
||||
handler_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
self.logger.info(
|
||||
"Event %s | %s | %s Handler(s) Completed",
|
||||
trace_id,
|
||||
event.type,
|
||||
len(handler_results),
|
||||
)
|
||||
|
||||
# Process results
|
||||
for i, result in enumerate(handler_results):
|
||||
if isinstance(result, Exception):
|
||||
self.logger.error(
|
||||
"Event %s | %s | Handler %d failed: %s",
|
||||
trace_id,
|
||||
event.type,
|
||||
i,
|
||||
str(result),
|
||||
exc_info=result,
|
||||
)
|
||||
else:
|
||||
self.logger.debug(
|
||||
"Event %s | %s | Handler %d succeeded: %s",
|
||||
trace_id,
|
||||
event.type,
|
||||
i,
|
||||
str(result),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Set the result on the future
|
||||
if not future.done():
|
||||
future.set_result(results)
|
||||
|
||||
async def _handle_event(self, handler: EventHandler[E], event: E) -> Any:
|
||||
"""
|
||||
Handle a single event with error handling and metrics.
|
||||
|
||||
Args:
|
||||
handler: The handler to use
|
||||
event: The event to handle
|
||||
|
||||
Returns:
|
||||
The result from the handler
|
||||
"""
|
||||
handler_name = handler.__class__.__name__
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Call the handler
|
||||
result = await handler.handle(event)
|
||||
|
||||
# Record successful processing
|
||||
self.metrics.events_processed += 1
|
||||
|
||||
# Record timing
|
||||
duration = time.time() - start_time
|
||||
if handler_name not in self.metrics.handler_durations:
|
||||
self.metrics.handler_durations[handler_name] = []
|
||||
self.metrics.handler_durations[handler_name].append(duration)
|
||||
|
||||
self.logger.debug(
|
||||
f"Handler {handler_name} processed {event.__class__.__name__} "
|
||||
f"in {duration:.3f}s"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record failure
|
||||
self.metrics.events_failed += 1
|
||||
|
||||
self.logger.exception(
|
||||
f"Handler {handler_name} failed to process {event.__class__.__name__}: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_metrics(self) -> dict:
|
||||
"""
|
||||
Get the current metrics for the event bus.
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
"""
|
||||
metrics: dict[str, Any] = {
|
||||
"events_emitted": self.metrics.events_emitted,
|
||||
"events_processed": self.metrics.events_processed,
|
||||
"events_failed": self.metrics.events_failed,
|
||||
"current_queue_size": self._event_queue.qsize(),
|
||||
"queue_high_water_mark": self.metrics.queue_high_water_mark,
|
||||
}
|
||||
|
||||
# Add handler metrics
|
||||
handler_metrics = {}
|
||||
for handler_name, durations in self.metrics.handler_durations.items():
|
||||
if not durations:
|
||||
continue
|
||||
|
||||
handler_metrics[handler_name] = {
|
||||
"count": len(durations),
|
||||
"avg_duration": sum(durations) / len(durations),
|
||||
"max_duration": max(durations),
|
||||
"min_duration": min(durations),
|
||||
}
|
||||
|
||||
metrics["handlers"] = handler_metrics
|
||||
return metrics
|
||||
|
||||
def emit_with_callback(
|
||||
self, event: Event, callback: Callable[[List[Any]], None]
|
||||
) -> None:
|
||||
"""
|
||||
Emit an event and register a callback for when it completes.
|
||||
|
||||
Args:
|
||||
event: The event to emit
|
||||
callback: Function to call with the results when complete
|
||||
"""
|
||||
future = self.emit(event)
|
||||
if future:
|
||||
future.add_done_callback(
|
||||
lambda f: callback(f.result()) if not f.exception() else None
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from .bus import EventBus
|
||||
|
||||
from safety_schemas.models.events import EventType
|
||||
from safety.events.types import InternalEventType
|
||||
from safety.events.handlers import SecurityEventsHandler
|
||||
|
||||
from safety.constants import PLATFORM_API_EVENTS_ENDPOINT
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from safety.models import SafetyCLI
|
||||
from safety.auth.utils import SafetyAuthSession
|
||||
|
||||
|
||||
def start_event_bus(obj: "SafetyCLI", session: "SafetyAuthSession"):
|
||||
"""
|
||||
Initializes the event bus with the default security events handler
|
||||
for authenticated users.
|
||||
This function creates an instance of the EventBus, starts it,
|
||||
and assigns it to the `event_bus` attribute of the provided `obj`.
|
||||
It also initializes the `security_events_handler` with the necessary
|
||||
parameters and subscribes it to a predefined list of events.
|
||||
|
||||
Args:
|
||||
obj (SafetyCLI): The main application object.
|
||||
session (SafetyAuthSession): The authentication session containing
|
||||
the necessary credentials and proxies.
|
||||
|
||||
"""
|
||||
event_bus = EventBus()
|
||||
event_bus.start()
|
||||
obj.event_bus = event_bus
|
||||
|
||||
token = session.token.get("access_token") if session.token else None
|
||||
|
||||
obj.security_events_handler = SecurityEventsHandler(
|
||||
api_endpoint=PLATFORM_API_EVENTS_ENDPOINT,
|
||||
proxies=session.proxies, # type: ignore
|
||||
auth_token=token,
|
||||
api_key=session.api_key,
|
||||
)
|
||||
|
||||
events = [
|
||||
EventType.AUTH_STARTED,
|
||||
EventType.AUTH_COMPLETED,
|
||||
EventType.COMMAND_EXECUTED,
|
||||
EventType.COMMAND_ERROR,
|
||||
InternalEventType.CLOSE_RESOURCES,
|
||||
InternalEventType.FLUSH_SECURITY_TRACES,
|
||||
]
|
||||
|
||||
event_bus.subscribe(events, obj.security_events_handler)
|
||||
|
||||
if obj.firewall_enabled:
|
||||
from safety.firewall.events.utils import register_event_handlers
|
||||
|
||||
register_event_handlers(obj.event_bus, obj=obj)
|
||||
@@ -0,0 +1,5 @@
|
||||
from .base import EventHandler
|
||||
from .common import SecurityEventsHandler
|
||||
|
||||
|
||||
__all__ = ["EventHandler", "SecurityEventsHandler"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Event handler definitions for the event bus system.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypeVar, Generic
|
||||
|
||||
from safety_schemas.models.events import Event
|
||||
|
||||
# Type variable for event types
|
||||
EventType = TypeVar("EventType", bound=Event)
|
||||
|
||||
|
||||
class EventHandler(Generic[EventType], ABC):
|
||||
"""
|
||||
Abstract base class for event handlers.
|
||||
|
||||
Concrete handlers should implement the handle method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, event: EventType) -> Any:
|
||||
"""
|
||||
Handle an event asynchronously.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
|
||||
Returns:
|
||||
Any result from handling the event
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,333 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
import tenacity
|
||||
|
||||
from safety.meta import get_identifier, get_meta_http_headers, get_version
|
||||
|
||||
from ..types import (
|
||||
CommandErrorEvent,
|
||||
CommandExecutedEvent,
|
||||
CloseResourcesEvent,
|
||||
InternalEventType,
|
||||
FlushSecurityTracesEvent,
|
||||
)
|
||||
from ..handlers import EventHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from safety_schemas.models.events import EventContext
|
||||
from safety.events.utils import InternalPayload
|
||||
from safety.models import SafetyCLI
|
||||
|
||||
SecurityEventTypes = Union[
|
||||
CommandExecutedEvent,
|
||||
CommandErrorEvent,
|
||||
FlushSecurityTracesEvent,
|
||||
CloseResourcesEvent,
|
||||
]
|
||||
|
||||
|
||||
class SecurityEventsHandler(EventHandler[SecurityEventTypes]):
|
||||
"""
|
||||
Handler that collects events in memory and flushes them when requested.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_endpoint: str,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the telemetry handler.
|
||||
|
||||
Args:
|
||||
api_endpoint: URL to send events to
|
||||
proxies: Optional dictionary of proxy settings
|
||||
auth_token: Optional authentication token for the API
|
||||
api_key: Optional API key for authentication
|
||||
"""
|
||||
self.api_endpoint = api_endpoint
|
||||
self.proxies = proxies
|
||||
self.auth_token = auth_token
|
||||
self.api_key = api_key
|
||||
|
||||
# Storage for collected events
|
||||
self.collected_events: List[Dict[str, Any]] = []
|
||||
|
||||
# HTTP client (created when needed)
|
||||
self.http_client = None
|
||||
|
||||
# Logging
|
||||
self.logger = logging.getLogger("security_events_handler")
|
||||
|
||||
# Event types that should not be collected (to avoid recursion)
|
||||
self.excluded_event_types = [
|
||||
InternalEventType.FLUSH_SECURITY_TRACES,
|
||||
]
|
||||
|
||||
async def handle(self, event: SecurityEventTypes) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle an event - either collect it or process a flush request.
|
||||
|
||||
Args:
|
||||
event: The event to handle
|
||||
|
||||
Returns:
|
||||
Status dictionary
|
||||
"""
|
||||
|
||||
if event.type is InternalEventType.CLOSE_RESOURCES:
|
||||
self.logger.info("Received request to close resources")
|
||||
await self.close_async()
|
||||
return {"closed": True}
|
||||
|
||||
if event.type is InternalEventType.FLUSH_SECURITY_TRACES:
|
||||
self.logger.info(f"Received flush request from {event.source}")
|
||||
return await self.flush(event_payload=event.payload)
|
||||
|
||||
# Don't collect excluded event types
|
||||
if any(event == t for t in self.excluded_event_types):
|
||||
return {"skipped": True, "reason": "excluded_event_type"}
|
||||
|
||||
try:
|
||||
event_data = event.model_dump(mode="json")
|
||||
except Exception:
|
||||
return {"collected": False, "event_count": len(self.collected_events)}
|
||||
|
||||
# Add to in-memory collection
|
||||
self.collected_events.append(event_data)
|
||||
event_count = len(self.collected_events)
|
||||
|
||||
self.logger.debug(
|
||||
f"Collected event: {event.type}, total event count: {event_count}"
|
||||
)
|
||||
|
||||
return {"collected": True, "event_count": event_count}
|
||||
|
||||
async def _build_context_data(self, obj: Optional["SafetyCLI"]) -> "EventContext":
|
||||
"""
|
||||
Generate context data for telemetry events.
|
||||
|
||||
Returns:
|
||||
Dict containing context information about client, runtime, etc.
|
||||
"""
|
||||
from safety_schemas.models.events.types import SourceType
|
||||
from safety.events.utils.context import create_event_context
|
||||
|
||||
project = getattr(obj, "project", None)
|
||||
tags = None
|
||||
try:
|
||||
if obj and obj.auth and obj.auth.stage:
|
||||
tags = [obj.auth.stage.value]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
version = get_version() or "unknown"
|
||||
|
||||
path = ""
|
||||
try:
|
||||
path = sys.argv[0]
|
||||
except (IndexError, TypeError):
|
||||
pass
|
||||
|
||||
context = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
create_event_context,
|
||||
SourceType(get_identifier()),
|
||||
version,
|
||||
path,
|
||||
project,
|
||||
tags,
|
||||
),
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=0.1, min=0.2, max=1.0),
|
||||
retry=retry_if_exception_type(
|
||||
(httpx.NetworkError, httpx.TimeoutException, httpx.HTTPStatusError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
|
||||
)
|
||||
async def _send_events(
|
||||
self, payload: dict, headers: dict
|
||||
) -> Optional[httpx.Response]:
|
||||
"""
|
||||
Send events to the API with retry logic.
|
||||
|
||||
Args:
|
||||
payload: The data payload to send
|
||||
headers: The HTTP headers to include
|
||||
|
||||
Returns:
|
||||
Response from the API or None if http_client is not initialized
|
||||
|
||||
Raises:
|
||||
Exception if all retries fail
|
||||
"""
|
||||
if self.http_client is None:
|
||||
self.logger.warning("Cannot send events: HTTP client not initialized")
|
||||
return None
|
||||
|
||||
TIMEOUT = int(os.getenv("SAFETY_REQUEST_TIMEOUT_EVENTS", 10))
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.api_endpoint, json=payload, headers=headers, timeout=TIMEOUT
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def flush(self, event_payload: "InternalPayload") -> Dict[str, Any]:
|
||||
"""
|
||||
Send all collected events to the API endpoint.
|
||||
|
||||
Returns:
|
||||
Status dictionary
|
||||
"""
|
||||
# If no events, just return early
|
||||
if not self.collected_events:
|
||||
return {"status": "no_events", "count": 0}
|
||||
|
||||
# Get a copy of events and clear the original list
|
||||
events_to_send = self.collected_events.copy()
|
||||
self.collected_events.clear()
|
||||
|
||||
event_count = len(events_to_send)
|
||||
self.logger.info(
|
||||
"[Flush] -> Sending %s events to %s", event_count, self.api_endpoint
|
||||
)
|
||||
IDEMPOTENCY_KEY = str(uuid.uuid4())
|
||||
|
||||
# Get context data that will be shared across all events
|
||||
obj = event_payload.ctx.obj if event_payload.ctx else None
|
||||
context = await self._build_context_data(obj=obj)
|
||||
self.logger.info("Context data built")
|
||||
|
||||
for event_data in events_to_send:
|
||||
event_data["context"] = context.model_dump(mode="json")
|
||||
|
||||
payload = {"events": events_to_send}
|
||||
|
||||
# Create HTTP client if needed
|
||||
if self.http_client is None:
|
||||
# TODO: Add proxy support
|
||||
self.http_client = httpx.AsyncClient(proxy=None)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Idempotency-Key": IDEMPOTENCY_KEY,
|
||||
}
|
||||
headers.update(get_meta_http_headers())
|
||||
|
||||
# Add authentication
|
||||
if self.api_key:
|
||||
headers["X-Api-Key"] = self.api_key
|
||||
elif self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
|
||||
try:
|
||||
# Send the request with retries
|
||||
response = await self._send_events(payload, headers)
|
||||
|
||||
# Handle case where http_client was None and _send_events returned None
|
||||
if response is None:
|
||||
self.logger.warning("Events not sent: HTTP client not initialized")
|
||||
# Put events back in collection
|
||||
self.collected_events = events_to_send + self.collected_events
|
||||
return {
|
||||
"status": "error",
|
||||
"count": event_count,
|
||||
"error": "HTTP client not initialized",
|
||||
}
|
||||
|
||||
self.logger.info(
|
||||
f"Successfully sent {event_count} events, status: {response.status_code}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"count": event_count,
|
||||
"http_status": response.status_code,
|
||||
}
|
||||
except tenacity.RetryError as retry_exc:
|
||||
# Put events back in collection
|
||||
self.collected_events = events_to_send + self.collected_events
|
||||
exc = retry_exc.last_attempt.exception()
|
||||
|
||||
status_code = None
|
||||
if hasattr(exc, "response"):
|
||||
status_code = exc.response # type: ignore
|
||||
|
||||
self.logger.error(f"Failed after retries: {exc}")
|
||||
|
||||
result = {"status": "error", "count": event_count, "error": repr(exc)}
|
||||
if status_code:
|
||||
result["http_status"] = status_code
|
||||
|
||||
return result
|
||||
except Exception as exc:
|
||||
# Handle any other unexpected exceptions
|
||||
self.collected_events = events_to_send + self.collected_events
|
||||
self.logger.exception(f"Unexpected error: {exc}")
|
||||
|
||||
return {"status": "error", "count": event_count, "error": repr(exc)}
|
||||
|
||||
async def close_async(self):
|
||||
"""Close the HTTP client asynchronously."""
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
self.http_client = None
|
||||
self.logger.debug("HTTP client closed")
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Synchronous method to close the handler.
|
||||
|
||||
This is a non-blocking method that initiates closure but doesn't wait for it.
|
||||
The event bus will handle the actual closing asynchronously.
|
||||
"""
|
||||
self.logger.info("Initiating telemetry handler shutdown")
|
||||
# The actual close will happen when the event loop processes events
|
||||
# Just log the intent and let the event loop handle it
|
||||
return {"status": "shutdown_initiated"}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current telemetry statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary of statistics
|
||||
"""
|
||||
event_count = len(self.collected_events)
|
||||
|
||||
# Group events by type
|
||||
event_types = {}
|
||||
for event in self.collected_events:
|
||||
event_type = event.get("event_type", "unknown")
|
||||
if event_type not in event_types:
|
||||
event_types[event_type] = 0
|
||||
event_types[event_type] += 1
|
||||
|
||||
return {
|
||||
"events_in_memory": event_count,
|
||||
"event_types": event_types,
|
||||
"api_endpoint": self.api_endpoint,
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
from .aliases import (
|
||||
CloseResourcesEvent,
|
||||
CommandErrorEvent,
|
||||
CommandExecutedEvent,
|
||||
FirewallConfiguredEvent,
|
||||
FirewallDisabledEvent,
|
||||
FirewallHeartbeatEvent,
|
||||
FlushSecurityTracesEvent,
|
||||
PackageInstalledEvent,
|
||||
PackageUninstalledEvent,
|
||||
EventBusReadyEvent,
|
||||
)
|
||||
from .base import InternalEventType, InternalPayload
|
||||
|
||||
__all__ = [
|
||||
"CloseResourcesEvent",
|
||||
"FlushSecurityTracesEvent",
|
||||
"InternalEventType",
|
||||
"InternalPayload",
|
||||
"CommandExecutedEvent",
|
||||
"CommandErrorEvent",
|
||||
"PackageInstalledEvent",
|
||||
"PackageUninstalledEvent",
|
||||
"FirewallHeartbeatEvent",
|
||||
"FirewallConfiguredEvent",
|
||||
"FirewallDisabledEvent",
|
||||
"EventBusReadyEvent",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
from typing import Literal
|
||||
|
||||
from safety_schemas.models.events import Event, EventType
|
||||
from safety_schemas.models.events.payloads import (
|
||||
AuthCompletedPayload,
|
||||
AuthStartedPayload,
|
||||
CodebaseSetupCompletedPayload,
|
||||
CodebaseSetupResponseCreatedPayload,
|
||||
FirewallConfiguredPayload,
|
||||
FirewallDisabledPayload,
|
||||
FirewallSetupCompletedPayload,
|
||||
FirewallSetupResponseCreatedPayload,
|
||||
InitScanCompletedPayload,
|
||||
InitStartedPayload,
|
||||
PackageInstalledPayload,
|
||||
PackageUninstalledPayload,
|
||||
CommandExecutedPayload,
|
||||
CommandErrorPayload,
|
||||
FirewallHeartbeatPayload,
|
||||
CodebaseDetectionStatusPayload,
|
||||
)
|
||||
|
||||
from .base import InternalEventType, InternalPayload
|
||||
|
||||
CommandExecutedEvent = Event[
|
||||
Literal[EventType.COMMAND_EXECUTED], CommandExecutedPayload
|
||||
]
|
||||
CommandErrorEvent = Event[Literal[EventType.COMMAND_ERROR], CommandErrorPayload]
|
||||
PackageInstalledEvent = Event[
|
||||
Literal[EventType.PACKAGE_INSTALLED], PackageInstalledPayload
|
||||
]
|
||||
PackageUninstalledEvent = Event[
|
||||
Literal[EventType.PACKAGE_UNINSTALLED], PackageUninstalledPayload
|
||||
]
|
||||
FirewallHeartbeatEvent = Event[
|
||||
Literal[EventType.FIREWALL_HEARTBEAT], FirewallHeartbeatPayload
|
||||
]
|
||||
FirewallConfiguredEvent = Event[
|
||||
Literal[EventType.FIREWALL_CONFIGURED], FirewallConfiguredPayload
|
||||
]
|
||||
FirewallDisabledEvent = Event[
|
||||
Literal[EventType.FIREWALL_DISABLED], FirewallDisabledPayload
|
||||
]
|
||||
|
||||
|
||||
InitStartedEvent = Event[Literal[EventType.INIT_STARTED], InitStartedPayload]
|
||||
AuthStartedEvent = Event[Literal[EventType.AUTH_STARTED], AuthStartedPayload]
|
||||
AuthCompletedEvent = Event[Literal[EventType.AUTH_COMPLETED], AuthCompletedPayload]
|
||||
|
||||
# Firewall setup events
|
||||
FirewallSetupResponseCreatedEvent = Event[
|
||||
Literal[EventType.FIREWALL_SETUP_RESPONSE_CREATED],
|
||||
FirewallSetupResponseCreatedPayload,
|
||||
]
|
||||
FirewallSetupCompletedEvent = Event[
|
||||
Literal[EventType.FIREWALL_SETUP_COMPLETED], FirewallSetupCompletedPayload
|
||||
]
|
||||
|
||||
# Codebase setup events
|
||||
CodebaseDetectionStatusEvent = Event[
|
||||
Literal[EventType.CODEBASE_DETECTION_STATUS], CodebaseDetectionStatusPayload
|
||||
]
|
||||
CodebaseSetupResponseCreatedEvent = Event[
|
||||
Literal[EventType.CODEBASE_SETUP_RESPONSE_CREATED],
|
||||
CodebaseSetupResponseCreatedPayload,
|
||||
]
|
||||
CodebaseSetupCompletedEvent = Event[
|
||||
Literal[EventType.CODEBASE_SETUP_COMPLETED], CodebaseSetupCompletedPayload
|
||||
]
|
||||
|
||||
# Scan events
|
||||
InitScanCompletedEvent = Event[
|
||||
Literal[EventType.INIT_SCAN_COMPLETED], InitScanCompletedPayload
|
||||
]
|
||||
|
||||
# Internal events
|
||||
CloseResourcesEvent = Event[InternalEventType.CLOSE_RESOURCES, InternalPayload]
|
||||
FlushSecurityTracesEvent = Event[
|
||||
InternalEventType.FLUSH_SECURITY_TRACES, InternalPayload
|
||||
]
|
||||
EventBusReadyEvent = Event[Literal[InternalEventType.EVENT_BUS_READY], InternalPayload]
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import ConfigDict
|
||||
from safety_schemas.models.events import EventTypeBase, PayloadBase
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class InternalEventType(EventTypeBase):
|
||||
"""
|
||||
Internal event types.
|
||||
"""
|
||||
|
||||
CLOSE_RESOURCES = "com.safetycli.close_resources"
|
||||
FLUSH_SECURITY_TRACES = "com.safetycli.flush_security_traces"
|
||||
EVENT_BUS_READY = "com.safetycli.event_bus_ready"
|
||||
|
||||
|
||||
class InternalPayload(PayloadBase):
|
||||
ctx: Optional[Annotated[Any, "CustomContext"]] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
@@ -0,0 +1,34 @@
|
||||
from .emission import (
|
||||
emit_command_error,
|
||||
emit_command_executed,
|
||||
emit_firewall_disabled,
|
||||
emit_diff_operations,
|
||||
emit_firewall_configured,
|
||||
emit_tool_command_executed,
|
||||
emit_firewall_heartbeat,
|
||||
emit_init_started,
|
||||
emit_auth_started,
|
||||
emit_auth_completed,
|
||||
)
|
||||
|
||||
from .creation import (
|
||||
create_internal_event,
|
||||
InternalEventType,
|
||||
InternalPayload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"emit_command_error",
|
||||
"emit_command_executed",
|
||||
"emit_firewall_disabled",
|
||||
"create_internal_event",
|
||||
"InternalEventType",
|
||||
"InternalPayload",
|
||||
"emit_firewall_configured",
|
||||
"emit_diff_operations",
|
||||
"emit_init_started",
|
||||
"emit_auth_started",
|
||||
"emit_auth_completed",
|
||||
"emit_tool_command_executed",
|
||||
"emit_firewall_heartbeat",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,79 @@
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, TypeVar, cast, overload
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from safety.events.event_bus import EventBus
|
||||
from safety.cli_util import CustomContext
|
||||
|
||||
|
||||
def should_emit(
|
||||
event_bus: Optional["EventBus"], ctx: Optional["CustomContext"]
|
||||
) -> bool:
|
||||
"""
|
||||
Common conditions that apply to all event emissions.
|
||||
"""
|
||||
if event_bus is None:
|
||||
return False
|
||||
|
||||
# Be aware that ctx depends on the command being parsed, if the emit func
|
||||
# is called from the entrypoint group command, ctx will not have
|
||||
# the command parsed yet.
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_emit_firewall_heartbeat(ctx: Optional["CustomContext"]) -> bool:
|
||||
"""
|
||||
Condition to check if the firewall is enabled.
|
||||
"""
|
||||
if ctx and ctx.obj.firewall_enabled:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# Define TypeVars for better typing
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
@overload
|
||||
def conditional_emitter(emit_func: F, *, conditions: None = None) -> F: ...
|
||||
|
||||
|
||||
@overload
|
||||
def conditional_emitter(
|
||||
emit_func: None = None,
|
||||
*,
|
||||
conditions: Optional[List[Callable[[Optional["CustomContext"]], bool]]] = None,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
def conditional_emitter(
|
||||
emit_func=None,
|
||||
*,
|
||||
conditions: Optional[List[Callable[[Optional["CustomContext"]], bool]]] = None,
|
||||
):
|
||||
"""
|
||||
A decorator that conditionally calls the decorated function based on conditions.
|
||||
Only executes the decorated function if all conditions evaluate to True.
|
||||
"""
|
||||
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapper(event_bus, ctx=None, *args, **kwargs):
|
||||
if not should_emit(event_bus, ctx):
|
||||
return None
|
||||
|
||||
if conditions:
|
||||
if all(condition(ctx) for condition in conditions):
|
||||
return func(event_bus, ctx, *args, **kwargs)
|
||||
return None
|
||||
return func(event_bus, ctx, *args, **kwargs)
|
||||
|
||||
return cast(F, wrapper) # Cast to help type checker
|
||||
|
||||
if emit_func is None:
|
||||
return decorator
|
||||
return decorator(emit_func)
|
||||
@@ -0,0 +1,163 @@
|
||||
import getpass
|
||||
import os
|
||||
from pathlib import Path
|
||||
import site
|
||||
import socket
|
||||
import sys
|
||||
import platform
|
||||
from typing import List, Optional
|
||||
from safety_schemas.models.events.context import (
|
||||
ClientInfo,
|
||||
EventContext,
|
||||
HostInfo,
|
||||
OsInfo,
|
||||
ProjectInfo,
|
||||
PythonInfo,
|
||||
RuntimeInfo,
|
||||
UserInfo,
|
||||
)
|
||||
|
||||
from safety_schemas.models.events.types import SourceType
|
||||
from safety_schemas.models import ProjectModel
|
||||
|
||||
|
||||
def get_user_info() -> UserInfo:
|
||||
"""
|
||||
Collect information about the current user.
|
||||
"""
|
||||
return UserInfo(name=getpass.getuser(), home_dir=str(Path.home()))
|
||||
|
||||
|
||||
def get_os_info() -> OsInfo:
|
||||
"""
|
||||
Get basic OS information using only the platform module.
|
||||
Returns a dictionary with architecture, platform, name, version, and kernel_version.
|
||||
"""
|
||||
# Initialize with required fields
|
||||
os_info = {
|
||||
"architecture": platform.machine(),
|
||||
"platform": platform.system(),
|
||||
"name": None,
|
||||
"version": None,
|
||||
"kernel_version": None,
|
||||
}
|
||||
|
||||
python_version = sys.version_info
|
||||
|
||||
if sys.platform == "wind32":
|
||||
os_info["version"] = platform.release()
|
||||
os_info["kernel_version"] = platform.version()
|
||||
os_info["name"] = "windows"
|
||||
|
||||
elif sys.platform == "darwin":
|
||||
os_info["version"] = platform.mac_ver()[0]
|
||||
os_info["kernel_version"] = platform.release()
|
||||
os_info["name"] = "macos"
|
||||
|
||||
elif sys.platform == "linux":
|
||||
os_info["kernel_version"] = platform.release()
|
||||
if python_version >= (3, 10):
|
||||
try:
|
||||
os_release = platform.freedesktop_os_release()
|
||||
# Use ID for name (more consistent for programmatic use)
|
||||
os_info["name"] = os_release.get("ID", "linux")
|
||||
os_info["version"] = os_release.get("VERSION_ID")
|
||||
except (OSError, AttributeError):
|
||||
# If freedesktop_os_release fails, keep values as is
|
||||
pass
|
||||
|
||||
return OsInfo(**os_info)
|
||||
|
||||
|
||||
def get_host_info() -> HostInfo:
|
||||
"""
|
||||
Collect information about the host machine.
|
||||
"""
|
||||
hostname = socket.gethostname()
|
||||
|
||||
ipv4_addresses = set()
|
||||
ipv6_addresses = set()
|
||||
try:
|
||||
host_info = socket.getaddrinfo(hostname, None)
|
||||
for info in host_info:
|
||||
ip_family = info[0]
|
||||
ip = str(info[4][0])
|
||||
|
||||
if ip_family == socket.AF_INET:
|
||||
if not ip.startswith("127."):
|
||||
ipv4_addresses.add(ip)
|
||||
elif ip_family == socket.AF_INET6:
|
||||
if not ip.startswith("::1") and ip != "fe80::1":
|
||||
ipv6_addresses.add(ip)
|
||||
|
||||
# Prioritize addresses
|
||||
primary_ipv4 = next(
|
||||
(ip for ip in ipv4_addresses),
|
||||
next(iter(ipv4_addresses)) if ipv4_addresses else None,
|
||||
)
|
||||
|
||||
primary_ipv6 = next(
|
||||
(ip for ip in ipv6_addresses if not ip.startswith("fe80:")),
|
||||
next(iter(ipv6_addresses)) if ipv6_addresses else None,
|
||||
)
|
||||
|
||||
except socket.gaierror:
|
||||
primary_ipv4 = None
|
||||
primary_ipv6 = None
|
||||
|
||||
return HostInfo(name=hostname, ipv4=primary_ipv4, ipv6=primary_ipv6, timezone=None)
|
||||
|
||||
|
||||
def get_python_info() -> PythonInfo:
|
||||
"""
|
||||
Collect detailed information about the Python environment.
|
||||
"""
|
||||
# Get site-packages directories
|
||||
site_packages_dirs = site.getsitepackages()
|
||||
|
||||
user_site_enabled = bool(site.ENABLE_USER_SITE)
|
||||
user_site_packages = site.getusersitepackages()
|
||||
|
||||
return PythonInfo(
|
||||
version=f"{sys.version_info.major}.{sys.version_info.minor}",
|
||||
path=sys.executable,
|
||||
sys_path=sys.path,
|
||||
implementation=platform.python_implementation(),
|
||||
implementation_version=platform.python_version(),
|
||||
sys_prefix=sys.prefix,
|
||||
site_packages=site_packages_dirs,
|
||||
user_site_enabled=user_site_enabled,
|
||||
user_site_packages=user_site_packages,
|
||||
encoding=sys.getdefaultencoding(),
|
||||
filesystem_encoding=sys.getfilesystemencoding(),
|
||||
)
|
||||
|
||||
|
||||
def create_event_context(
|
||||
client_identifier: SourceType,
|
||||
client_version: str,
|
||||
client_path: str,
|
||||
project: Optional[ProjectModel] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> EventContext:
|
||||
client = ClientInfo(
|
||||
identifier=client_identifier, version=client_version, path=client_path
|
||||
)
|
||||
|
||||
project_info = None
|
||||
|
||||
if project:
|
||||
project_info = ProjectInfo(
|
||||
id=project.id,
|
||||
url=project.url_path,
|
||||
)
|
||||
|
||||
runtime = RuntimeInfo(
|
||||
workdir=os.getcwd(),
|
||||
user=get_user_info(),
|
||||
os=get_os_info(),
|
||||
host=get_host_info(),
|
||||
python=get_python_info(),
|
||||
)
|
||||
|
||||
return EventContext(client=client, runtime=runtime, project=project_info, tags=tags)
|
||||
@@ -0,0 +1,48 @@
|
||||
import time
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from safety_schemas.models.events import Event, EventTypeBase, PayloadBase, SourceType
|
||||
|
||||
from safety.meta import get_identifier
|
||||
|
||||
from ..types import InternalEventType, InternalPayload
|
||||
|
||||
PayloadBaseT = TypeVar("PayloadBaseT", bound=PayloadBase)
|
||||
EventTypeBaseT = TypeVar("EventTypeBaseT", bound=EventTypeBase)
|
||||
|
||||
|
||||
def create_event(
|
||||
payload: PayloadBaseT,
|
||||
event_type: EventTypeBaseT,
|
||||
source: SourceType = SourceType(get_identifier()),
|
||||
timestamp: int = int(time.time()),
|
||||
correlation_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Event[EventTypeBaseT, PayloadBaseT]:
|
||||
"""
|
||||
Generic factory function for creating any type of event.
|
||||
"""
|
||||
|
||||
return Event(
|
||||
timestamp=timestamp,
|
||||
payload=payload,
|
||||
type=event_type,
|
||||
source=source,
|
||||
correlation_id=correlation_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create_internal_event(
|
||||
event_type: InternalEventType,
|
||||
payload: InternalPayload,
|
||||
) -> Event[InternalEventType, InternalPayload]:
|
||||
"""
|
||||
Create an internal event.
|
||||
"""
|
||||
return Event(
|
||||
type=event_type,
|
||||
timestamp=int(time.time()),
|
||||
source=SourceType(get_identifier()),
|
||||
payload=payload,
|
||||
)
|
||||
@@ -0,0 +1,110 @@
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from click.core import ParameterSource as ClickParameterSource
|
||||
|
||||
from safety_schemas.models.events.types import ParamSource
|
||||
|
||||
|
||||
def is_sensitive_parameter(param_name: str) -> bool:
|
||||
"""
|
||||
Determine if a parameter name likely contains sensitive information.
|
||||
"""
|
||||
sensitive_patterns = [
|
||||
r"(?i)pass(word)?", # password, pass
|
||||
r"(?i)token", # token, auth_token
|
||||
r"(?i)key", # key, apikey
|
||||
r"(?i)auth", # auth, authorization
|
||||
]
|
||||
|
||||
return any(re.search(pattern, param_name) for pattern in sensitive_patterns)
|
||||
|
||||
|
||||
def scrub_sensitive_value(value: str) -> str:
|
||||
"""
|
||||
Detect if a value appears to be sensitive information based on
|
||||
specific patterns.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
result = value
|
||||
|
||||
if re.match(r"^-{1,2}[\w-]+$", value) and "=" not in value:
|
||||
return value
|
||||
|
||||
# Patterns to detect and replace
|
||||
patterns = [
|
||||
# This will replace ports too, but that's fine
|
||||
(r"\b\w+:\w+\b", "-:-"),
|
||||
(r"Basic\s+[A-Za-z0-9+/=]+", "Basic -"),
|
||||
(r"Bearer\s+[A-Za-z0-9._~+/=-]+", "Bearer -"),
|
||||
(r"\b[A-Za-z0-9_-]{20,}\b", "-"),
|
||||
(
|
||||
r"((?:token|api|apikey|key|auth|secret|password|access|jwt|bearer|credential|pwd)=)([^&\s]+)",
|
||||
r"\1-",
|
||||
),
|
||||
]
|
||||
|
||||
# Apply each pattern and replace matches
|
||||
for pattern, repl in patterns:
|
||||
result = re.sub(pattern, repl, result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def clean_parameter(param_name: str, param_value: Any) -> Any:
|
||||
"""
|
||||
Scrub a parameter value if it's sensitive.
|
||||
"""
|
||||
if not isinstance(param_value, str):
|
||||
return param_value
|
||||
|
||||
if is_sensitive_parameter(param_name):
|
||||
return "-"
|
||||
|
||||
return scrub_sensitive_value(param_value)
|
||||
|
||||
|
||||
def get_command_path(ctx) -> List[str]:
|
||||
hierarchy = []
|
||||
current = ctx
|
||||
|
||||
while current is not None:
|
||||
if current.command:
|
||||
name = current.command.name
|
||||
if name == "cli":
|
||||
name = "safety"
|
||||
hierarchy.append(name)
|
||||
current = current.parent
|
||||
|
||||
# Reverse to get top-level first
|
||||
hierarchy.reverse()
|
||||
|
||||
return hierarchy
|
||||
|
||||
|
||||
def get_root_context(ctx):
|
||||
"""
|
||||
Get the top-level parent context.
|
||||
"""
|
||||
current = ctx
|
||||
while current.parent is not None:
|
||||
current = current.parent
|
||||
return current
|
||||
|
||||
|
||||
def translate_param_source(source: Optional[ClickParameterSource]) -> ParamSource:
|
||||
"""
|
||||
Translate Click's ParameterSource enum to our ParameterSource enum
|
||||
"""
|
||||
mapping = {
|
||||
ClickParameterSource.COMMANDLINE: ParamSource.COMMANDLINE,
|
||||
ClickParameterSource.ENVIRONMENT: ParamSource.ENVIRONMENT,
|
||||
ClickParameterSource.DEFAULT: ParamSource.DEFAULT,
|
||||
# In newer Click versions
|
||||
getattr(ClickParameterSource, "PROMPT", None): ParamSource.PROMPT,
|
||||
getattr(ClickParameterSource, "CONFIG_FILE", None): ParamSource.CONFIG,
|
||||
}
|
||||
|
||||
return mapping.get(source, ParamSource.UNKNOWN)
|
||||
@@ -0,0 +1,681 @@
|
||||
from concurrent.futures import Future
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
import uuid
|
||||
|
||||
from safety.utils.pyapp_utils import get_path, get_env
|
||||
|
||||
from safety_schemas.models.events import Event, EventType
|
||||
from safety_schemas.models.events.types import ToolType
|
||||
from safety_schemas.models.events.payloads import (
|
||||
CodebaseDetectionStatusPayload,
|
||||
CodebaseSetupCompletedPayload,
|
||||
CodebaseSetupResponseCreatedPayload,
|
||||
DependencyFile,
|
||||
FirewallConfiguredPayload,
|
||||
FirewallDisabledPayload,
|
||||
FirewallSetupCompletedPayload,
|
||||
FirewallSetupResponseCreatedPayload,
|
||||
InitExitStep,
|
||||
InitExitedPayload,
|
||||
InitScanCompletedPayload,
|
||||
PackageInstalledPayload,
|
||||
PackageUninstalledPayload,
|
||||
PackageUpdatedPayload,
|
||||
CommandExecutedPayload,
|
||||
ToolCommandExecutedPayload,
|
||||
CommandErrorPayload,
|
||||
AliasConfig,
|
||||
IndexConfig,
|
||||
ToolStatus,
|
||||
CommandParam,
|
||||
ProcessStatus,
|
||||
FirewallHeartbeatPayload,
|
||||
InitStartedPayload,
|
||||
AuthStartedPayload,
|
||||
AuthCompletedPayload,
|
||||
)
|
||||
import typer
|
||||
|
||||
from ..event_bus import EventBus
|
||||
from ..types.base import InternalEventType, InternalPayload
|
||||
|
||||
from .creation import (
|
||||
create_event,
|
||||
)
|
||||
from .data import (
|
||||
clean_parameter,
|
||||
get_command_path,
|
||||
get_root_context,
|
||||
scrub_sensitive_value,
|
||||
translate_param_source,
|
||||
)
|
||||
from .conditions import conditional_emitter, should_emit_firewall_heartbeat
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from safety.models import SafetyCLI, ToolResult
|
||||
from safety.cli_util import CustomContext
|
||||
from safety.init.types import FirewallConfigStatus
|
||||
from safety.tool.environment_diff import PackageLocation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def send_and_flush(event_bus: "EventBus", event: Event) -> Optional[Future]:
|
||||
"""
|
||||
Emit an event and immediately flush the event bus without closing it.
|
||||
|
||||
Args:
|
||||
event_bus: The event bus to emit on
|
||||
event: The event to emit
|
||||
"""
|
||||
future = event_bus.emit(event)
|
||||
|
||||
# Create and emit flush event
|
||||
flush_payload = InternalPayload()
|
||||
flush_event = create_event(
|
||||
payload=flush_payload, event_type=InternalEventType.FLUSH_SECURITY_TRACES
|
||||
)
|
||||
|
||||
# Emit flush event and wait for it to complete
|
||||
flush_future = event_bus.emit(flush_event)
|
||||
|
||||
# Wait for both events to complete
|
||||
if future:
|
||||
try:
|
||||
future.result(timeout=0.5)
|
||||
except Exception:
|
||||
logger.error("Emit Failed %s (%s)", event.type, event.id)
|
||||
|
||||
if flush_future:
|
||||
try:
|
||||
return flush_future.result(timeout=0.5)
|
||||
except Exception:
|
||||
logger.error("Flush Failed for event %s", event.id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@conditional_emitter(conditions=[should_emit_firewall_heartbeat])
|
||||
def emit_firewall_heartbeat(
|
||||
event_bus: "EventBus", ctx: Optional["CustomContext"], *, tools: List[ToolStatus]
|
||||
):
|
||||
payload = FirewallHeartbeatPayload(tools=tools)
|
||||
event = create_event(payload=payload, event_type=EventType.FIREWALL_HEARTBEAT)
|
||||
|
||||
event_bus.emit(event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_firewall_disabled(
|
||||
event_bus: "EventBus",
|
||||
ctx: Optional["CustomContext"] = None,
|
||||
*,
|
||||
reason: Optional[str],
|
||||
):
|
||||
payload = FirewallDisabledPayload(reason=reason)
|
||||
event = create_event(payload=payload, event_type=EventType.FIREWALL_DISABLED)
|
||||
|
||||
event_bus.emit(event)
|
||||
|
||||
|
||||
def status_to_tool_status(status: "FirewallConfigStatus") -> List[ToolStatus]:
|
||||
filtered_path = get_path()
|
||||
tools = []
|
||||
for tool_type, configs in status.items():
|
||||
alias_config = (
|
||||
configs["alias"] if isinstance(configs["alias"], AliasConfig) else None
|
||||
)
|
||||
index_config = (
|
||||
configs["index"] if isinstance(configs["index"], IndexConfig) else None
|
||||
)
|
||||
|
||||
tool = tool_type.value
|
||||
command_path = shutil.which(tool, path=filtered_path)
|
||||
reachable = False
|
||||
version = "unknown"
|
||||
|
||||
if command_path:
|
||||
args = [command_path, "--version"]
|
||||
result = subprocess.run(args, capture_output=True, text=True, env=get_env())
|
||||
|
||||
if result.returncode == 0:
|
||||
output = result.stdout
|
||||
reachable = True
|
||||
|
||||
# Extract version
|
||||
version_match = re.search(r"(\d+\.\d+(?:\.\d+)?)", output)
|
||||
if version_match:
|
||||
version = version_match.group(1)
|
||||
else:
|
||||
command_path = tool
|
||||
|
||||
tool = ToolStatus(
|
||||
type=tool_type,
|
||||
command_path=command_path,
|
||||
version=version,
|
||||
reachable=reachable,
|
||||
alias_config=alias_config,
|
||||
index_config=index_config,
|
||||
)
|
||||
tools.append(tool)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_firewall_configured(
|
||||
event_bus: "EventBus",
|
||||
ctx: Optional["CustomContext"] = None,
|
||||
*,
|
||||
status: "FirewallConfigStatus",
|
||||
):
|
||||
tools = status_to_tool_status(status)
|
||||
|
||||
payload = FirewallConfiguredPayload(tools=tools)
|
||||
|
||||
event = create_event(payload=payload, event_type=EventType.FIREWALL_CONFIGURED)
|
||||
|
||||
event_bus.emit(event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_diff_operations(
|
||||
event_bus: "EventBus",
|
||||
ctx: "CustomContext",
|
||||
*,
|
||||
added: Dict["PackageLocation", str],
|
||||
removed: Dict["PackageLocation", str],
|
||||
updated: Dict["PackageLocation", Tuple[str, str]],
|
||||
tool_path: Optional[str],
|
||||
by_tool: ToolType,
|
||||
):
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
correlation_id = obj.correlation_id
|
||||
|
||||
kwargs = {
|
||||
"tool_path": tool_path,
|
||||
"tool": by_tool,
|
||||
}
|
||||
|
||||
if (added or removed or updated) and not correlation_id:
|
||||
correlation_id = obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
def emit_package_event(event_bus, correlation_id, payload, event_type):
|
||||
event = create_event(
|
||||
payload=payload,
|
||||
event_type=event_type,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
event_bus.emit(event)
|
||||
|
||||
for package, version in added.items():
|
||||
emit_package_event(
|
||||
event_bus,
|
||||
correlation_id,
|
||||
PackageInstalledPayload(
|
||||
package_name=package.name,
|
||||
location=package.location,
|
||||
version=version,
|
||||
**kwargs,
|
||||
),
|
||||
EventType.PACKAGE_INSTALLED,
|
||||
)
|
||||
|
||||
for package, version in removed.items():
|
||||
emit_package_event(
|
||||
event_bus,
|
||||
correlation_id,
|
||||
PackageUninstalledPayload(
|
||||
package_name=package.name,
|
||||
location=package.location,
|
||||
version=version,
|
||||
**kwargs,
|
||||
),
|
||||
EventType.PACKAGE_UNINSTALLED,
|
||||
)
|
||||
|
||||
for package, (previous_version, current_version) in updated.items():
|
||||
emit_package_event(
|
||||
event_bus,
|
||||
correlation_id,
|
||||
PackageUpdatedPayload(
|
||||
package_name=package.name,
|
||||
location=package.location,
|
||||
previous_version=previous_version,
|
||||
current_version=current_version,
|
||||
**kwargs,
|
||||
),
|
||||
EventType.PACKAGE_UPDATED,
|
||||
)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_tool_command_executed(
|
||||
event_bus: "EventBus", ctx: "CustomContext", *, tool: ToolType, result: "ToolResult"
|
||||
) -> None:
|
||||
correlation_id = ctx.obj.correlation_id
|
||||
|
||||
if not correlation_id:
|
||||
correlation_id = ctx.obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
process = result.process
|
||||
|
||||
payload = ToolCommandExecutedPayload(
|
||||
tool=tool,
|
||||
tool_path=result.tool_path,
|
||||
raw_command=[clean_parameter("", arg) for arg in process.args],
|
||||
duration_ms=result.duration_ms,
|
||||
status=ProcessStatus(
|
||||
stdout=process.stdout, stderr=process.stderr, return_code=process.returncode
|
||||
),
|
||||
)
|
||||
|
||||
# Scrub after binary coercion to str
|
||||
if payload.status.stdout:
|
||||
payload.status.stdout = scrub_sensitive_value(payload.status.stdout)
|
||||
if payload.status.stderr:
|
||||
payload.status.stderr = scrub_sensitive_value(payload.status.stderr)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.TOOL_COMMAND_EXECUTED,
|
||||
)
|
||||
|
||||
event_bus.emit(event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_command_executed(
|
||||
event_bus: "EventBus", ctx: "CustomContext", *, returned_code: int
|
||||
) -> None:
|
||||
root_context = get_root_context(ctx)
|
||||
NA = ""
|
||||
|
||||
started_at = getattr(root_context, "started_at", None) if root_context else None
|
||||
if started_at is not None:
|
||||
duration_ms = int((time.monotonic() - started_at) * 1000)
|
||||
else:
|
||||
duration_ms = 1
|
||||
|
||||
command_name = ctx.command.name if ctx.command.name is not None else NA
|
||||
raw_command = [clean_parameter("", arg) for arg in sys.argv]
|
||||
|
||||
params: List[CommandParam] = []
|
||||
|
||||
for idx, param in enumerate(ctx.command.params):
|
||||
param_name = param.name if param.name is not None else NA
|
||||
param_value = ctx.params.get(param_name)
|
||||
|
||||
# Scrub the parameter value if sensitive
|
||||
scrubbed_value = clean_parameter(param_name, param_value)
|
||||
|
||||
# Determine parameter source using Click's API
|
||||
click_source = ctx.get_parameter_source(param_name)
|
||||
source = translate_param_source(click_source)
|
||||
|
||||
display_name = param_name if param_name else None
|
||||
|
||||
params.append(
|
||||
CommandParam(
|
||||
position=idx, name=display_name, value=scrubbed_value, source=source
|
||||
)
|
||||
)
|
||||
|
||||
payload = CommandExecutedPayload(
|
||||
command_name=command_name,
|
||||
command_path=get_command_path(ctx),
|
||||
raw_command=raw_command,
|
||||
parameters=params,
|
||||
duration_ms=duration_ms,
|
||||
status=ProcessStatus(
|
||||
return_code=returned_code,
|
||||
),
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=ctx.obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.COMMAND_EXECUTED,
|
||||
)
|
||||
|
||||
try:
|
||||
if future := event_bus.emit(event):
|
||||
future.result(timeout=0.5)
|
||||
except Exception:
|
||||
logger.error("Emit Failed %s (%s)", event.type, event.id)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_command_error(
|
||||
event_bus: "EventBus",
|
||||
ctx: "CustomContext",
|
||||
*,
|
||||
message: str,
|
||||
traceback: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emit a CommandErrorEvent with sensitive data scrubbed.
|
||||
"""
|
||||
# Get command name from context if available
|
||||
command_name = getattr(ctx, "command", None)
|
||||
if command_name and command_name.name:
|
||||
command_name = command_name.name
|
||||
|
||||
scrub_traceback = None
|
||||
if traceback:
|
||||
scrub_traceback = scrub_sensitive_value(traceback)
|
||||
|
||||
command_path = get_command_path(ctx)
|
||||
raw_command = [scrub_sensitive_value(arg) for arg in sys.argv]
|
||||
|
||||
payload = CommandErrorPayload(
|
||||
command_name=command_name,
|
||||
raw_command=raw_command,
|
||||
command_path=command_path,
|
||||
error_message=scrub_sensitive_value(message),
|
||||
stacktrace=scrub_traceback,
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
payload=payload,
|
||||
event_type=EventType.COMMAND_ERROR,
|
||||
)
|
||||
|
||||
event_bus.emit(event)
|
||||
|
||||
|
||||
def emit_init_started(
|
||||
event_bus: "EventBus", ctx: Union["CustomContext", typer.Context]
|
||||
) -> None:
|
||||
"""
|
||||
Emit an InitStartedEvent and store it as a pending event in SafetyCLI object.
|
||||
|
||||
Args:
|
||||
event_bus: The event bus to emit on
|
||||
ctx: The Click context containing the SafetyCLI object
|
||||
"""
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = InitStartedPayload()
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.INIT_STARTED,
|
||||
)
|
||||
|
||||
if not send_and_flush(event_bus, event):
|
||||
# Store as pending event
|
||||
obj.pending_events.append(event)
|
||||
|
||||
|
||||
def emit_auth_started(event_bus: "EventBus", ctx: "CustomContext") -> None:
|
||||
"""
|
||||
Emit an AuthStartedEvent and store it as a pending event in SafetyCLI object.
|
||||
|
||||
Args:
|
||||
event_bus: The event bus to emit on
|
||||
ctx: The Click context containing the SafetyCLI object
|
||||
"""
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = AuthStartedPayload()
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.AUTH_STARTED,
|
||||
)
|
||||
|
||||
if not send_and_flush(event_bus, event):
|
||||
# Store as pending event
|
||||
obj.pending_events.append(event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_auth_completed(
|
||||
event_bus: "EventBus",
|
||||
ctx: "CustomContext",
|
||||
*,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emit an AuthCompletedEvent and submit all pending events together.
|
||||
|
||||
Args:
|
||||
event_bus: The event bus to emit on
|
||||
ctx: The Click context containing the SafetyCLI object
|
||||
success: Whether authentication was successful
|
||||
error_message: Optional error message if authentication failed
|
||||
"""
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = AuthCompletedPayload(success=success, error_message=error_message)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.AUTH_COMPLETED,
|
||||
)
|
||||
|
||||
for pending_event in obj.pending_events:
|
||||
event_bus.emit(pending_event)
|
||||
|
||||
obj.pending_events.clear()
|
||||
|
||||
# Emit auth completed event and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_firewall_setup_response_created(
|
||||
event_bus: "EventBus",
|
||||
ctx: Union["CustomContext", typer.Context],
|
||||
*,
|
||||
user_consent_requested: bool,
|
||||
user_consent: Optional[bool] = None,
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = FirewallSetupResponseCreatedPayload(
|
||||
user_consent_requested=user_consent_requested, user_consent=user_consent
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.FIREWALL_SETUP_RESPONSE_CREATED,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_codebase_setup_response_created(
|
||||
event_bus: "EventBus",
|
||||
ctx: Union["CustomContext", typer.Context],
|
||||
*,
|
||||
user_consent_requested: bool,
|
||||
user_consent: Optional[bool] = None,
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = CodebaseSetupResponseCreatedPayload(
|
||||
user_consent_requested=user_consent_requested, user_consent=user_consent
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.CODEBASE_SETUP_RESPONSE_CREATED,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_codebase_detection_status(
|
||||
event_bus: "EventBus",
|
||||
ctx: Union["CustomContext", typer.Context],
|
||||
*,
|
||||
detected: bool,
|
||||
detected_files: Optional[List[Path]] = None,
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = CodebaseDetectionStatusPayload(
|
||||
detected=detected,
|
||||
dependency_files=[
|
||||
DependencyFile(file_path=str(file)) for file in detected_files
|
||||
]
|
||||
if detected_files
|
||||
else None,
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.CODEBASE_DETECTION_STATUS,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_init_scan_completed(
|
||||
event_bus: "EventBus",
|
||||
ctx: Union["CustomContext", typer.Context],
|
||||
*,
|
||||
scan_id: Optional[str],
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = InitScanCompletedPayload(scan_id=scan_id)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.INIT_SCAN_COMPLETED,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_codebase_setup_completed(
|
||||
event_bus: "EventBus",
|
||||
ctx: Union["CustomContext", typer.Context],
|
||||
*,
|
||||
is_created: bool,
|
||||
codebase_id: Optional[str] = None,
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = CodebaseSetupCompletedPayload(
|
||||
is_created=is_created, codebase_id=codebase_id
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.CODEBASE_SETUP_COMPLETED,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_firewall_setup_completed(
|
||||
event_bus: "EventBus",
|
||||
ctx: "CustomContext",
|
||||
*,
|
||||
status: "FirewallConfigStatus",
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
tools = status_to_tool_status(status)
|
||||
|
||||
payload = FirewallSetupCompletedPayload(
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.FIREWALL_SETUP_COMPLETED,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
|
||||
|
||||
@conditional_emitter
|
||||
def emit_init_exited(
|
||||
event_bus: "EventBus",
|
||||
ctx: Union["CustomContext", typer.Context],
|
||||
*,
|
||||
exit_step: InitExitStep,
|
||||
) -> None:
|
||||
obj: "SafetyCLI" = ctx.obj
|
||||
|
||||
if not obj.correlation_id:
|
||||
obj.correlation_id = str(uuid.uuid4())
|
||||
|
||||
payload = InitExitedPayload(exit_step=exit_step)
|
||||
|
||||
event = create_event(
|
||||
correlation_id=obj.correlation_id,
|
||||
payload=payload,
|
||||
event_type=EventType.INIT_EXITED,
|
||||
)
|
||||
|
||||
# Emit and flush
|
||||
send_and_flush(event_bus, event)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user