This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,10 @@
from .cli_utils import auth_options, build_client_session, proxy_options, inject_session
from .cli import auth
__all__ = [
"build_client_session",
"proxy_options",
"auth_options",
"inject_session",
"auth",
]

View File

@@ -0,0 +1,402 @@
# type: ignore
import logging
import sys
from datetime import datetime
from safety.auth.models import Auth
from safety.auth.utils import initialize, is_email_verified
from safety.console import main_console as console
from safety.constants import (
MSG_FINISH_REGISTRATION_TPL,
MSG_VERIFICATION_HINT,
DEFAULT_EPILOG,
)
from safety.meta import get_version
from safety.decorators import notify
try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated
from typing import Optional
import click
import typer
from rich.padding import Padding
from typer import Typer
from safety.auth.main import (
clean_session,
get_auth_info,
get_authorization_data,
get_token,
)
from safety.auth.server import process_browser_callback
from safety.events.utils import emit_auth_started, emit_auth_completed
from safety.util import initialize_event_bus
from safety.scan.constants import (
CLI_AUTH_COMMAND_HELP,
CLI_AUTH_HEADLESS_HELP,
CLI_AUTH_LOGIN_HELP,
CLI_AUTH_LOGOUT_HELP,
CLI_AUTH_STATUS_HELP,
)
from ..cli_util import SafetyCLISubGroup, get_command_for, pass_safety_cli_obj
from safety.error_handlers import handle_cmd_exception
from .constants import (
MSG_FAIL_LOGIN_AUTHED,
MSG_FAIL_REGISTER_AUTHED,
MSG_LOGOUT_DONE,
MSG_LOGOUT_FAILED,
MSG_NON_AUTHENTICATED,
)
LOG = logging.getLogger(__name__)
auth_app = Typer(rich_markup_mode="rich", name="auth")
CMD_LOGIN_NAME = "login"
CMD_REGISTER_NAME = "register"
CMD_STATUS_NAME = "status"
CMD_LOGOUT_NAME = "logout"
DEFAULT_CMD = CMD_LOGIN_NAME
@auth_app.callback(
invoke_without_command=True,
cls=SafetyCLISubGroup,
help=CLI_AUTH_COMMAND_HELP,
epilog=DEFAULT_EPILOG,
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
@pass_safety_cli_obj
def auth(ctx: typer.Context) -> None:
"""
Authenticate Safety CLI with your account.
Args:
ctx (typer.Context): The Typer context object.
"""
LOG.info("auth started")
# If no subcommand is invoked, forward to the default command
if not ctx.invoked_subcommand:
default_command = get_command_for(name=DEFAULT_CMD, typer_instance=auth_app)
return ctx.forward(default_command)
def fail_if_authenticated(ctx: typer.Context, with_msg: str) -> None:
"""
Exits the command if the user is already authenticated.
Args:
ctx (typer.Context): The Typer context object.
with_msg (str): The message to display if authenticated.
"""
info = get_auth_info(ctx)
if info:
console.print()
email = f"[green]{ctx.obj.auth.email}[/green]"
if not ctx.obj.auth.email_verified:
email = f"{email} {render_email_note(ctx.obj.auth)}"
console.print(with_msg.format(email=email))
sys.exit(0)
def render_email_note(auth: Auth) -> str:
"""
Renders a note indicating whether email verification is required.
Args:
auth (Auth): The Auth object.
Returns:
str: The rendered email note.
"""
return "" if auth.email_verified else "[red](email verification required)[/red]"
def render_successful_login(auth: Auth, organization: Optional[str] = None) -> None:
"""
Renders a message indicating a successful login.
Args:
auth (Auth): The Auth object.
organization (Optional[str]): The organization name.
"""
DEFAULT = "--"
name = auth.name if auth.name else DEFAULT
email = auth.email if auth.email else DEFAULT
email_note = render_email_note(auth)
console.print()
console.print("[bold][green]You're authenticated[/green][/bold]")
if name and name != email:
details = [f"[green][bold]Account:[/bold] {name}, {email}[/green] {email_note}"]
else:
details = [f"[green][bold]Account:[/bold] {email}[/green] {email_note}"]
if organization:
details.insert(0, f"[green][bold]Organization:[/bold] {organization}[green]")
for msg in details:
console.print(Padding(msg, (0, 0, 0, 1)), emoji=True)
@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
@handle_cmd_exception
@notify
def login(
ctx: typer.Context,
headless: Annotated[
Optional[bool],
typer.Option(
"--headless",
help=CLI_AUTH_HEADLESS_HELP,
),
] = None,
) -> None:
"""
Authenticate Safety CLI with your safetycli.com account using your default browser.
Args:
ctx (typer.Context): The Typer context object.
headless (bool): Whether to run in headless mode.
"""
LOG.info("login started")
headless = headless is True
# Check if the user is already authenticated
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)
console.print()
info = None
brief_msg: str = (
"Redirecting your browser to log in; once authenticated, "
"return here to start using Safety"
)
if ctx.obj.auth.org:
console.print(
f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] organization."
)
if headless:
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"
# Get authorization data and generate the authorization URL
uri, initial_state = get_authorization_data(
client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org,
headless=headless,
)
click.secho(brief_msg)
click.echo()
emit_auth_started(ctx.obj.event_bus, ctx)
# Process the browser callback to complete the authentication
info = process_browser_callback(
uri, initial_state=initial_state, ctx=ctx, headless=headless
)
is_success = False
error_msg = None
if info:
if info.get("email", None):
organization = None
if ctx.obj.auth.org and ctx.obj.auth.org.name:
organization = ctx.obj.auth.org.name
ctx.obj.auth.refresh_from(info)
if headless:
console.print()
initialize(ctx, refresh=True)
initialize_event_bus(ctx=ctx)
render_successful_login(ctx.obj.auth, organization=organization)
is_success = True
console.print()
if ctx.obj.auth.org or ctx.obj.auth.email_verified:
if not getattr(ctx.obj, "only_auth_msg", False):
console.print(
"[tip]Tip[/tip]: now try [bold]`safety scan`[/bold] in your projects root "
"folder to run a project scan or [bold]`safety -help`[/bold] to learn more."
)
else:
console.print(
MSG_FINISH_REGISTRATION_TPL.format(email=ctx.obj.auth.email)
)
console.print()
console.print(MSG_VERIFICATION_HINT)
else:
click.secho("Safety is now authenticated but your email is missing.")
else:
error_msg = ":stop_sign: [red]"
if ctx.obj.auth.org:
error_msg += (
f"Error logging into {ctx.obj.auth.org.name} organization "
f"with auth ID: {ctx.obj.auth.org.id}."
)
else:
error_msg += "Error logging into Safety."
error_msg += (
" Please try again, or use [bold]`safety auth -help`[/bold] "
"for more information[/red]"
)
console.print(error_msg, emoji=True)
emit_auth_completed(
ctx.obj.event_bus, ctx, success=is_success, error_message=error_msg
)
@auth_app.command(name=CMD_LOGOUT_NAME, help=CLI_AUTH_LOGOUT_HELP)
@handle_cmd_exception
@notify
def logout(ctx: typer.Context) -> None:
"""
Log out of your current session.
Args:
ctx (typer.Context): The Typer context object.
"""
LOG.info("logout started")
id_token = get_token("id_token")
msg = MSG_NON_AUTHENTICATED
if id_token:
# Clean the session if an ID token is found
if clean_session(ctx.obj.auth.client):
msg = MSG_LOGOUT_DONE
else:
msg = MSG_LOGOUT_FAILED
console.print(msg)
@auth_app.command(name=CMD_STATUS_NAME, help=CLI_AUTH_STATUS_HELP)
@click.option(
"--ensure-auth/--no-ensure-auth",
default=False,
help="This will keep running the command until anauthentication is made.",
)
@click.option(
"--login-timeout",
"-w",
type=int,
default=600,
help="Max time allowed to wait for an authentication.",
)
@handle_cmd_exception
@notify
def status(
ctx: typer.Context, ensure_auth: bool = False, login_timeout: int = 600
) -> None:
"""
Display Safety CLI's current authentication status.
Args:
ctx (typer.Context): The Typer context object.
ensure_auth (bool): Whether to keep running until authentication is made.
login_timeout (int): Max time allowed to wait for authentication.
"""
LOG.info("status started")
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
safety_version = get_version()
console.print(f"[{current_time}]: Safety {safety_version}")
info = get_auth_info(ctx)
initialize(ctx, refresh=True)
if ensure_auth:
console.print("running: safety auth status --ensure-auth")
console.print()
if info:
verified = is_email_verified(info)
email_status = " [red](email not verified)[/red]" if not verified else ""
console.print(f"[green]Authenticated as {info['email']}[/green]{email_status}")
elif ensure_auth:
console.print(
"Safety is not authenticated. Launching default browser to log in"
)
console.print()
uri, initial_state = get_authorization_data(
client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
organization=ctx.obj.auth.org,
ensure_auth=ensure_auth,
)
# Process the browser callback to complete the authentication
info = process_browser_callback(
uri, initial_state=initial_state, timeout=login_timeout, ctx=ctx
)
if not info:
console.print(
f"[red]Timeout error ({login_timeout} seconds): not successfully authenticated without the timeout period.[/red]"
)
sys.exit(1)
organization = None
if ctx.obj.auth.org and ctx.obj.auth.org.name:
organization = ctx.obj.auth.org.name
render_successful_login(ctx.obj.auth, organization=organization)
console.print()
else:
console.print(MSG_NON_AUTHENTICATED)
@auth_app.command(name=CMD_REGISTER_NAME)
@handle_cmd_exception
@notify
def register(ctx: typer.Context) -> None:
"""
Create a new user account for the safetycli.com service.
Args:
ctx (typer.Context): The Typer context object.
"""
LOG.info("register started")
# Check if the user is already authenticated
fail_if_authenticated(ctx, with_msg=MSG_FAIL_REGISTER_AUTHED)
# Get authorization data and generate the registration URL
uri, initial_state = get_authorization_data(
client=ctx.obj.auth.client,
code_verifier=ctx.obj.auth.code_verifier,
sign_up=True,
)
console.print(
"\nRedirecting your browser to register for a free account. Once registered, return here to start using Safety."
)
console.print()
# Process the browser callback to complete the registration
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx)
console.print()
if info:
console.print(f"[green]Successfully registered {info.get('email')}[/green]")
console.print()
else:
console.print("[red]Unable to register in this time, try again.[/red]")

