This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,756 @@
# type: ignore
import importlib.util
import json
import logging
from functools import lru_cache
from typing import Any, Callable, Dict, Optional, Tuple, List, Literal
import requests
from authlib.integrations.base_client.errors import OAuthError
from authlib.integrations.requests_client import OAuth2Session
from requests.adapters import HTTPAdapter
from safety_schemas.models import STAGE_ID_MAPPING, Stage
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from safety.auth.constants import (
AUTH_SERVER_URL,
OPENID_CONFIG_URL,
)
from safety.constants import (
PLATFORM_API_CHECK_UPDATES_ENDPOINT,
PLATFORM_API_INITIALIZE_ENDPOINT,
PLATFORM_API_POLICY_ENDPOINT,
PLATFORM_API_PROJECT_CHECK_ENDPOINT,
PLATFORM_API_PROJECT_ENDPOINT,
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT,
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
REQUEST_TIMEOUT,
FeatureType,
get_config_setting,
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT,
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT,
)
from safety.error_handlers import output_exception
from safety.errors import (
InvalidCredentialError,
NetworkConnectionError,
RequestTimeoutError,
SafetyError,
ServerError,
TooManyRequestsError,
)
from safety.meta import get_meta_http_headers
from safety.models import SafetyCLI
from safety.scan.util import AuthenticationType
from safety.util import SafetyContext
LOG = logging.getLogger(__name__)
def get_keys(
client_session: OAuth2Session, openid_config: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Retrieve the keys from the OpenID configuration.
Args:
client_session (OAuth2Session): The OAuth2 session.
openid_config (Dict[str, Any]): The OpenID configuration.
Returns:
Optional[Dict[str, Any]]: The keys, if available.
"""
if "jwks_uri" in openid_config:
return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() # type: ignore
return None
def is_email_verified(info: Dict[str, Any]) -> Optional[bool]:
"""
Check if the email is verified.
Args:
info (Dict[str, Any]): The user information.
Returns:
bool: True
"""
# return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(
# CLAIM_EMAIL_VERIFIED_AUTH_SERVER
# )
# Always return True to avoid email verification
return True
def extract_detail(response: requests.Response) -> Optional[str]:
"""
Extract the reason from an HTTP response.
Args:
response (requests.Response): The response.
Returns:
Optional[str]: The reason.
"""
detail = None
try:
detail = response.json().get("detail")
except Exception:
LOG.debug("Failed to extract detail from response: %s", response.status_code)
return detail
def parse_response(func: Callable) -> Callable:
"""
Decorator to parse the response from an HTTP request.
Args:
func (Callable): The function to wrap.
Returns:
Callable: The wrapped function.
"""
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3),
reraise=True,
retry=retry_if_exception_type(
(
NetworkConnectionError,
RequestTimeoutError,
TooManyRequestsError,
ServerError,
)
),
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
)
def wrapper(*args, **kwargs):
try:
r = func(*args, **kwargs)
except OAuthError as e:
LOG.exception("OAuth failed: %s", e)
raise InvalidCredentialError(
message="Your token authentication expired, try login again."
)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException as e:
raise e
# TODO: Handle content as JSON and fallback to text for all responses
if r.status_code == 403:
reason = extract_detail(response=r)
raise InvalidCredentialError(
credential="Failed authentication.", reason=reason
)
if r.status_code == 429:
raise TooManyRequestsError(reason=r.text)
if r.status_code >= 400 and r.status_code < 500:
error_code = None
try:
data = r.json()
reason = data.get("detail", "Unable to find reason.")
error_code = data.get("error_code", None)
except Exception:
reason = r.reason
raise SafetyError(message=reason, error_code=error_code)
if r.status_code >= 500 and r.status_code < 600:
reason = extract_detail(response=r)
LOG.debug("ServerError %s -> Response returned: %s", r.status_code, r.text)
raise ServerError(reason=reason)
data = None
try:
data = r.json()
except json.JSONDecodeError as e:
raise ServerError(message=f"Bad JSON response from the server: {e}")
return data
return wrapper
class SafetyAuthSession(OAuth2Session):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Initialize the SafetyAuthSession.
Args:
*args (Any): Positional arguments for the parent class.
**kwargs (Any): Keyword arguments for the parent class.
"""
super().__init__(*args, **kwargs)
self.proxy_required: bool = False
self.proxy_timeout: Optional[int] = None
self.api_key = None
def get_credential(self) -> Optional[str]:
"""
Get the current authentication credential.
Returns:
Optional[str]: The API key, token, or None.
"""
if self.api_key:
return self.api_key
if self.token:
return SafetyContext().account
return None
def is_using_auth_credentials(self) -> bool:
"""
Check if the session is using authentication credentials.
This does NOT check if the client is authenticated.
Returns:
bool: True if using authentication credentials, False otherwise.
"""
return self.get_authentication_type() != AuthenticationType.none
def get_authentication_type(self) -> AuthenticationType:
"""
Get the type of authentication being used.
Returns:
AuthenticationType: The type of authentication.
"""
if self.api_key:
return AuthenticationType.api_key
if self.token:
return AuthenticationType.token
return AuthenticationType.none
def request(
self,
method: str,
url: str,
withhold_token: bool = False,
auth: Optional[Tuple] = None,
bearer: bool = True,
**kwargs: Any,
) -> requests.Response:
"""
Make an HTTP request with the appropriate authentication.
Use the right auth parameter for Safety supported auth types.
Args:
method (str): The HTTP method.
url (str): The URL to request.
withhold_token (bool): Whether to withhold the token.
auth (Optional[Tuple]): The authentication tuple.
bearer (bool): Whether to use bearer authentication.
**kwargs (Any): Additional keyword arguments.
Returns:
requests.Response: The HTTP response.
Raises:
Exception: If the request fails.
"""
# By default use the token_auth
TIMEOUT_KEYWARD = "timeout"
func_timeout = (
kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT
)
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"].update(get_meta_http_headers())
if self.api_key:
kwargs["headers"]["X-Api-Key"] = self.api_key
if not self.token or not bearer:
# Fallback to no token auth
auth = ()
# Override proxies
if self.proxies:
kwargs["proxies"] = self.proxies
if self.proxy_timeout:
kwargs["timeout"] = int(self.proxy_timeout) / 1000
if ("proxies" not in kwargs or not self.proxies) and self.proxy_required:
output_exception(
"Proxy connection is required but there is not a proxy setup.", # type: ignore
exit_code_output=True,
)
request_func = super(SafetyAuthSession, self).request
params = {
"method": method,
"url": url,
"withhold_token": withhold_token,
"auth": auth,
}
params.update(kwargs)
try:
return request_func(**params)
except Exception as e:
LOG.debug("Request failed: %s", e)
if self.proxy_required:
output_exception(
f"Proxy is required but the connection failed because: {e}", # type: ignore
exit_code_output=True,
)
if "proxies" in kwargs or self.proxies:
params["proxies"] = {}
params["timeout"] = func_timeout
self.proxies = {}
message = (
"The proxy configuration failed to function and was disregarded."
)
LOG.debug(message)
if message not in [
a["message"] for a in SafetyContext.local_announcements
]:
SafetyContext.local_announcements.append(
{"message": message, "type": "warning", "local": True}
)
return request_func(**params)
raise e
def fetch_openid_config(self) -> Any:
"""
Fetch the OpenID configuration from the authorization server.
Returns:
Any: The OpenID configuration.
"""
try:
openid_config = self.get(
url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT
).json()
except Exception as e:
LOG.debug("Unable to load the openID config: %s", e)
openid_config = {}
return openid_config
@parse_response
def fetch_user_info(self) -> Any:
"""
Fetch user information from the authorization server.
Returns:
Any: The user information.
"""
USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo"
r = self.get(url=USER_INFO_ENDPOINT)
return r
@parse_response
def check_project(
self,
scan_stage: str,
safety_source: str,
project_slug: Optional[str] = None,
git_origin: Optional[str] = None,
project_slug_source: Optional[str] = None,
) -> Any:
"""
Check project information.
Args:
scan_stage (str): The scan stage.
safety_source (str): The safety source.
project_slug (Optional[str]): The project slug.
git_origin (Optional[str]): The git origin.
project_slug_source (Optional[str]): The project slug source.
Returns:
Any: The project information.
"""
data = {
"scan_stage": scan_stage,
"safety_source": safety_source,
"project_slug": project_slug,
"project_slug_source": project_slug_source,
"git_origin": git_origin,
}
r = self.post(url=PLATFORM_API_PROJECT_CHECK_ENDPOINT, json=data)
return r
@parse_response
def project(self, project_id: str) -> Any:
"""
Get project information.
Args:
project_id (str): The project ID.
Returns:
Any: The project information.
"""
data = {"project": project_id}
return self.get(url=PLATFORM_API_PROJECT_ENDPOINT, params=data)
@parse_response
def download_policy(
self, project_id: Optional[str], stage: Stage, branch: Optional[str]
) -> Any:
"""
Download the project policy.
Args:
project_id (Optional[str]): The project ID.
stage (Stage): The stage.
branch (Optional[str]): The branch.
Returns:
Any: The policy data.
"""
data = {
"project": project_id,
"stage": STAGE_ID_MAPPING[stage],
"branch": branch,
}
return self.get(url=PLATFORM_API_POLICY_ENDPOINT, params=data)
@parse_response
def project_scan_request(self, project_id: str) -> Any:
"""
Request a project scan.
Args:
project_id (str): The project ID.
Returns:
Any: The scan request result.
"""
data = {"project_id": project_id}
return self.post(url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data)
@parse_response
def upload_report(self, json_report: str) -> Any:
"""
Upload a scan report.
Args:
json_report (str): The JSON report.
Returns:
Any: The upload result.
"""
return self.post(
url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
data=json_report,
headers={"Content-Type": "application/json"},
)
def upload_requirements(self, json_payload: str) -> Any:
"""
Upload a scan report.
Args:
json_payload (str): The JSON payload to upload.
Returns:
Any: The result of the upload operation.
"""
return self.post(
url=PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
data=json.dumps(json_payload),
headers={"Content-Type": "application/json"},
)
@parse_response
def check_updates(
self,
version: int,
safety_version: Optional[str] = None,
python_version: Optional[str] = None,
os_type: Optional[str] = None,
os_release: Optional[str] = None,
os_description: Optional[str] = None,
) -> Any:
"""
Check for updates.
Args:
version (int): The version.
safety_version (Optional[str]): The Safety version.
python_version (Optional[str]): The Python version.
os_type (Optional[str]): The OS type.
os_release (Optional[str]): The OS release.
os_description (Optional[str]): The OS description.
Returns:
Any: The update check result.
"""
data = {
"version": version,
"safety_version": safety_version,
"python_version": python_version,
"os_type": os_type,
"os_release": os_release,
"os_description": os_description,
}
return self.get(url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data)
@parse_response
def audit_packages(
self, packages: List[str], ecosystem: Literal["pypi", "npmjs"]
) -> Any:
"""
Audits packages for vulnerabilities
Args:
packages: list of package specifiers
ecosystem: the ecosystem to audit
Returns:
Any: The packages audit result.
"""
url = (
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT
if ecosystem == "npmjs"
else FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT
)
data = {"packages": [{"package_specifier": package} for package in packages]}
return self.post(url=url, json=data)
@parse_response
def initialize(self) -> Any:
"""
Initialize a run.
Returns:
Any: The initialization result.
"""
try:
response = self.get(
url=PLATFORM_API_INITIALIZE_ENDPOINT,
headers={"Content-Type": "application/json"},
timeout=5,
)
return response
except requests.exceptions.Timeout:
LOG.error("Auth request to initialize timed out after 5 seconds.")
except Exception:
LOG.exception("Exception trying to auth initialize", exc_info=True)
return None
class S3PresignedAdapter(HTTPAdapter):
def send( # type: ignore
self, request: requests.PreparedRequest, **kwargs: Any
) -> requests.Response:
"""
Send a request, removing the Authorization header.
Args:
request (requests.PreparedRequest): The prepared request.
**kwargs (Any): Additional keyword arguments.
Returns:
requests.Response: The response.
"""
request.headers.pop("Authorization", None)
return super().send(request, **kwargs)
@lru_cache(maxsize=1)
def is_jupyter_notebook() -> bool:
"""
Detects if the code is running in a Jupyter notebook environment, including
various cloud-hosted Jupyter notebooks.
Returns:
bool: True if the environment is identified as a Jupyter notebook (or
equivalent cloud-based environment), False otherwise.
Supported environments:
- Google Colab
- Amazon SageMaker
- Azure Notebooks
- Kaggle Notebooks
- Databricks Notebooks
- Datalore by JetBrains
- Paperspace Gradient Notebooks
- Classic Jupyter Notebook and JupyterLab
"""
if (
(
importlib.util.find_spec("google")
and importlib.util.find_spec("google.colab")
)
is not None
or importlib.util.find_spec("sagemaker") is not None
or importlib.util.find_spec("azureml") is not None
or importlib.util.find_spec("kaggle") is not None
or importlib.util.find_spec("dbutils") is not None
or importlib.util.find_spec("datalore") is not None
or importlib.util.find_spec("gradient") is not None
):
return True
# Detect classic Jupyter Notebook, JupyterLab, and other IPython kernel-based environments
try:
from IPython import get_ipython # type: ignore
ipython = get_ipython()
if ipython is not None and "IPKernelApp" in ipython.config:
return True
except (ImportError, AttributeError, NameError):
pass
return False
def save_flags_config(flags: Dict[FeatureType, bool]) -> None:
"""
Save feature flags configuration to file.
This function attempts to save feature flags to the configuration file
but will fail silently if unable to do so (e.g., due to permission issues
or disk problems). Silent failure is chosen to prevent configuration issues
from disrupting core application functionality.
Note that if saving fails, the application will continue using existing
or default flag values until the next restart.
Args:
flags: Dictionary mapping feature types to their enabled/disabled state
The operation will be logged (with stack trace) if it fails.
"""
import configparser
from safety.constants import CONFIG_FILE_USER
config = configparser.ConfigParser()
config.read(CONFIG_FILE_USER)
flag_settings = {key.name.upper(): str(value) for key, value in flags.items()}
if not config.has_section("settings"):
config.add_section("settings")
settings = dict(config.items("settings"))
settings.update(flag_settings)
for key, value in settings.items():
config.set("settings", key, value)
try:
with open(CONFIG_FILE_USER, "w") as config_file:
config.write(config_file)
except Exception:
LOG.exception("Unable to save flags configuration.")
def get_feature_name(feature: FeatureType, as_attr: bool = False) -> str:
"""Returns a formatted feature name with enabled suffix.
Args:
feature: The feature to format the name for
as_attr: If True, formats for attribute usage (underscore),
otherwise uses hyphen
Returns:
Formatted feature name string with enabled suffix
"""
name = feature.name.lower()
separator = "_" if as_attr else "-"
return f"{name}{separator}enabled"
def str_to_bool(value) -> Optional[bool]:
"""Convert basic string representations to boolean."""
if isinstance(value, bool):
return value
if isinstance(value, str):
value = value.lower().strip()
if value in ("true"):
return True
if value in ("false"):
return False
return None
def initialize(ctx: Any, refresh: bool = True) -> None:
"""
Initializes the run by loading settings.
Args:
ctx (Any): The context object.
refresh (bool): Whether to refresh settings from the server. Defaults to True.
"""
settings = None
current_values = {}
if not ctx.obj:
ctx.obj = SafetyCLI()
for feature in FeatureType:
value = get_config_setting(feature.name)
if value is not None:
current_values[feature] = str_to_bool(value)
if refresh:
try:
settings = ctx.obj.auth.client.initialize() # type: ignore
except Exception:
LOG.info("Unable to initialize, continue with default values.")
if settings:
for feature in FeatureType:
server_value = str_to_bool(settings.get(feature.config_key))
if server_value is not None:
if (
feature not in current_values
or current_values[feature] != server_value
):
current_values[feature] = server_value
save_flags_config(current_values)
for feature, value in current_values.items():
if value is not None:
setattr(ctx.obj, feature.attr_name, value)