View File

@@ -0,0 +1,290 @@
import logging
from typing import Dict, Optional, Tuple, Any, Callable
import click
from .main import (
get_auth_info,
get_host_config,
get_organization,
get_proxy_config,
get_redirect_url,
get_token_data,
save_auth_config,
get_token,
clean_session,
)
from authlib.common.security import generate_token
from safety.auth.constants import CLIENT_ID
from safety.auth.models import Organization, Auth
from safety.auth.utils import (
S3PresignedAdapter,
SafetyAuthSession,
get_keys,
is_email_verified,
)
from safety.scan.constants import (
CLI_KEY_HELP,
CLI_PROXY_HOST_HELP,
CLI_PROXY_PORT_HELP,
CLI_PROXY_PROTOCOL_HELP,
CLI_STAGE_HELP,
)
from safety.util import DependentOption, SafetyContext, get_proxy_dict
from safety.models import SafetyCLI
from safety_schemas.models import Stage
LOG = logging.getLogger(__name__)
def build_client_session(
api_key: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Tuple[SafetyAuthSession, Dict[str, Any]]:
"""
Builds and configures the client session for authentication.
Args:
api_key (Optional[str]): The API key for authentication.
proxies (Optional[Dict[str, str]]): Proxy configuration.
headers (Optional[Dict[str, str]]): Additional headers.
Returns:
Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration.
"""
kwargs = {}
target_proxies = proxies
# Global proxy defined in the config.ini
proxy_config, proxy_timeout, proxy_required = get_proxy_config()
if not proxies:
target_proxies = proxy_config
def update_token(tokens, **kwargs):
save_auth_config(
access_token=tokens["access_token"],
id_token=tokens["id_token"],
refresh_token=tokens["refresh_token"],
)
load_auth_session(click_ctx=click.get_current_context(silent=True)) # type: ignore
client_session = SafetyAuthSession(
client_id=CLIENT_ID,
code_challenge_method="S256",
redirect_uri=get_redirect_url(),
update_token=update_token,
scope="openid email profile offline_access",
**kwargs,
)
client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter())
client_session.proxy_required = proxy_required
client_session.proxy_timeout = proxy_timeout
client_session.proxies = target_proxies # type: ignore
client_session.headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
openid_config = client_session.fetch_openid_config()
client_session.metadata["token_endpoint"] = openid_config.get(
"token_endpoint", None
)
if api_key:
client_session.api_key = api_key # type: ignore
client_session.headers["X-Api-Key"] = api_key
if headers:
client_session.headers.update(headers)
return client_session, openid_config
def load_auth_session(click_ctx: click.Context) -> None:
"""
Loads the authentication session from the context.
Args:
click_ctx (click.Context): The Click context object.
"""
if not click_ctx:
LOG.warning("Click context is needed to be able to load the Auth data.")
return
client = click_ctx.obj.auth.client
keys = click_ctx.obj.auth.keys
access_token: str = get_token(name="access_token") # type: ignore
refresh_token: str = get_token(name="refresh_token") # type: ignore
id_token: str = get_token(name="id_token") # type: ignore
if access_token and keys:
try:
token = get_token_data(access_token, keys, silent_if_expired=True)
client.token = {
"access_token": access_token,
"refresh_token": refresh_token,
"id_token": id_token,
"token_type": "bearer",
"expires_at": token.get("exp", None), # type: ignore
}
except Exception as e:
print(e)
clean_session(client)
def proxy_options(func: Callable) -> Callable:
"""
Decorator that defines proxy options for Click commands.
Options defined per command, this will override the proxy settings defined in the
config.ini file.
Args:
func (Callable): The Click command function.
Returns:
Callable: The wrapped Click command function with proxy options.
"""
func = click.option(
"--proxy-protocol",
type=click.Choice(["http", "https"]),
default="https",
cls=DependentOption,
required_options=["proxy_host"],
help=CLI_PROXY_PROTOCOL_HELP,
)(func)
func = click.option(
"--proxy-port",
multiple=False,
type=int,
default=80,
cls=DependentOption,
required_options=["proxy_host"],
help=CLI_PROXY_PORT_HELP,
)(func)
func = click.option(
"--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP
)(func)
return func
def auth_options(stage: bool = True) -> Callable:
"""
Decorator that defines authentication options for Click commands.
Args:
stage (bool): Whether to include the stage option.
Returns:
Callable: The decorator function.
"""
def decorator(func: Callable) -> Callable:
func = click.option(
"--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP
)(func)
if stage:
func = click.option(
"--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP
)(func)
return func
return decorator
def inject_session(
ctx: click.Context,
proxy_protocol: Optional[str] = None,
proxy_host: Optional[str] = None,
proxy_port: Optional[str] = None,
key: Optional[str] = None,
stage: Optional[Stage] = None,
invoked_command: str = "",
) -> Any:
org: Optional[Organization] = get_organization()
if not stage:
host_stage = get_host_config(key_name="stage")
stage = host_stage if host_stage else Stage.development
proxy_config: Optional[Dict[str, str]] = get_proxy_dict(
proxy_protocol, # type: ignore
proxy_host, # type: ignore
proxy_port, # type: ignore
)
client_session, openid_config = build_client_session(
api_key=key, proxies=proxy_config
)
keys = get_keys(client_session, openid_config)
auth = Auth(
stage=stage,
keys=keys,
org=org,
client_id=CLIENT_ID, # type: ignore
client=client_session,
code_verifier=generate_token(48),
)
if not ctx.obj:
ctx.obj = SafetyCLI()
ctx.obj.auth = auth
load_auth_session(ctx)
info = get_auth_info(ctx)
if info:
ctx.obj.auth.name = info.get("name")
ctx.obj.auth.email = info.get("email")
ctx.obj.auth.email_verified = is_email_verified(info) # type: ignore
SafetyContext().account = info["email"]
else:
SafetyContext().account = ""
@ctx.call_on_close
def clean_up_on_close():
LOG.debug("Closing requests session.")
ctx.obj.auth.client.close()
if ctx.obj.event_bus:
from safety.events.utils import (
create_internal_event,
InternalEventType,
InternalPayload,
)
payload = InternalPayload(ctx=ctx)
flush_event = create_internal_event(
event_type=InternalEventType.FLUSH_SECURITY_TRACES, payload=payload
)
close_event = create_internal_event(
event_type=InternalEventType.CLOSE_RESOURCES, payload=payload
)
flush_future = ctx.obj.event_bus.emit(flush_event)
close_future = ctx.obj.event_bus.emit(close_event)
# Wait for both events to be processed
if flush_future and close_future:
try:
flush_future.result()
close_future.result()
except Exception as e:
LOG.warning(f"Error waiting for events to process: {e}")
ctx.obj.event_bus.stop()

View File

@@ -0,0 +1,35 @@
from pathlib import Path
from safety.constants import USER_CONFIG_DIR, get_config_setting
AUTH_CONFIG_FILE_NAME = "auth.ini"
AUTH_CONFIG_USER = USER_CONFIG_DIR / Path(AUTH_CONFIG_FILE_NAME)
HOST: str = "localhost"
CLIENT_ID = get_config_setting("CLIENT_ID")
AUTH_SERVER_URL = get_config_setting("AUTH_SERVER_URL")
SAFETY_PLATFORM_URL = get_config_setting("SAFETY_PLATFORM_URL")
OPENID_CONFIG_URL = f"{AUTH_SERVER_URL}/.well-known/openid-configuration"
CLAIM_EMAIL_VERIFIED_API = "https://api.safetycli.com/email_verified"
CLAIM_EMAIL_VERIFIED_AUTH_SERVER = "email_verified"
CLI_AUTH = f"{SAFETY_PLATFORM_URL}/cli/auth"
CLI_AUTH_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/auth/success"
CLI_AUTH_LOGOUT = f"{SAFETY_PLATFORM_URL}/cli/logout"
CLI_CALLBACK = f"{SAFETY_PLATFORM_URL}/cli/callback"
CLI_LOGOUT_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/logout/success"
MSG_NON_AUTHENTICATED = (
"Safety is not authenticated. Please run 'safety auth login' to log in."
)
MSG_FAIL_LOGIN_AUTHED = """[green]You are authenticated as[/green] {email}.
To log into a different account, first logout via: safety auth logout, and then login again."""
MSG_FAIL_REGISTER_AUTHED = "You are currently logged in to {email}, please logout using `safety auth logout` before registering a new account."
MSG_LOGOUT_DONE = "[green]Logout done.[/green]"
MSG_LOGOUT_FAILED = "[red]Logout failed. Try again.[/red]"

View File

@@ -0,0 +1,329 @@
import configparser
from typing import Any, Dict, Optional, Tuple, Union
from authlib.oidc.core import CodeIDToken
from authlib.jose import jwt
from authlib.jose.errors import ExpiredTokenError
from safety.auth.models import Organization
from safety.auth.constants import (
CLI_AUTH_LOGOUT,
CLI_CALLBACK,
AUTH_CONFIG_USER,
CLI_AUTH,
)
from safety.constants import CONFIG
from safety_schemas.models import Stage
from safety.util import get_proxy_dict
def get_authorization_data(
client,
code_verifier: str,
organization: Optional[Organization] = None,
sign_up: bool = False,
ensure_auth: bool = False,
headless: bool = False,
) -> Tuple[str, str]:
"""
Generate the authorization URL for the authentication process.
Args:
client: The authentication client.
code_verifier (str): The code verifier for the PKCE flow.
organization (Optional[Organization]): The organization to authenticate with.
sign_up (bool): Whether the URL is for sign-up.
ensure_auth (bool): Whether to ensure authentication.
headless (bool): Whether to run in headless mode.
Returns:
Tuple[str, str]: The authorization URL and initial state.
"""
kwargs = {
"sign_up": sign_up,
"locale": "en",
"ensure_auth": ensure_auth,
"headless": headless,
}
if organization:
kwargs["organization"] = organization.id
return client.create_authorization_url(
CLI_AUTH, code_verifier=code_verifier, **kwargs
)
def get_logout_url(id_token: str) -> str:
"""
Generate the logout URL.
Args:
id_token (str): The ID token.
Returns:
str: The logout URL.
"""
return f"{CLI_AUTH_LOGOUT}?id_token={id_token}"
def get_redirect_url() -> str:
"""
Get the redirect URL for the authentication callback.
Returns:
str: The redirect URL.
"""
return CLI_CALLBACK
def get_organization() -> Optional[Organization]:
"""
Retrieve the organization configuration.
Returns:
Optional[Organization]: The organization object, or None if not configured.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
org_conf: Union[Dict[str, str], configparser.SectionProxy] = (
config["organization"] if "organization" in config.sections() else {}
)
org_id: Optional[str] = (
org_conf["id"].replace('"', "") if org_conf.get("id", None) else None
)
org_name: Optional[str] = (
org_conf["name"].replace('"', "") if org_conf.get("name", None) else None
)
if not org_id:
return None
org = Organization(id=org_id, name=org_name) # type: ignore
return org
def get_auth_info(ctx) -> Optional[Dict]:
"""
Retrieve the authentication information.
Args:
ctx: The context object containing authentication data.
Returns:
Optional[Dict]: The authentication information, or None if not authenticated.
"""
from safety.auth.utils import is_email_verified
info = None
if ctx.obj.auth.client.token:
try:
info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore
verified = is_email_verified(info) # type: ignore
if not verified:
user_info = ctx.obj.auth.client.fetch_user_info()
verified = is_email_verified(user_info)
if verified:
# refresh only if needed
raise ExpiredTokenError
except ExpiredTokenError:
# id_token expired. So fire a manually a refresh
try:
ctx.obj.auth.client.refresh_token(
ctx.obj.auth.client.metadata.get("token_endpoint"),
refresh_token=ctx.obj.auth.client.token.get("refresh_token"),
)
info = get_token_data(
get_token(name="id_token"), # type: ignore
keys=ctx.obj.auth.keys, # type: ignore
)
except Exception as _e:
clean_session(ctx.obj.auth.client)
except Exception as _g:
clean_session(ctx.obj.auth.client)
return info
def get_token_data(
token: str, keys: Any, silent_if_expired: bool = False
) -> Optional[Dict]:
"""
Decode and validate the token data.
Args:
token (str): The token to decode.
keys (Any): The keys to use for decoding.
silent_if_expired (bool): Whether to silently ignore expired tokens.
Returns:
Optional[Dict]: The decoded token data, or None if invalid.
"""
claims = jwt.decode(token, keys, claims_cls=CodeIDToken)
try:
claims.validate()
except ExpiredTokenError as e:
if not silent_if_expired:
raise e
return claims
def get_token(name: str = "access_token") -> Optional[str]:
""" "
Retrieve a token from the local authentication configuration.
This returns tokens saved in the local auth configuration.
There are two types of tokens: access_token and id_token
Args:
name (str): The name of the token to retrieve.
Returns:
Optional[str]: The token value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
if "auth" in config.sections() and name in config["auth"]:
value = config["auth"][name]
if value:
return value
return None
def get_host_config(key_name: str) -> Optional[Any]:
"""
Retrieve a configuration value from the host configuration.
Args:
key_name (str): The name of the configuration key.
Returns:
Optional[Any]: The configuration value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
if not config.has_section("host"):
return None
host_section = dict(config.items("host"))
if key_name in host_section:
if key_name == "stage":
# Support old alias in the config.ini
if host_section[key_name] == "dev":
host_section[key_name] = "development"
if host_section[key_name] not in {env.value for env in Stage}:
return None
return Stage(host_section[key_name])
return None
def str_to_bool(s: str) -> bool:
"""
Convert a string to a boolean value.
Args:
s (str): The string to convert.
Returns:
bool: The converted boolean value.
Raises:
ValueError: If the string cannot be converted.
"""
if s.lower() == "true" or s == "1":
return True
elif s.lower() == "false" or s == "0":
return False
else:
raise ValueError(f"Cannot convert '{s}' to a boolean value.")
def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]:
"""
Retrieve the proxy configuration.
Returns:
Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
proxy_dictionary = None
required = False
timeout = None
proxy = None
if config.has_section("proxy"):
proxy = dict(config.items("proxy"))
if proxy:
try:
proxy_dictionary = get_proxy_dict(
proxy["protocol"],
proxy["host"],
proxy["port"], # type: ignore
)
required = str_to_bool(proxy["required"])
timeout = proxy["timeout"]
except Exception:
pass
return proxy_dictionary, timeout, required # type: ignore
def clean_session(client) -> bool:
"""
Clean the authentication session.
Args:
client: The authentication client.
Returns:
bool: Always returns True.
"""
config = configparser.ConfigParser()
config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile)
client.token = None
return True
def save_auth_config(
access_token: Optional[str] = None,
id_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> None:
"""
Save the authentication configuration.
Args:
access_token (Optional[str]): The access token.
id_token (Optional[str]): The ID token.
refresh_token (Optional[str]): The refresh token.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
config["auth"] = { # type: ignore
"access_token": access_token,
"id_token": id_token,
"refresh_token": refresh_token,
}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile) # type: ignore

View File

@@ -0,0 +1,105 @@
from dataclasses import dataclass
import os
from typing import Any, Optional, Dict
from authlib.integrations.base_client import BaseOAuth
from safety_schemas.models import Stage
@dataclass
class Organization:
id: str
name: str
def to_dict(self) -> Dict:
"""
Convert the Organization instance to a dictionary.
Returns:
dict: The dictionary representation of the organization.
"""
return {"id": self.id, "name": self.name}
@dataclass
class Auth:
org: Optional[Organization]
keys: Any
client: Any
code_verifier: str
client_id: str
stage: Optional[Stage] = Stage.development
email: Optional[str] = None
name: Optional[str] = None
email_verified: bool = False
def is_valid(self) -> bool:
"""
Check if the authentication information is valid.
Returns:
bool: True if valid, False otherwise.
"""
if os.getenv("SAFETY_DB_DIR"):
return True
if not self.client:
return False
if self.client.api_key:
return True
return bool(self.client.token and self.email_verified)
def refresh_from(self, info: Dict) -> None:
"""
Refresh the authentication information from the provided info.
Args:
info (dict): The information to refresh from.
"""
from safety.auth.utils import is_email_verified
self.name = info.get("name")
self.email = info.get("email")
self.email_verified = is_email_verified(info) # type: ignore
def get_auth_method(self) -> str:
"""
Get the authentication method.
Returns:
str: The authentication method.
"""
if self.client.api_key:
return "API Key"
if self.client.token:
return "Token"
return "None"
class XAPIKeyAuth(BaseOAuth):
def __init__(self, api_key: str) -> None:
"""
Initialize the XAPIKeyAuth instance.
Args:
api_key (str): The API key to use for authentication.
"""
self.api_key = api_key
def __call__(self, r: Any) -> Any:
"""
Add the API key to the request headers.
Args:
r (Any): The request object.
Returns:
Any: The modified request object.
"""
r.headers["X-API-Key"] = self.api_key
return r

View File

@@ -0,0 +1,308 @@
# type: ignore
import http.server
import json
import logging
import random
import socket
import sys
import time
from typing import Any, Optional, Dict, Tuple
import urllib.parse
import threading
import click
from safety.auth.utils import is_jupyter_notebook
from safety.console import main_console as console
from safety.auth.constants import (
AUTH_SERVER_URL,
CLI_AUTH_SUCCESS,
CLI_LOGOUT_SUCCESS,
HOST,
)
from safety.auth.main import save_auth_config
from rich.prompt import Prompt
LOG = logging.getLogger(__name__)
def find_available_port() -> Optional[int]:
"""
Find an available port on localhost within the dynamic port range (49152-65536).
Returns:
Optional[int]: An available port number, or None if no ports are available.
"""
# Dynamic ports IANA
port_range = list(range(49152, 65536))
random.shuffle(port_range)
for port in port_range:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.connect(("localhost", port))
# If the connect succeeds, the port is already in use
except socket.error:
# If the connect fails, the port is available
return port
return None
def auth_process(
code: str, state: str, initial_state: str, code_verifier: str, client: Any
) -> Any:
"""
Process the authentication callback and exchange the authorization code for tokens.
Args:
code (str): The authorization code.
state (str): The state parameter from the callback.
initial_state (str): The initial state parameter.
code_verifier (str): The code verifier for PKCE.
client (Any): The OAuth client.
Returns:
Any: The user information.
Raises:
SystemExit: If there is an error during authentication.
"""
err = None
if initial_state is None or initial_state != state:
err = (
"The state parameter value provided does not match the expected "
"value. The state parameter is used to protect against Cross-Site "
"Request Forgery (CSRF) attacks. For security reasons, the "
"authorization process cannot proceed with an invalid state "
"parameter value. Please try again, ensuring that the state "
"parameter value provided in the authorization request matches "
"the value returned in the callback."
)
if err:
click.secho(f"Error: {err}", fg="red")
sys.exit(1)
try:
tokens = client.fetch_token(
url=f"{AUTH_SERVER_URL}/oauth/token",
code_verifier=code_verifier,
client_id=client.client_id,
grant_type="authorization_code",
code=code,
)
save_auth_config(
access_token=tokens["access_token"],
id_token=tokens["id_token"],
refresh_token=tokens["refresh_token"],
)
return client.fetch_user_info()
except Exception as e:
LOG.exception(e)
sys.exit(1)
class CallbackHandler(http.server.BaseHTTPRequestHandler):
def auth(self, code: str, state: str, err: str, error_description: str) -> None:
"""
Handle the authentication callback.
Args:
code (str): The authorization code.
state (str): The state parameter.
err (str): The error message, if any.
error_description (str): The error description, if any.
"""
initial_state = self.server.initial_state
ctx = self.server.ctx
result = auth_process(
code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client,
)
self.server.callback = result
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})
def logout(self) -> None:
"""
Handle the logout callback.
"""
ctx = self.server.ctx
uri = CLI_LOGOUT_SUCCESS
if ctx.obj.auth.org:
uri = f"{uri}&org_id={ctx.obj.auth.org.id}"
self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={})
def do_GET(self) -> None:
"""
Handle GET requests.
"""
query = urllib.parse.urlparse(self.path).query
params = urllib.parse.parse_qs(query)
callback_type: Optional[str] = None
try:
c_type = params.get("type", [])
if (
isinstance(c_type, list)
and len(c_type) == 1
and isinstance(c_type[0], str)
):
callback_type = c_type[0]
except Exception:
msg = "Unable to process the callback, try again."
self.send_error(400, msg)
click.secho("Unable to process the callback, try again.")
return
if callback_type == "logout":
self.logout()
return
code = params.get("code", [""])[0]
state = params.get("state", [""])[0]
err = params.get("error", [""])[0]
error_description = params.get("error_description", [""])[0]
self.auth(code=code, state=state, err=err, error_description=error_description)
def do_redirect(self, location: str, params: Dict) -> None:
"""
Redirect the client to the specified location.
Args:
location (str): The URL to redirect to.
params (dict): Additional parameters for the redirection.
"""
self.send_response(302)
self.send_header("Location", location)
self.send_header("Connection", "close")
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
self.end_headers()
def log_message(self, format: str, *args: Any) -> None:
"""
Log an arbitrary message.
Args:
format (str): The format string.
args (Any): Arguments for the format string.
"""
LOG.info(format % args)
def process_browser_callback(uri: str, **kwargs: Any) -> Any:
"""
Process the browser callback for authentication.
Args:
uri (str): The authorization URL.
**kwargs (Any): Additional keyword arguments.
Returns:
Any: The user information.
Raises:
SystemExit: If there is an error during the process.
"""
class ThreadedHTTPServer(http.server.HTTPServer):
def __init__(self, server_address: Tuple, RequestHandlerClass: Any) -> None:
"""
Initialize the ThreadedHTTPServer.
Args:
server_address (Tuple): The server address as a tuple (host, port).
RequestHandlerClass (Any): The request handler class.
"""
super().__init__(server_address, RequestHandlerClass)
self.initial_state = None
self.ctx = None
self.callback = None
self.timeout_reached = False
def handle_timeout(self) -> None:
"""
Handle server timeout.
"""
self.timeout_reached = True
return super().handle_timeout()
PORT = find_available_port()
if not PORT:
click.secho("No available ports.")
sys.exit(1)
try:
headless = kwargs.get("headless", False)
initial_state = kwargs.get("initial_state", None)
ctx = kwargs.get("ctx", None)
message = "Copy and paste this URL into your browser:\n:icon_warning: Ensure there are no extra spaces, especially at line breaks, as they may break the link."
if not headless:
# Start a threaded HTTP server to handle the callback
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
server.initial_state = initial_state
server.timeout = kwargs.get("timeout", 600)
server.ctx = ctx
server_thread = threading.Thread(target=server.handle_request)
server_thread.start()
message = "If the browser does not automatically open in 5 seconds, copy and paste this url into your browser:"
target = uri if headless else f"{uri}&port={PORT}"
if is_jupyter_notebook():
console.print(f"{message} {target}")
else:
console.print(f"{message} [link={target}]{target}[/link]")
if headless:
# Handle the headless mode where user manually provides the response
exchange_data = None
while not exchange_data:
auth_code_text = Prompt.ask(
"Paste the response here", default=None, console=console
)
try:
exchange_data = json.loads(auth_code_text)
state = exchange_data["state"]
code = exchange_data["code"]
except Exception:
code = state = None
return auth_process(
code=code,
state=state,
initial_state=initial_state,
code_verifier=ctx.obj.auth.code_verifier,
client=ctx.obj.auth.client,
)
else:
# Wait for the browser authentication in non-headless mode
console.print()
wait_msg = "waiting for browser authentication"
with console.status(wait_msg, spinner="bouncingBar"):
time.sleep(2)
click.launch(target)
server_thread.join()
except OSError as e:
if e.errno == socket.errno.EADDRINUSE:
reason = f"The port {HOST}:{PORT} is currently being used by another application or process. Please choose a different port or terminate the conflicting application/process to free up the port."
else:
reason = "An error occurred while performing this operation."
click.secho(reason)
sys.exit(1)
return server.callback

View File

@@ -0,0 +1,756 @@
# type: ignore
import importlib.util
import json
import logging
from functools import lru_cache
from typing import Any, Callable, Dict, Optional, Tuple, List, Literal
import requests
from authlib.integrations.base_client.errors import OAuthError
from authlib.integrations.requests_client import OAuth2Session
from requests.adapters import HTTPAdapter
from safety_schemas.models import STAGE_ID_MAPPING, Stage
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from safety.auth.constants import (
AUTH_SERVER_URL,
OPENID_CONFIG_URL,
)
from safety.constants import (
PLATFORM_API_CHECK_UPDATES_ENDPOINT,
PLATFORM_API_INITIALIZE_ENDPOINT,
PLATFORM_API_POLICY_ENDPOINT,
PLATFORM_API_PROJECT_CHECK_ENDPOINT,
PLATFORM_API_PROJECT_ENDPOINT,
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT,
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
REQUEST_TIMEOUT,
FeatureType,
get_config_setting,
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT,
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT,
)
from safety.error_handlers import output_exception
from safety.errors import (
InvalidCredentialError,
NetworkConnectionError,
RequestTimeoutError,
SafetyError,
ServerError,
TooManyRequestsError,
)
from safety.meta import get_meta_http_headers
from safety.models import SafetyCLI
from safety.scan.util import AuthenticationType
from safety.util import SafetyContext
LOG = logging.getLogger(__name__)
def get_keys(
client_session: OAuth2Session, openid_config: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Retrieve the keys from the OpenID configuration.
Args:
client_session (OAuth2Session): The OAuth2 session.
openid_config (Dict[str, Any]): The OpenID configuration.
Returns:
Optional[Dict[str, Any]]: The keys, if available.
"""
if "jwks_uri" in openid_config:
return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() # type: ignore
return None
def is_email_verified(info: Dict[str, Any]) -> Optional[bool]:
"""
Check if the email is verified.
Args:
info (Dict[str, Any]): The user information.
Returns:
bool: True
"""
# return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(
# CLAIM_EMAIL_VERIFIED_AUTH_SERVER
# )
# Always return True to avoid email verification
return True
def extract_detail(response: requests.Response) -> Optional[str]:
"""
Extract the reason from an HTTP response.
Args:
response (requests.Response): The response.
Returns:
Optional[str]: The reason.
"""
detail = None
try:
detail = response.json().get("detail")
except Exception:
LOG.debug("Failed to extract detail from response: %s", response.status_code)
return detail
def parse_response(func: Callable) -> Callable:
"""
Decorator to parse the response from an HTTP request.
Args:
func (Callable): The function to wrap.
Returns:
Callable: The wrapped function.
"""
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3),
reraise=True,
retry=retry_if_exception_type(
(
NetworkConnectionError,
RequestTimeoutError,
TooManyRequestsError,
ServerError,
)
),
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
)
def wrapper(*args, **kwargs):
try:
r = func(*args, **kwargs)
except OAuthError as e:
LOG.exception("OAuth failed: %s", e)
raise InvalidCredentialError(
message="Your token authentication expired, try login again."
)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
except requests.exceptions.Timeout:
raise RequestTimeoutError()
except requests.exceptions.RequestException as e:
raise e
# TODO: Handle content as JSON and fallback to text for all responses
if r.status_code == 403:
reason = extract_detail(response=r)
raise InvalidCredentialError(
credential="Failed authentication.", reason=reason
)
if r.status_code == 429:
raise TooManyRequestsError(reason=r.text)
if r.status_code >= 400 and r.status_code < 500:
error_code = None
try:
data = r.json()
reason = data.get("detail", "Unable to find reason.")
error_code = data.get("error_code", None)
except Exception:
reason = r.reason
raise SafetyError(message=reason, error_code=error_code)
if r.status_code >= 500 and r.status_code < 600:
reason = extract_detail(response=r)
LOG.debug("ServerError %s -> Response returned: %s", r.status_code, r.text)
raise ServerError(reason=reason)
data = None
try:
data = r.json()
except json.JSONDecodeError as e:
raise ServerError(message=f"Bad JSON response from the server: {e}")
return data
return wrapper
class SafetyAuthSession(OAuth2Session):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Initialize the SafetyAuthSession.
Args:
*args (Any): Positional arguments for the parent class.
**kwargs (Any): Keyword arguments for the parent class.
"""
super().__init__(*args, **kwargs)
self.proxy_required: bool = False
self.proxy_timeout: Optional[int] = None
self.api_key = None
def get_credential(self) -> Optional[str]:
"""
Get the current authentication credential.
Returns:
Optional[str]: The API key, token, or None.
"""
if self.api_key:
return self.api_key
if self.token:
return SafetyContext().account
return None
def is_using_auth_credentials(self) -> bool:
"""
Check if the session is using authentication credentials.
This does NOT check if the client is authenticated.
Returns:
bool: True if using authentication credentials, False otherwise.
"""
return self.get_authentication_type() != AuthenticationType.none
def get_authentication_type(self) -> AuthenticationType:
"""
Get the type of authentication being used.
Returns:
AuthenticationType: The type of authentication.
"""
if self.api_key:
return AuthenticationType.api_key
if self.token:
return AuthenticationType.token
return AuthenticationType.none
def request(
self,
method: str,
url: str,
withhold_token: bool = False,
auth: Optional[Tuple] = None,
bearer: bool = True,
**kwargs: Any,
) -> requests.Response:
"""
Make an HTTP request with the appropriate authentication.
Use the right auth parameter for Safety supported auth types.
Args:
method (str): The HTTP method.
url (str): The URL to request.
withhold_token (bool): Whether to withhold the token.
auth (Optional[Tuple]): The authentication tuple.
bearer (bool): Whether to use bearer authentication.
**kwargs (Any): Additional keyword arguments.
Returns:
requests.Response: The HTTP response.
Raises:
Exception: If the request fails.
"""
# By default use the token_auth
TIMEOUT_KEYWARD = "timeout"
func_timeout = (
kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT
)
if "headers" not in kwargs:
kwargs["headers"] = {}
kwargs["headers"].update(get_meta_http_headers())
if self.api_key:
kwargs["headers"]["X-Api-Key"] = self.api_key
if not self.token or not bearer:
# Fallback to no token auth
auth = ()
# Override proxies
if self.proxies:
kwargs["proxies"] = self.proxies
if self.proxy_timeout:
kwargs["timeout"] = int(self.proxy_timeout) / 1000
if ("proxies" not in kwargs or not self.proxies) and self.proxy_required:
output_exception(
"Proxy connection is required but there is not a proxy setup.", # type: ignore
exit_code_output=True,
)
request_func = super(SafetyAuthSession, self).request
params = {
"method": method,
"url": url,
"withhold_token": withhold_token,
"auth": auth,
}
params.update(kwargs)
try:
return request_func(**params)
except Exception as e:
LOG.debug("Request failed: %s", e)
if self.proxy_required:
output_exception(
f"Proxy is required but the connection failed because: {e}", # type: ignore
exit_code_output=True,
)
if "proxies" in kwargs or self.proxies:
params["proxies"] = {}
params["timeout"] = func_timeout
self.proxies = {}
message = (
"The proxy configuration failed to function and was disregarded."
)
LOG.debug(message)
if message not in [
a["message"] for a in SafetyContext.local_announcements
]:
SafetyContext.local_announcements.append(
{"message": message, "type": "warning", "local": True}
)
return request_func(**params)
raise e
def fetch_openid_config(self) -> Any:
"""
Fetch the OpenID configuration from the authorization server.
Returns:
Any: The OpenID configuration.
"""
try:
openid_config = self.get(
url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT
).json()
except Exception as e:
LOG.debug("Unable to load the openID config: %s", e)
openid_config = {}
return openid_config
@parse_response
def fetch_user_info(self) -> Any:
"""
Fetch user information from the authorization server.
Returns:
Any: The user information.
"""
USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo"
r = self.get(url=USER_INFO_ENDPOINT)
return r
@parse_response
def check_project(
self,
scan_stage: str,
safety_source: str,
project_slug: Optional[str] = None,
git_origin: Optional[str] = None,
project_slug_source: Optional[str] = None,
) -> Any:
"""
Check project information.
Args:
scan_stage (str): The scan stage.
safety_source (str): The safety source.
project_slug (Optional[str]): The project slug.
git_origin (Optional[str]): The git origin.
project_slug_source (Optional[str]): The project slug source.
Returns:
Any: The project information.
"""
data = {
"scan_stage": scan_stage,
"safety_source": safety_source,
"project_slug": project_slug,
"project_slug_source": project_slug_source,
"git_origin": git_origin,
}
r = self.post(url=PLATFORM_API_PROJECT_CHECK_ENDPOINT, json=data)
return r
@parse_response
def project(self, project_id: str) -> Any:
"""
Get project information.
Args:
project_id (str): The project ID.
Returns:
Any: The project information.
"""
data = {"project": project_id}
return self.get(url=PLATFORM_API_PROJECT_ENDPOINT, params=data)
@parse_response
def download_policy(
self, project_id: Optional[str], stage: Stage, branch: Optional[str]
) -> Any:
"""
Download the project policy.
Args:
project_id (Optional[str]): The project ID.
stage (Stage): The stage.
branch (Optional[str]): The branch.
Returns:
Any: The policy data.
"""
data = {
"project": project_id,
"stage": STAGE_ID_MAPPING[stage],
"branch": branch,
}
return self.get(url=PLATFORM_API_POLICY_ENDPOINT, params=data)
@parse_response
def project_scan_request(self, project_id: str) -> Any:
"""
Request a project scan.
Args:
project_id (str): The project ID.
Returns:
Any: The scan request result.
"""
data = {"project_id": project_id}
return self.post(url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data)
@parse_response
def upload_report(self, json_report: str) -> Any:
"""
Upload a scan report.
Args:
json_report (str): The JSON report.
Returns:
Any: The upload result.
"""
return self.post(
url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
data=json_report,
headers={"Content-Type": "application/json"},
)
def upload_requirements(self, json_payload: str) -> Any:
"""
Upload a scan report.
Args:
json_payload (str): The JSON payload to upload.
Returns:
Any: The result of the upload operation.
"""
return self.post(
url=PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
data=json.dumps(json_payload),
headers={"Content-Type": "application/json"},
)
@parse_response
def check_updates(
self,
version: int,
safety_version: Optional[str] = None,
python_version: Optional[str] = None,
os_type: Optional[str] = None,
os_release: Optional[str] = None,
os_description: Optional[str] = None,
) -> Any:
"""
Check for updates.
Args:
version (int): The version.
safety_version (Optional[str]): The Safety version.
python_version (Optional[str]): The Python version.
os_type (Optional[str]): The OS type.
os_release (Optional[str]): The OS release.
os_description (Optional[str]): The OS description.
Returns:
Any: The update check result.
"""
data = {
"version": version,
"safety_version": safety_version,
"python_version": python_version,
"os_type": os_type,
"os_release": os_release,
"os_description": os_description,
}
return self.get(url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data)
@parse_response
def audit_packages(
self, packages: List[str], ecosystem: Literal["pypi", "npmjs"]
) -> Any:
"""
Audits packages for vulnerabilities
Args:
packages: list of package specifiers
ecosystem: the ecosystem to audit
Returns:
Any: The packages audit result.
"""
url = (
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT
if ecosystem == "npmjs"
else FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT
)
data = {"packages": [{"package_specifier": package} for package in packages]}
return self.post(url=url, json=data)
@parse_response
def initialize(self) -> Any:
"""
Initialize a run.
Returns:
Any: The initialization result.
"""
try:
response = self.get(
url=PLATFORM_API_INITIALIZE_ENDPOINT,
headers={"Content-Type": "application/json"},
timeout=5,
)
return response
except requests.exceptions.Timeout:
LOG.error("Auth request to initialize timed out after 5 seconds.")
except Exception:
LOG.exception("Exception trying to auth initialize", exc_info=True)
return None
class S3PresignedAdapter(HTTPAdapter):
def send( # type: ignore
self, request: requests.PreparedRequest, **kwargs: Any
) -> requests.Response:
"""
Send a request, removing the Authorization header.
Args:
request (requests.PreparedRequest): The prepared request.
**kwargs (Any): Additional keyword arguments.
Returns:
requests.Response: The response.
"""
request.headers.pop("Authorization", None)
return super().send(request, **kwargs)
@lru_cache(maxsize=1)
def is_jupyter_notebook() -> bool:
"""
Detects if the code is running in a Jupyter notebook environment, including
various cloud-hosted Jupyter notebooks.
Returns:
bool: True if the environment is identified as a Jupyter notebook (or
equivalent cloud-based environment), False otherwise.
Supported environments:
- Google Colab
- Amazon SageMaker
- Azure Notebooks
- Kaggle Notebooks
- Databricks Notebooks
- Datalore by JetBrains
- Paperspace Gradient Notebooks
- Classic Jupyter Notebook and JupyterLab
"""
if (
(
importlib.util.find_spec("google")
and importlib.util.find_spec("google.colab")
)
is not None
or importlib.util.find_spec("sagemaker") is not None
or importlib.util.find_spec("azureml") is not None
or importlib.util.find_spec("kaggle") is not None
or importlib.util.find_spec("dbutils") is not None
or importlib.util.find_spec("datalore") is not None
or importlib.util.find_spec("gradient") is not None
):
return True
# Detect classic Jupyter Notebook, JupyterLab, and other IPython kernel-based environments
try:
from IPython import get_ipython # type: ignore
ipython = get_ipython()
if ipython is not None and "IPKernelApp" in ipython.config:
return True
except (ImportError, AttributeError, NameError):
pass
return False
def save_flags_config(flags: Dict[FeatureType, bool]) -> None:
"""
Save feature flags configuration to file.
This function attempts to save feature flags to the configuration file
but will fail silently if unable to do so (e.g., due to permission issues
or disk problems). Silent failure is chosen to prevent configuration issues
from disrupting core application functionality.
Note that if saving fails, the application will continue using existing
or default flag values until the next restart.
Args:
flags: Dictionary mapping feature types to their enabled/disabled state
The operation will be logged (with stack trace) if it fails.
"""
import configparser
from safety.constants import CONFIG_FILE_USER
config = configparser.ConfigParser()
config.read(CONFIG_FILE_USER)
flag_settings = {key.name.upper(): str(value) for key, value in flags.items()}
if not config.has_section("settings"):
config.add_section("settings")
settings = dict(config.items("settings"))
settings.update(flag_settings)
for key, value in settings.items():
config.set("settings", key, value)
try:
with open(CONFIG_FILE_USER, "w") as config_file:
config.write(config_file)
except Exception:
LOG.exception("Unable to save flags configuration.")
def get_feature_name(feature: FeatureType, as_attr: bool = False) -> str:
"""Returns a formatted feature name with enabled suffix.
Args:
feature: The feature to format the name for
as_attr: If True, formats for attribute usage (underscore),
otherwise uses hyphen
Returns:
Formatted feature name string with enabled suffix
"""
name = feature.name.lower()
separator = "_" if as_attr else "-"
return f"{name}{separator}enabled"
def str_to_bool(value) -> Optional[bool]:
"""Convert basic string representations to boolean."""
if isinstance(value, bool):
return value
if isinstance(value, str):
value = value.lower().strip()
if value in ("true"):
return True
if value in ("false"):
return False
return None
def initialize(ctx: Any, refresh: bool = True) -> None:
"""
Initializes the run by loading settings.
Args:
ctx (Any): The context object.
refresh (bool): Whether to refresh settings from the server. Defaults to True.
"""
settings = None
current_values = {}
if not ctx.obj:
ctx.obj = SafetyCLI()
for feature in FeatureType:
value = get_config_setting(feature.name)
if value is not None:
current_values[feature] = str_to_bool(value)
if refresh:
try:
settings = ctx.obj.auth.client.initialize() # type: ignore
except Exception:
LOG.info("Unable to initialize, continue with default values.")
if settings:
for feature in FeatureType:
server_value = str_to_bool(settings.get(feature.config_key))
if server_value is not None:
if (
feature not in current_values
or current_values[feature] != server_value
):
current_values[feature] = server_value
save_flags_config(current_values)
for feature, value in current_values.items():
if value is not None:
setattr(ctx.obj, feature.attr_name, value)