updates
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
from .cli_utils import auth_options, build_client_session, proxy_options, inject_session
|
||||
from .cli import auth
|
||||
|
||||
__all__ = [
|
||||
"build_client_session",
|
||||
"proxy_options",
|
||||
"auth_options",
|
||||
"inject_session",
|
||||
"auth",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
402
Backend/venv/lib/python3.12/site-packages/safety/auth/cli.py
Normal file
402
Backend/venv/lib/python3.12/site-packages/safety/auth/cli.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# type: ignore
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from safety.auth.models import Auth
|
||||
from safety.auth.utils import initialize, is_email_verified
|
||||
from safety.console import main_console as console
|
||||
from safety.constants import (
|
||||
MSG_FINISH_REGISTRATION_TPL,
|
||||
MSG_VERIFICATION_HINT,
|
||||
DEFAULT_EPILOG,
|
||||
)
|
||||
from safety.meta import get_version
|
||||
from safety.decorators import notify
|
||||
|
||||
try:
|
||||
from typing import Annotated
|
||||
except ImportError:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import typer
|
||||
from rich.padding import Padding
|
||||
from typer import Typer
|
||||
|
||||
from safety.auth.main import (
|
||||
clean_session,
|
||||
get_auth_info,
|
||||
get_authorization_data,
|
||||
get_token,
|
||||
)
|
||||
from safety.auth.server import process_browser_callback
|
||||
from safety.events.utils import emit_auth_started, emit_auth_completed
|
||||
from safety.util import initialize_event_bus
|
||||
from safety.scan.constants import (
|
||||
CLI_AUTH_COMMAND_HELP,
|
||||
CLI_AUTH_HEADLESS_HELP,
|
||||
CLI_AUTH_LOGIN_HELP,
|
||||
CLI_AUTH_LOGOUT_HELP,
|
||||
CLI_AUTH_STATUS_HELP,
|
||||
)
|
||||
|
||||
from ..cli_util import SafetyCLISubGroup, get_command_for, pass_safety_cli_obj
|
||||
from safety.error_handlers import handle_cmd_exception
|
||||
from .constants import (
|
||||
MSG_FAIL_LOGIN_AUTHED,
|
||||
MSG_FAIL_REGISTER_AUTHED,
|
||||
MSG_LOGOUT_DONE,
|
||||
MSG_LOGOUT_FAILED,
|
||||
MSG_NON_AUTHENTICATED,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
auth_app = Typer(rich_markup_mode="rich", name="auth")
|
||||
|
||||
|
||||
CMD_LOGIN_NAME = "login"
|
||||
CMD_REGISTER_NAME = "register"
|
||||
CMD_STATUS_NAME = "status"
|
||||
CMD_LOGOUT_NAME = "logout"
|
||||
DEFAULT_CMD = CMD_LOGIN_NAME
|
||||
|
||||
|
||||
@auth_app.callback(
|
||||
invoke_without_command=True,
|
||||
cls=SafetyCLISubGroup,
|
||||
help=CLI_AUTH_COMMAND_HELP,
|
||||
epilog=DEFAULT_EPILOG,
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
@pass_safety_cli_obj
|
||||
def auth(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Authenticate Safety CLI with your account.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
"""
|
||||
LOG.info("auth started")
|
||||
|
||||
# If no subcommand is invoked, forward to the default command
|
||||
if not ctx.invoked_subcommand:
|
||||
default_command = get_command_for(name=DEFAULT_CMD, typer_instance=auth_app)
|
||||
return ctx.forward(default_command)
|
||||
|
||||
|
||||
def fail_if_authenticated(ctx: typer.Context, with_msg: str) -> None:
|
||||
"""
|
||||
Exits the command if the user is already authenticated.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
with_msg (str): The message to display if authenticated.
|
||||
"""
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
if info:
|
||||
console.print()
|
||||
email = f"[green]{ctx.obj.auth.email}[/green]"
|
||||
if not ctx.obj.auth.email_verified:
|
||||
email = f"{email} {render_email_note(ctx.obj.auth)}"
|
||||
|
||||
console.print(with_msg.format(email=email))
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def render_email_note(auth: Auth) -> str:
|
||||
"""
|
||||
Renders a note indicating whether email verification is required.
|
||||
|
||||
Args:
|
||||
auth (Auth): The Auth object.
|
||||
|
||||
Returns:
|
||||
str: The rendered email note.
|
||||
"""
|
||||
return "" if auth.email_verified else "[red](email verification required)[/red]"
|
||||
|
||||
|
||||
def render_successful_login(auth: Auth, organization: Optional[str] = None) -> None:
|
||||
"""
|
||||
Renders a message indicating a successful login.
|
||||
|
||||
Args:
|
||||
auth (Auth): The Auth object.
|
||||
organization (Optional[str]): The organization name.
|
||||
"""
|
||||
DEFAULT = "--"
|
||||
name = auth.name if auth.name else DEFAULT
|
||||
email = auth.email if auth.email else DEFAULT
|
||||
email_note = render_email_note(auth)
|
||||
console.print()
|
||||
console.print("[bold][green]You're authenticated[/green][/bold]")
|
||||
if name and name != email:
|
||||
details = [f"[green][bold]Account:[/bold] {name}, {email}[/green] {email_note}"]
|
||||
else:
|
||||
details = [f"[green][bold]Account:[/bold] {email}[/green] {email_note}"]
|
||||
|
||||
if organization:
|
||||
details.insert(0, f"[green][bold]Organization:[/bold] {organization}[green]")
|
||||
|
||||
for msg in details:
|
||||
console.print(Padding(msg, (0, 0, 0, 1)), emoji=True)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def login(
|
||||
ctx: typer.Context,
|
||||
headless: Annotated[
|
||||
Optional[bool],
|
||||
typer.Option(
|
||||
"--headless",
|
||||
help=CLI_AUTH_HEADLESS_HELP,
|
||||
),
|
||||
] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Authenticate Safety CLI with your safetycli.com account using your default browser.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
headless (bool): Whether to run in headless mode.
|
||||
"""
|
||||
LOG.info("login started")
|
||||
headless = headless is True
|
||||
|
||||
# Check if the user is already authenticated
|
||||
fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED)
|
||||
|
||||
console.print()
|
||||
|
||||
info = None
|
||||
|
||||
brief_msg: str = (
|
||||
"Redirecting your browser to log in; once authenticated, "
|
||||
"return here to start using Safety"
|
||||
)
|
||||
|
||||
if ctx.obj.auth.org:
|
||||
console.print(
|
||||
f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] organization."
|
||||
)
|
||||
|
||||
if headless:
|
||||
brief_msg = "Running in headless mode. Please copy and open the following URL in a browser"
|
||||
|
||||
# Get authorization data and generate the authorization URL
|
||||
uri, initial_state = get_authorization_data(
|
||||
client=ctx.obj.auth.client,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
organization=ctx.obj.auth.org,
|
||||
headless=headless,
|
||||
)
|
||||
click.secho(brief_msg)
|
||||
click.echo()
|
||||
|
||||
emit_auth_started(ctx.obj.event_bus, ctx)
|
||||
# Process the browser callback to complete the authentication
|
||||
info = process_browser_callback(
|
||||
uri, initial_state=initial_state, ctx=ctx, headless=headless
|
||||
)
|
||||
|
||||
is_success = False
|
||||
error_msg = None
|
||||
|
||||
if info:
|
||||
if info.get("email", None):
|
||||
organization = None
|
||||
if ctx.obj.auth.org and ctx.obj.auth.org.name:
|
||||
organization = ctx.obj.auth.org.name
|
||||
ctx.obj.auth.refresh_from(info)
|
||||
if headless:
|
||||
console.print()
|
||||
|
||||
initialize(ctx, refresh=True)
|
||||
initialize_event_bus(ctx=ctx)
|
||||
render_successful_login(ctx.obj.auth, organization=organization)
|
||||
is_success = True
|
||||
|
||||
console.print()
|
||||
if ctx.obj.auth.org or ctx.obj.auth.email_verified:
|
||||
if not getattr(ctx.obj, "only_auth_msg", False):
|
||||
console.print(
|
||||
"[tip]Tip[/tip]: now try [bold]`safety scan`[/bold] in your project’s root "
|
||||
"folder to run a project scan or [bold]`safety -–help`[/bold] to learn more."
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
MSG_FINISH_REGISTRATION_TPL.format(email=ctx.obj.auth.email)
|
||||
)
|
||||
console.print()
|
||||
console.print(MSG_VERIFICATION_HINT)
|
||||
else:
|
||||
click.secho("Safety is now authenticated but your email is missing.")
|
||||
else:
|
||||
error_msg = ":stop_sign: [red]"
|
||||
if ctx.obj.auth.org:
|
||||
error_msg += (
|
||||
f"Error logging into {ctx.obj.auth.org.name} organization "
|
||||
f"with auth ID: {ctx.obj.auth.org.id}."
|
||||
)
|
||||
else:
|
||||
error_msg += "Error logging into Safety."
|
||||
|
||||
error_msg += (
|
||||
" Please try again, or use [bold]`safety auth -–help`[/bold] "
|
||||
"for more information[/red]"
|
||||
)
|
||||
|
||||
console.print(error_msg, emoji=True)
|
||||
|
||||
emit_auth_completed(
|
||||
ctx.obj.event_bus, ctx, success=is_success, error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_LOGOUT_NAME, help=CLI_AUTH_LOGOUT_HELP)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def logout(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Log out of your current session.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
"""
|
||||
LOG.info("logout started")
|
||||
|
||||
id_token = get_token("id_token")
|
||||
|
||||
msg = MSG_NON_AUTHENTICATED
|
||||
|
||||
if id_token:
|
||||
# Clean the session if an ID token is found
|
||||
if clean_session(ctx.obj.auth.client):
|
||||
msg = MSG_LOGOUT_DONE
|
||||
else:
|
||||
msg = MSG_LOGOUT_FAILED
|
||||
|
||||
console.print(msg)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_STATUS_NAME, help=CLI_AUTH_STATUS_HELP)
|
||||
@click.option(
|
||||
"--ensure-auth/--no-ensure-auth",
|
||||
default=False,
|
||||
help="This will keep running the command until anauthentication is made.",
|
||||
)
|
||||
@click.option(
|
||||
"--login-timeout",
|
||||
"-w",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Max time allowed to wait for an authentication.",
|
||||
)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def status(
|
||||
ctx: typer.Context, ensure_auth: bool = False, login_timeout: int = 600
|
||||
) -> None:
|
||||
"""
|
||||
Display Safety CLI's current authentication status.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
ensure_auth (bool): Whether to keep running until authentication is made.
|
||||
login_timeout (int): Max time allowed to wait for authentication.
|
||||
"""
|
||||
LOG.info("status started")
|
||||
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
safety_version = get_version()
|
||||
console.print(f"[{current_time}]: Safety {safety_version}")
|
||||
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
initialize(ctx, refresh=True)
|
||||
|
||||
if ensure_auth:
|
||||
console.print("running: safety auth status --ensure-auth")
|
||||
console.print()
|
||||
|
||||
if info:
|
||||
verified = is_email_verified(info)
|
||||
email_status = " [red](email not verified)[/red]" if not verified else ""
|
||||
|
||||
console.print(f"[green]Authenticated as {info['email']}[/green]{email_status}")
|
||||
elif ensure_auth:
|
||||
console.print(
|
||||
"Safety is not authenticated. Launching default browser to log in"
|
||||
)
|
||||
console.print()
|
||||
uri, initial_state = get_authorization_data(
|
||||
client=ctx.obj.auth.client,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
organization=ctx.obj.auth.org,
|
||||
ensure_auth=ensure_auth,
|
||||
)
|
||||
|
||||
# Process the browser callback to complete the authentication
|
||||
info = process_browser_callback(
|
||||
uri, initial_state=initial_state, timeout=login_timeout, ctx=ctx
|
||||
)
|
||||
|
||||
if not info:
|
||||
console.print(
|
||||
f"[red]Timeout error ({login_timeout} seconds): not successfully authenticated without the timeout period.[/red]"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
organization = None
|
||||
if ctx.obj.auth.org and ctx.obj.auth.org.name:
|
||||
organization = ctx.obj.auth.org.name
|
||||
|
||||
render_successful_login(ctx.obj.auth, organization=organization)
|
||||
console.print()
|
||||
|
||||
else:
|
||||
console.print(MSG_NON_AUTHENTICATED)
|
||||
|
||||
|
||||
@auth_app.command(name=CMD_REGISTER_NAME)
|
||||
@handle_cmd_exception
|
||||
@notify
|
||||
def register(ctx: typer.Context) -> None:
|
||||
"""
|
||||
Create a new user account for the safetycli.com service.
|
||||
|
||||
Args:
|
||||
ctx (typer.Context): The Typer context object.
|
||||
"""
|
||||
LOG.info("register started")
|
||||
|
||||
# Check if the user is already authenticated
|
||||
fail_if_authenticated(ctx, with_msg=MSG_FAIL_REGISTER_AUTHED)
|
||||
|
||||
# Get authorization data and generate the registration URL
|
||||
uri, initial_state = get_authorization_data(
|
||||
client=ctx.obj.auth.client,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
sign_up=True,
|
||||
)
|
||||
|
||||
console.print(
|
||||
"\nRedirecting your browser to register for a free account. Once registered, return here to start using Safety."
|
||||
)
|
||||
console.print()
|
||||
|
||||
# Process the browser callback to complete the registration
|
||||
info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx)
|
||||
console.print()
|
||||
|
||||
if info:
|
||||
console.print(f"[green]Successfully registered {info.get('email')}[/green]")
|
||||
console.print()
|
||||
else:
|
||||
console.print("[red]Unable to register in this time, try again.[/red]")
|
||||
@@ -0,0 +1,290 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple, Any, Callable
|
||||
|
||||
import click
|
||||
|
||||
from .main import (
|
||||
get_auth_info,
|
||||
get_host_config,
|
||||
get_organization,
|
||||
get_proxy_config,
|
||||
get_redirect_url,
|
||||
get_token_data,
|
||||
save_auth_config,
|
||||
get_token,
|
||||
clean_session,
|
||||
)
|
||||
from authlib.common.security import generate_token
|
||||
from safety.auth.constants import CLIENT_ID
|
||||
|
||||
from safety.auth.models import Organization, Auth
|
||||
from safety.auth.utils import (
|
||||
S3PresignedAdapter,
|
||||
SafetyAuthSession,
|
||||
get_keys,
|
||||
is_email_verified,
|
||||
)
|
||||
from safety.scan.constants import (
|
||||
CLI_KEY_HELP,
|
||||
CLI_PROXY_HOST_HELP,
|
||||
CLI_PROXY_PORT_HELP,
|
||||
CLI_PROXY_PROTOCOL_HELP,
|
||||
CLI_STAGE_HELP,
|
||||
)
|
||||
from safety.util import DependentOption, SafetyContext, get_proxy_dict
|
||||
from safety.models import SafetyCLI
|
||||
from safety_schemas.models import Stage
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_client_session(
|
||||
api_key: Optional[str] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[SafetyAuthSession, Dict[str, Any]]:
|
||||
"""
|
||||
Builds and configures the client session for authentication.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): The API key for authentication.
|
||||
proxies (Optional[Dict[str, str]]): Proxy configuration.
|
||||
headers (Optional[Dict[str, str]]): Additional headers.
|
||||
|
||||
Returns:
|
||||
Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
target_proxies = proxies
|
||||
|
||||
# Global proxy defined in the config.ini
|
||||
proxy_config, proxy_timeout, proxy_required = get_proxy_config()
|
||||
|
||||
if not proxies:
|
||||
target_proxies = proxy_config
|
||||
|
||||
def update_token(tokens, **kwargs):
|
||||
save_auth_config(
|
||||
access_token=tokens["access_token"],
|
||||
id_token=tokens["id_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
)
|
||||
load_auth_session(click_ctx=click.get_current_context(silent=True)) # type: ignore
|
||||
|
||||
client_session = SafetyAuthSession(
|
||||
client_id=CLIENT_ID,
|
||||
code_challenge_method="S256",
|
||||
redirect_uri=get_redirect_url(),
|
||||
update_token=update_token,
|
||||
scope="openid email profile offline_access",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter())
|
||||
|
||||
client_session.proxy_required = proxy_required
|
||||
client_session.proxy_timeout = proxy_timeout
|
||||
client_session.proxies = target_proxies # type: ignore
|
||||
client_session.headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
openid_config = client_session.fetch_openid_config()
|
||||
|
||||
client_session.metadata["token_endpoint"] = openid_config.get(
|
||||
"token_endpoint", None
|
||||
)
|
||||
|
||||
if api_key:
|
||||
client_session.api_key = api_key # type: ignore
|
||||
client_session.headers["X-Api-Key"] = api_key
|
||||
|
||||
if headers:
|
||||
client_session.headers.update(headers)
|
||||
|
||||
return client_session, openid_config
|
||||
|
||||
|
||||
def load_auth_session(click_ctx: click.Context) -> None:
|
||||
"""
|
||||
Loads the authentication session from the context.
|
||||
|
||||
Args:
|
||||
click_ctx (click.Context): The Click context object.
|
||||
"""
|
||||
if not click_ctx:
|
||||
LOG.warning("Click context is needed to be able to load the Auth data.")
|
||||
return
|
||||
|
||||
client = click_ctx.obj.auth.client
|
||||
keys = click_ctx.obj.auth.keys
|
||||
|
||||
access_token: str = get_token(name="access_token") # type: ignore
|
||||
refresh_token: str = get_token(name="refresh_token") # type: ignore
|
||||
id_token: str = get_token(name="id_token") # type: ignore
|
||||
|
||||
if access_token and keys:
|
||||
try:
|
||||
token = get_token_data(access_token, keys, silent_if_expired=True)
|
||||
client.token = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"id_token": id_token,
|
||||
"token_type": "bearer",
|
||||
"expires_at": token.get("exp", None), # type: ignore
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
clean_session(client)
|
||||
|
||||
|
||||
def proxy_options(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator that defines proxy options for Click commands.
|
||||
|
||||
Options defined per command, this will override the proxy settings defined in the
|
||||
config.ini file.
|
||||
|
||||
Args:
|
||||
func (Callable): The Click command function.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped Click command function with proxy options.
|
||||
"""
|
||||
func = click.option(
|
||||
"--proxy-protocol",
|
||||
type=click.Choice(["http", "https"]),
|
||||
default="https",
|
||||
cls=DependentOption,
|
||||
required_options=["proxy_host"],
|
||||
help=CLI_PROXY_PROTOCOL_HELP,
|
||||
)(func)
|
||||
func = click.option(
|
||||
"--proxy-port",
|
||||
multiple=False,
|
||||
type=int,
|
||||
default=80,
|
||||
cls=DependentOption,
|
||||
required_options=["proxy_host"],
|
||||
help=CLI_PROXY_PORT_HELP,
|
||||
)(func)
|
||||
func = click.option(
|
||||
"--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP
|
||||
)(func)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def auth_options(stage: bool = True) -> Callable:
|
||||
"""
|
||||
Decorator that defines authentication options for Click commands.
|
||||
|
||||
Args:
|
||||
stage (bool): Whether to include the stage option.
|
||||
|
||||
Returns:
|
||||
Callable: The decorator function.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func = click.option(
|
||||
"--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP
|
||||
)(func)
|
||||
|
||||
if stage:
|
||||
func = click.option(
|
||||
"--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP
|
||||
)(func)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def inject_session(
|
||||
ctx: click.Context,
|
||||
proxy_protocol: Optional[str] = None,
|
||||
proxy_host: Optional[str] = None,
|
||||
proxy_port: Optional[str] = None,
|
||||
key: Optional[str] = None,
|
||||
stage: Optional[Stage] = None,
|
||||
invoked_command: str = "",
|
||||
) -> Any:
|
||||
org: Optional[Organization] = get_organization()
|
||||
|
||||
if not stage:
|
||||
host_stage = get_host_config(key_name="stage")
|
||||
stage = host_stage if host_stage else Stage.development
|
||||
|
||||
proxy_config: Optional[Dict[str, str]] = get_proxy_dict(
|
||||
proxy_protocol, # type: ignore
|
||||
proxy_host, # type: ignore
|
||||
proxy_port, # type: ignore
|
||||
)
|
||||
|
||||
client_session, openid_config = build_client_session(
|
||||
api_key=key, proxies=proxy_config
|
||||
)
|
||||
keys = get_keys(client_session, openid_config)
|
||||
|
||||
auth = Auth(
|
||||
stage=stage,
|
||||
keys=keys,
|
||||
org=org,
|
||||
client_id=CLIENT_ID, # type: ignore
|
||||
client=client_session,
|
||||
code_verifier=generate_token(48),
|
||||
)
|
||||
|
||||
if not ctx.obj:
|
||||
ctx.obj = SafetyCLI()
|
||||
|
||||
ctx.obj.auth = auth
|
||||
|
||||
load_auth_session(ctx)
|
||||
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
if info:
|
||||
ctx.obj.auth.name = info.get("name")
|
||||
ctx.obj.auth.email = info.get("email")
|
||||
ctx.obj.auth.email_verified = is_email_verified(info) # type: ignore
|
||||
SafetyContext().account = info["email"]
|
||||
else:
|
||||
SafetyContext().account = ""
|
||||
|
||||
@ctx.call_on_close
|
||||
def clean_up_on_close():
|
||||
LOG.debug("Closing requests session.")
|
||||
ctx.obj.auth.client.close()
|
||||
|
||||
if ctx.obj.event_bus:
|
||||
from safety.events.utils import (
|
||||
create_internal_event,
|
||||
InternalEventType,
|
||||
InternalPayload,
|
||||
)
|
||||
|
||||
payload = InternalPayload(ctx=ctx)
|
||||
|
||||
flush_event = create_internal_event(
|
||||
event_type=InternalEventType.FLUSH_SECURITY_TRACES, payload=payload
|
||||
)
|
||||
close_event = create_internal_event(
|
||||
event_type=InternalEventType.CLOSE_RESOURCES, payload=payload
|
||||
)
|
||||
|
||||
flush_future = ctx.obj.event_bus.emit(flush_event)
|
||||
close_future = ctx.obj.event_bus.emit(close_event)
|
||||
|
||||
# Wait for both events to be processed
|
||||
if flush_future and close_future:
|
||||
try:
|
||||
flush_future.result()
|
||||
close_future.result()
|
||||
except Exception as e:
|
||||
LOG.warning(f"Error waiting for events to process: {e}")
|
||||
|
||||
ctx.obj.event_bus.stop()
|
||||
@@ -0,0 +1,35 @@
|
||||
from pathlib import Path
|
||||
|
||||
from safety.constants import USER_CONFIG_DIR, get_config_setting
|
||||
|
||||
AUTH_CONFIG_FILE_NAME = "auth.ini"
|
||||
AUTH_CONFIG_USER = USER_CONFIG_DIR / Path(AUTH_CONFIG_FILE_NAME)
|
||||
|
||||
|
||||
HOST: str = "localhost"
|
||||
|
||||
CLIENT_ID = get_config_setting("CLIENT_ID")
|
||||
AUTH_SERVER_URL = get_config_setting("AUTH_SERVER_URL")
|
||||
SAFETY_PLATFORM_URL = get_config_setting("SAFETY_PLATFORM_URL")
|
||||
|
||||
OPENID_CONFIG_URL = f"{AUTH_SERVER_URL}/.well-known/openid-configuration"
|
||||
|
||||
CLAIM_EMAIL_VERIFIED_API = "https://api.safetycli.com/email_verified"
|
||||
CLAIM_EMAIL_VERIFIED_AUTH_SERVER = "email_verified"
|
||||
|
||||
CLI_AUTH = f"{SAFETY_PLATFORM_URL}/cli/auth"
|
||||
CLI_AUTH_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/auth/success"
|
||||
CLI_AUTH_LOGOUT = f"{SAFETY_PLATFORM_URL}/cli/logout"
|
||||
CLI_CALLBACK = f"{SAFETY_PLATFORM_URL}/cli/callback"
|
||||
CLI_LOGOUT_SUCCESS = f"{SAFETY_PLATFORM_URL}/cli/logout/success"
|
||||
|
||||
MSG_NON_AUTHENTICATED = (
|
||||
"Safety is not authenticated. Please run 'safety auth login' to log in."
|
||||
)
|
||||
MSG_FAIL_LOGIN_AUTHED = """[green]You are authenticated as[/green] {email}.
|
||||
|
||||
To log into a different account, first logout via: safety auth logout, and then login again."""
|
||||
MSG_FAIL_REGISTER_AUTHED = "You are currently logged in to {email}, please logout using `safety auth logout` before registering a new account."
|
||||
|
||||
MSG_LOGOUT_DONE = "[green]Logout done.[/green]"
|
||||
MSG_LOGOUT_FAILED = "[red]Logout failed. Try again.[/red]"
|
||||
329
Backend/venv/lib/python3.12/site-packages/safety/auth/main.py
Normal file
329
Backend/venv/lib/python3.12/site-packages/safety/auth/main.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import configparser
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
from authlib.oidc.core import CodeIDToken
|
||||
from authlib.jose import jwt
|
||||
from authlib.jose.errors import ExpiredTokenError
|
||||
|
||||
from safety.auth.models import Organization
|
||||
from safety.auth.constants import (
|
||||
CLI_AUTH_LOGOUT,
|
||||
CLI_CALLBACK,
|
||||
AUTH_CONFIG_USER,
|
||||
CLI_AUTH,
|
||||
)
|
||||
from safety.constants import CONFIG
|
||||
from safety_schemas.models import Stage
|
||||
from safety.util import get_proxy_dict
|
||||
|
||||
|
||||
def get_authorization_data(
|
||||
client,
|
||||
code_verifier: str,
|
||||
organization: Optional[Organization] = None,
|
||||
sign_up: bool = False,
|
||||
ensure_auth: bool = False,
|
||||
headless: bool = False,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate the authorization URL for the authentication process.
|
||||
|
||||
Args:
|
||||
client: The authentication client.
|
||||
code_verifier (str): The code verifier for the PKCE flow.
|
||||
organization (Optional[Organization]): The organization to authenticate with.
|
||||
sign_up (bool): Whether the URL is for sign-up.
|
||||
ensure_auth (bool): Whether to ensure authentication.
|
||||
headless (bool): Whether to run in headless mode.
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: The authorization URL and initial state.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"sign_up": sign_up,
|
||||
"locale": "en",
|
||||
"ensure_auth": ensure_auth,
|
||||
"headless": headless,
|
||||
}
|
||||
if organization:
|
||||
kwargs["organization"] = organization.id
|
||||
|
||||
return client.create_authorization_url(
|
||||
CLI_AUTH, code_verifier=code_verifier, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def get_logout_url(id_token: str) -> str:
|
||||
"""
|
||||
Generate the logout URL.
|
||||
|
||||
Args:
|
||||
id_token (str): The ID token.
|
||||
|
||||
Returns:
|
||||
str: The logout URL.
|
||||
"""
|
||||
return f"{CLI_AUTH_LOGOUT}?id_token={id_token}"
|
||||
|
||||
|
||||
def get_redirect_url() -> str:
|
||||
"""
|
||||
Get the redirect URL for the authentication callback.
|
||||
|
||||
Returns:
|
||||
str: The redirect URL.
|
||||
"""
|
||||
return CLI_CALLBACK
|
||||
|
||||
|
||||
def get_organization() -> Optional[Organization]:
|
||||
"""
|
||||
Retrieve the organization configuration.
|
||||
|
||||
Returns:
|
||||
Optional[Organization]: The organization object, or None if not configured.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
org_conf: Union[Dict[str, str], configparser.SectionProxy] = (
|
||||
config["organization"] if "organization" in config.sections() else {}
|
||||
)
|
||||
org_id: Optional[str] = (
|
||||
org_conf["id"].replace('"', "") if org_conf.get("id", None) else None
|
||||
)
|
||||
org_name: Optional[str] = (
|
||||
org_conf["name"].replace('"', "") if org_conf.get("name", None) else None
|
||||
)
|
||||
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
org = Organization(id=org_id, name=org_name) # type: ignore
|
||||
|
||||
return org
|
||||
|
||||
|
||||
def get_auth_info(ctx) -> Optional[Dict]:
|
||||
"""
|
||||
Retrieve the authentication information.
|
||||
|
||||
Args:
|
||||
ctx: The context object containing authentication data.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The authentication information, or None if not authenticated.
|
||||
"""
|
||||
from safety.auth.utils import is_email_verified
|
||||
|
||||
info = None
|
||||
if ctx.obj.auth.client.token:
|
||||
try:
|
||||
info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore
|
||||
|
||||
verified = is_email_verified(info) # type: ignore
|
||||
if not verified:
|
||||
user_info = ctx.obj.auth.client.fetch_user_info()
|
||||
verified = is_email_verified(user_info)
|
||||
|
||||
if verified:
|
||||
# refresh only if needed
|
||||
raise ExpiredTokenError
|
||||
|
||||
except ExpiredTokenError:
|
||||
# id_token expired. So fire a manually a refresh
|
||||
try:
|
||||
ctx.obj.auth.client.refresh_token(
|
||||
ctx.obj.auth.client.metadata.get("token_endpoint"),
|
||||
refresh_token=ctx.obj.auth.client.token.get("refresh_token"),
|
||||
)
|
||||
info = get_token_data(
|
||||
get_token(name="id_token"), # type: ignore
|
||||
keys=ctx.obj.auth.keys, # type: ignore
|
||||
)
|
||||
except Exception as _e:
|
||||
clean_session(ctx.obj.auth.client)
|
||||
except Exception as _g:
|
||||
clean_session(ctx.obj.auth.client)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def get_token_data(
|
||||
token: str, keys: Any, silent_if_expired: bool = False
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Decode and validate the token data.
|
||||
|
||||
Args:
|
||||
token (str): The token to decode.
|
||||
keys (Any): The keys to use for decoding.
|
||||
silent_if_expired (bool): Whether to silently ignore expired tokens.
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: The decoded token data, or None if invalid.
|
||||
"""
|
||||
claims = jwt.decode(token, keys, claims_cls=CodeIDToken)
|
||||
try:
|
||||
claims.validate()
|
||||
except ExpiredTokenError as e:
|
||||
if not silent_if_expired:
|
||||
raise e
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
def get_token(name: str = "access_token") -> Optional[str]:
|
||||
""" "
|
||||
Retrieve a token from the local authentication configuration.
|
||||
|
||||
This returns tokens saved in the local auth configuration.
|
||||
There are two types of tokens: access_token and id_token
|
||||
|
||||
Args:
|
||||
name (str): The name of the token to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The token value, or None if not found.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(AUTH_CONFIG_USER)
|
||||
|
||||
if "auth" in config.sections() and name in config["auth"]:
|
||||
value = config["auth"][name]
|
||||
if value:
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_host_config(key_name: str) -> Optional[Any]:
|
||||
"""
|
||||
Retrieve a configuration value from the host configuration.
|
||||
|
||||
Args:
|
||||
key_name (str): The name of the configuration key.
|
||||
|
||||
Returns:
|
||||
Optional[Any]: The configuration value, or None if not found.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
if not config.has_section("host"):
|
||||
return None
|
||||
|
||||
host_section = dict(config.items("host"))
|
||||
|
||||
if key_name in host_section:
|
||||
if key_name == "stage":
|
||||
# Support old alias in the config.ini
|
||||
if host_section[key_name] == "dev":
|
||||
host_section[key_name] = "development"
|
||||
if host_section[key_name] not in {env.value for env in Stage}:
|
||||
return None
|
||||
return Stage(host_section[key_name])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def str_to_bool(s: str) -> bool:
|
||||
"""
|
||||
Convert a string to a boolean value.
|
||||
|
||||
Args:
|
||||
s (str): The string to convert.
|
||||
|
||||
Returns:
|
||||
bool: The converted boolean value.
|
||||
|
||||
Raises:
|
||||
ValueError: If the string cannot be converted.
|
||||
"""
|
||||
if s.lower() == "true" or s == "1":
|
||||
return True
|
||||
elif s.lower() == "false" or s == "0":
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Cannot convert '{s}' to a boolean value.")
|
||||
|
||||
|
||||
def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]:
|
||||
"""
|
||||
Retrieve the proxy configuration.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG)
|
||||
|
||||
proxy_dictionary = None
|
||||
required = False
|
||||
timeout = None
|
||||
proxy = None
|
||||
|
||||
if config.has_section("proxy"):
|
||||
proxy = dict(config.items("proxy"))
|
||||
|
||||
if proxy:
|
||||
try:
|
||||
proxy_dictionary = get_proxy_dict(
|
||||
proxy["protocol"],
|
||||
proxy["host"],
|
||||
proxy["port"], # type: ignore
|
||||
)
|
||||
required = str_to_bool(proxy["required"])
|
||||
timeout = proxy["timeout"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return proxy_dictionary, timeout, required # type: ignore
|
||||
|
||||
|
||||
def clean_session(client) -> bool:
|
||||
"""
|
||||
Clean the authentication session.
|
||||
|
||||
Args:
|
||||
client: The authentication client.
|
||||
|
||||
Returns:
|
||||
bool: Always returns True.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""}
|
||||
|
||||
with open(AUTH_CONFIG_USER, "w") as configfile:
|
||||
config.write(configfile)
|
||||
|
||||
client.token = None
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def save_auth_config(
|
||||
access_token: Optional[str] = None,
|
||||
id_token: Optional[str] = None,
|
||||
refresh_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save the authentication configuration.
|
||||
|
||||
Args:
|
||||
access_token (Optional[str]): The access token.
|
||||
id_token (Optional[str]): The ID token.
|
||||
refresh_token (Optional[str]): The refresh token.
|
||||
"""
|
||||
config = configparser.ConfigParser()
|
||||
config.read(AUTH_CONFIG_USER)
|
||||
config["auth"] = { # type: ignore
|
||||
"access_token": access_token,
|
||||
"id_token": id_token,
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
with open(AUTH_CONFIG_USER, "w") as configfile:
|
||||
config.write(configfile) # type: ignore
|
||||
105
Backend/venv/lib/python3.12/site-packages/safety/auth/models.py
Normal file
105
Backend/venv/lib/python3.12/site-packages/safety/auth/models.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from authlib.integrations.base_client import BaseOAuth
|
||||
|
||||
from safety_schemas.models import Stage
|
||||
|
||||
|
||||
@dataclass
|
||||
class Organization:
|
||||
id: str
|
||||
name: str
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
Convert the Organization instance to a dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The dictionary representation of the organization.
|
||||
"""
|
||||
return {"id": self.id, "name": self.name}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Auth:
|
||||
org: Optional[Organization]
|
||||
keys: Any
|
||||
client: Any
|
||||
code_verifier: str
|
||||
client_id: str
|
||||
stage: Optional[Stage] = Stage.development
|
||||
email: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
email_verified: bool = False
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""
|
||||
Check if the authentication information is valid.
|
||||
|
||||
Returns:
|
||||
bool: True if valid, False otherwise.
|
||||
"""
|
||||
if os.getenv("SAFETY_DB_DIR"):
|
||||
return True
|
||||
|
||||
if not self.client:
|
||||
return False
|
||||
|
||||
if self.client.api_key:
|
||||
return True
|
||||
|
||||
return bool(self.client.token and self.email_verified)
|
||||
|
||||
def refresh_from(self, info: Dict) -> None:
|
||||
"""
|
||||
Refresh the authentication information from the provided info.
|
||||
|
||||
Args:
|
||||
info (dict): The information to refresh from.
|
||||
"""
|
||||
from safety.auth.utils import is_email_verified
|
||||
|
||||
self.name = info.get("name")
|
||||
self.email = info.get("email")
|
||||
self.email_verified = is_email_verified(info) # type: ignore
|
||||
|
||||
def get_auth_method(self) -> str:
|
||||
"""
|
||||
Get the authentication method.
|
||||
|
||||
Returns:
|
||||
str: The authentication method.
|
||||
"""
|
||||
if self.client.api_key:
|
||||
return "API Key"
|
||||
|
||||
if self.client.token:
|
||||
return "Token"
|
||||
|
||||
return "None"
|
||||
|
||||
|
||||
class XAPIKeyAuth(BaseOAuth):
|
||||
def __init__(self, api_key: str) -> None:
|
||||
"""
|
||||
Initialize the XAPIKeyAuth instance.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
self.api_key = api_key
|
||||
|
||||
def __call__(self, r: Any) -> Any:
|
||||
"""
|
||||
Add the API key to the request headers.
|
||||
|
||||
Args:
|
||||
r (Any): The request object.
|
||||
|
||||
Returns:
|
||||
Any: The modified request object.
|
||||
"""
|
||||
r.headers["X-API-Key"] = self.api_key
|
||||
return r
|
||||
308
Backend/venv/lib/python3.12/site-packages/safety/auth/server.py
Normal file
308
Backend/venv/lib/python3.12/site-packages/safety/auth/server.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# type: ignore
|
||||
import http.server
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Optional, Dict, Tuple
|
||||
import urllib.parse
|
||||
import threading
|
||||
import click
|
||||
|
||||
from safety.auth.utils import is_jupyter_notebook
|
||||
from safety.console import main_console as console
|
||||
|
||||
from safety.auth.constants import (
|
||||
AUTH_SERVER_URL,
|
||||
CLI_AUTH_SUCCESS,
|
||||
CLI_LOGOUT_SUCCESS,
|
||||
HOST,
|
||||
)
|
||||
from safety.auth.main import save_auth_config
|
||||
from rich.prompt import Prompt
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_available_port() -> Optional[int]:
|
||||
"""
|
||||
Find an available port on localhost within the dynamic port range (49152-65536).
|
||||
|
||||
Returns:
|
||||
Optional[int]: An available port number, or None if no ports are available.
|
||||
"""
|
||||
# Dynamic ports IANA
|
||||
port_range = list(range(49152, 65536))
|
||||
random.shuffle(port_range)
|
||||
|
||||
for port in port_range:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.connect(("localhost", port))
|
||||
# If the connect succeeds, the port is already in use
|
||||
except socket.error:
|
||||
# If the connect fails, the port is available
|
||||
return port
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def auth_process(
|
||||
code: str, state: str, initial_state: str, code_verifier: str, client: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Process the authentication callback and exchange the authorization code for tokens.
|
||||
|
||||
Args:
|
||||
code (str): The authorization code.
|
||||
state (str): The state parameter from the callback.
|
||||
initial_state (str): The initial state parameter.
|
||||
code_verifier (str): The code verifier for PKCE.
|
||||
client (Any): The OAuth client.
|
||||
|
||||
Returns:
|
||||
Any: The user information.
|
||||
|
||||
Raises:
|
||||
SystemExit: If there is an error during authentication.
|
||||
"""
|
||||
err = None
|
||||
|
||||
if initial_state is None or initial_state != state:
|
||||
err = (
|
||||
"The state parameter value provided does not match the expected "
|
||||
"value. The state parameter is used to protect against Cross-Site "
|
||||
"Request Forgery (CSRF) attacks. For security reasons, the "
|
||||
"authorization process cannot proceed with an invalid state "
|
||||
"parameter value. Please try again, ensuring that the state "
|
||||
"parameter value provided in the authorization request matches "
|
||||
"the value returned in the callback."
|
||||
)
|
||||
|
||||
if err:
|
||||
click.secho(f"Error: {err}", fg="red")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
tokens = client.fetch_token(
|
||||
url=f"{AUTH_SERVER_URL}/oauth/token",
|
||||
code_verifier=code_verifier,
|
||||
client_id=client.client_id,
|
||||
grant_type="authorization_code",
|
||||
code=code,
|
||||
)
|
||||
|
||||
save_auth_config(
|
||||
access_token=tokens["access_token"],
|
||||
id_token=tokens["id_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
)
|
||||
return client.fetch_user_info()
|
||||
|
||||
except Exception as e:
|
||||
LOG.exception(e)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
||||
def auth(self, code: str, state: str, err: str, error_description: str) -> None:
|
||||
"""
|
||||
Handle the authentication callback.
|
||||
|
||||
Args:
|
||||
code (str): The authorization code.
|
||||
state (str): The state parameter.
|
||||
err (str): The error message, if any.
|
||||
error_description (str): The error description, if any.
|
||||
"""
|
||||
initial_state = self.server.initial_state
|
||||
ctx = self.server.ctx
|
||||
|
||||
result = auth_process(
|
||||
code=code,
|
||||
state=state,
|
||||
initial_state=initial_state,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
client=ctx.obj.auth.client,
|
||||
)
|
||||
|
||||
self.server.callback = result
|
||||
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})
|
||||
|
||||
def logout(self) -> None:
|
||||
"""
|
||||
Handle the logout callback.
|
||||
"""
|
||||
ctx = self.server.ctx
|
||||
uri = CLI_LOGOUT_SUCCESS
|
||||
|
||||
if ctx.obj.auth.org:
|
||||
uri = f"{uri}&org_id={ctx.obj.auth.org.id}"
|
||||
|
||||
self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={})
|
||||
|
||||
def do_GET(self) -> None:
|
||||
"""
|
||||
Handle GET requests.
|
||||
"""
|
||||
query = urllib.parse.urlparse(self.path).query
|
||||
params = urllib.parse.parse_qs(query)
|
||||
callback_type: Optional[str] = None
|
||||
|
||||
try:
|
||||
c_type = params.get("type", [])
|
||||
if (
|
||||
isinstance(c_type, list)
|
||||
and len(c_type) == 1
|
||||
and isinstance(c_type[0], str)
|
||||
):
|
||||
callback_type = c_type[0]
|
||||
except Exception:
|
||||
msg = "Unable to process the callback, try again."
|
||||
self.send_error(400, msg)
|
||||
click.secho("Unable to process the callback, try again.")
|
||||
return
|
||||
|
||||
if callback_type == "logout":
|
||||
self.logout()
|
||||
return
|
||||
|
||||
code = params.get("code", [""])[0]
|
||||
state = params.get("state", [""])[0]
|
||||
err = params.get("error", [""])[0]
|
||||
error_description = params.get("error_description", [""])[0]
|
||||
|
||||
self.auth(code=code, state=state, err=err, error_description=error_description)
|
||||
|
||||
def do_redirect(self, location: str, params: Dict) -> None:
|
||||
"""
|
||||
Redirect the client to the specified location.
|
||||
|
||||
Args:
|
||||
location (str): The URL to redirect to.
|
||||
params (dict): Additional parameters for the redirection.
|
||||
"""
|
||||
self.send_response(302)
|
||||
self.send_header("Location", location)
|
||||
self.send_header("Connection", "close")
|
||||
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
"""
|
||||
Log an arbitrary message.
|
||||
|
||||
Args:
|
||||
format (str): The format string.
|
||||
args (Any): Arguments for the format string.
|
||||
"""
|
||||
LOG.info(format % args)
|
||||
|
||||
|
||||
def process_browser_callback(uri: str, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Process the browser callback for authentication.
|
||||
|
||||
Args:
|
||||
uri (str): The authorization URL.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
Any: The user information.
|
||||
|
||||
Raises:
|
||||
SystemExit: If there is an error during the process.
|
||||
"""
|
||||
|
||||
class ThreadedHTTPServer(http.server.HTTPServer):
|
||||
def __init__(self, server_address: Tuple, RequestHandlerClass: Any) -> None:
|
||||
"""
|
||||
Initialize the ThreadedHTTPServer.
|
||||
Args:
|
||||
server_address (Tuple): The server address as a tuple (host, port).
|
||||
RequestHandlerClass (Any): The request handler class.
|
||||
"""
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
self.initial_state = None
|
||||
self.ctx = None
|
||||
self.callback = None
|
||||
self.timeout_reached = False
|
||||
|
||||
def handle_timeout(self) -> None:
|
||||
"""
|
||||
Handle server timeout.
|
||||
"""
|
||||
self.timeout_reached = True
|
||||
return super().handle_timeout()
|
||||
|
||||
PORT = find_available_port()
|
||||
|
||||
if not PORT:
|
||||
click.secho("No available ports.")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
headless = kwargs.get("headless", False)
|
||||
initial_state = kwargs.get("initial_state", None)
|
||||
ctx = kwargs.get("ctx", None)
|
||||
message = "Copy and paste this URL into your browser:\n:icon_warning: Ensure there are no extra spaces, especially at line breaks, as they may break the link."
|
||||
|
||||
if not headless:
|
||||
# Start a threaded HTTP server to handle the callback
|
||||
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
|
||||
server.initial_state = initial_state
|
||||
server.timeout = kwargs.get("timeout", 600)
|
||||
server.ctx = ctx
|
||||
server_thread = threading.Thread(target=server.handle_request)
|
||||
server_thread.start()
|
||||
message = "If the browser does not automatically open in 5 seconds, copy and paste this url into your browser:"
|
||||
|
||||
target = uri if headless else f"{uri}&port={PORT}"
|
||||
|
||||
if is_jupyter_notebook():
|
||||
console.print(f"{message} {target}")
|
||||
else:
|
||||
console.print(f"{message} [link={target}]{target}[/link]")
|
||||
|
||||
if headless:
|
||||
# Handle the headless mode where user manually provides the response
|
||||
exchange_data = None
|
||||
while not exchange_data:
|
||||
auth_code_text = Prompt.ask(
|
||||
"Paste the response here", default=None, console=console
|
||||
)
|
||||
try:
|
||||
exchange_data = json.loads(auth_code_text)
|
||||
state = exchange_data["state"]
|
||||
code = exchange_data["code"]
|
||||
except Exception:
|
||||
code = state = None
|
||||
|
||||
return auth_process(
|
||||
code=code,
|
||||
state=state,
|
||||
initial_state=initial_state,
|
||||
code_verifier=ctx.obj.auth.code_verifier,
|
||||
client=ctx.obj.auth.client,
|
||||
)
|
||||
else:
|
||||
# Wait for the browser authentication in non-headless mode
|
||||
console.print()
|
||||
wait_msg = "waiting for browser authentication"
|
||||
with console.status(wait_msg, spinner="bouncingBar"):
|
||||
time.sleep(2)
|
||||
click.launch(target)
|
||||
server_thread.join()
|
||||
|
||||
except OSError as e:
|
||||
if e.errno == socket.errno.EADDRINUSE:
|
||||
reason = f"The port {HOST}:{PORT} is currently being used by another application or process. Please choose a different port or terminate the conflicting application/process to free up the port."
|
||||
else:
|
||||
reason = "An error occurred while performing this operation."
|
||||
|
||||
click.secho(reason)
|
||||
sys.exit(1)
|
||||
|
||||
return server.callback
|
||||
756
Backend/venv/lib/python3.12/site-packages/safety/auth/utils.py
Normal file
756
Backend/venv/lib/python3.12/site-packages/safety/auth/utils.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# type: ignore
|
||||
import importlib.util
|
||||
import json
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, List, Literal
|
||||
|
||||
import requests
|
||||
from authlib.integrations.base_client.errors import OAuthError
|
||||
from authlib.integrations.requests_client import OAuth2Session
|
||||
from requests.adapters import HTTPAdapter
|
||||
from safety_schemas.models import STAGE_ID_MAPPING, Stage
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from safety.auth.constants import (
|
||||
AUTH_SERVER_URL,
|
||||
OPENID_CONFIG_URL,
|
||||
)
|
||||
from safety.constants import (
|
||||
PLATFORM_API_CHECK_UPDATES_ENDPOINT,
|
||||
PLATFORM_API_INITIALIZE_ENDPOINT,
|
||||
PLATFORM_API_POLICY_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_CHECK_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT,
|
||||
PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
|
||||
PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
|
||||
REQUEST_TIMEOUT,
|
||||
FeatureType,
|
||||
get_config_setting,
|
||||
FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT,
|
||||
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT,
|
||||
)
|
||||
from safety.error_handlers import output_exception
|
||||
from safety.errors import (
|
||||
InvalidCredentialError,
|
||||
NetworkConnectionError,
|
||||
RequestTimeoutError,
|
||||
SafetyError,
|
||||
ServerError,
|
||||
TooManyRequestsError,
|
||||
)
|
||||
from safety.meta import get_meta_http_headers
|
||||
from safety.models import SafetyCLI
|
||||
from safety.scan.util import AuthenticationType
|
||||
from safety.util import SafetyContext
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_keys(
|
||||
client_session: OAuth2Session, openid_config: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve the keys from the OpenID configuration.
|
||||
|
||||
Args:
|
||||
client_session (OAuth2Session): The OAuth2 session.
|
||||
openid_config (Dict[str, Any]): The OpenID configuration.
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: The keys, if available.
|
||||
"""
|
||||
if "jwks_uri" in openid_config:
|
||||
return client_session.get(url=openid_config["jwks_uri"], bearer=False).json() # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def is_email_verified(info: Dict[str, Any]) -> Optional[bool]:
|
||||
"""
|
||||
Check if the email is verified.
|
||||
|
||||
Args:
|
||||
info (Dict[str, Any]): The user information.
|
||||
|
||||
Returns:
|
||||
bool: True
|
||||
"""
|
||||
# return info.get(CLAIM_EMAIL_VERIFIED_API) or info.get(
|
||||
# CLAIM_EMAIL_VERIFIED_AUTH_SERVER
|
||||
# )
|
||||
|
||||
# Always return True to avoid email verification
|
||||
return True
|
||||
|
||||
|
||||
def extract_detail(response: requests.Response) -> Optional[str]:
|
||||
"""
|
||||
Extract the reason from an HTTP response.
|
||||
|
||||
Args:
|
||||
response (requests.Response): The response.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The reason.
|
||||
"""
|
||||
detail = None
|
||||
|
||||
try:
|
||||
detail = response.json().get("detail")
|
||||
except Exception:
|
||||
LOG.debug("Failed to extract detail from response: %s", response.status_code)
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def parse_response(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to parse the response from an HTTP request.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to wrap.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped function.
|
||||
"""
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential_jitter(initial=0.2, max=8.0, exp_base=3, jitter=0.3),
|
||||
reraise=True,
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
NetworkConnectionError,
|
||||
RequestTimeoutError,
|
||||
TooManyRequestsError,
|
||||
ServerError,
|
||||
)
|
||||
),
|
||||
before_sleep=before_sleep_log(logging.getLogger("api_client"), logging.WARNING),
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
r = func(*args, **kwargs)
|
||||
except OAuthError as e:
|
||||
LOG.exception("OAuth failed: %s", e)
|
||||
raise InvalidCredentialError(
|
||||
message="Your token authentication expired, try login again."
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise NetworkConnectionError()
|
||||
except requests.exceptions.Timeout:
|
||||
raise RequestTimeoutError()
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise e
|
||||
|
||||
# TODO: Handle content as JSON and fallback to text for all responses
|
||||
|
||||
if r.status_code == 403:
|
||||
reason = extract_detail(response=r)
|
||||
|
||||
raise InvalidCredentialError(
|
||||
credential="Failed authentication.", reason=reason
|
||||
)
|
||||
|
||||
if r.status_code == 429:
|
||||
raise TooManyRequestsError(reason=r.text)
|
||||
|
||||
if r.status_code >= 400 and r.status_code < 500:
|
||||
error_code = None
|
||||
try:
|
||||
data = r.json()
|
||||
reason = data.get("detail", "Unable to find reason.")
|
||||
error_code = data.get("error_code", None)
|
||||
except Exception:
|
||||
reason = r.reason
|
||||
|
||||
raise SafetyError(message=reason, error_code=error_code)
|
||||
|
||||
if r.status_code >= 500 and r.status_code < 600:
|
||||
reason = extract_detail(response=r)
|
||||
LOG.debug("ServerError %s -> Response returned: %s", r.status_code, r.text)
|
||||
raise ServerError(reason=reason)
|
||||
|
||||
data = None
|
||||
|
||||
try:
|
||||
data = r.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise ServerError(message=f"Bad JSON response from the server: {e}")
|
||||
|
||||
return data
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class SafetyAuthSession(OAuth2Session):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the SafetyAuthSession.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments for the parent class.
|
||||
**kwargs (Any): Keyword arguments for the parent class.
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.proxy_required: bool = False
|
||||
self.proxy_timeout: Optional[int] = None
|
||||
self.api_key = None
|
||||
|
||||
def get_credential(self) -> Optional[str]:
|
||||
"""
|
||||
Get the current authentication credential.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The API key, token, or None.
|
||||
"""
|
||||
if self.api_key:
|
||||
return self.api_key
|
||||
|
||||
if self.token:
|
||||
return SafetyContext().account
|
||||
|
||||
return None
|
||||
|
||||
def is_using_auth_credentials(self) -> bool:
|
||||
"""
|
||||
Check if the session is using authentication credentials.
|
||||
|
||||
This does NOT check if the client is authenticated.
|
||||
|
||||
Returns:
|
||||
bool: True if using authentication credentials, False otherwise.
|
||||
"""
|
||||
return self.get_authentication_type() != AuthenticationType.none
|
||||
|
||||
def get_authentication_type(self) -> AuthenticationType:
|
||||
"""
|
||||
Get the type of authentication being used.
|
||||
|
||||
Returns:
|
||||
AuthenticationType: The type of authentication.
|
||||
"""
|
||||
if self.api_key:
|
||||
return AuthenticationType.api_key
|
||||
|
||||
if self.token:
|
||||
return AuthenticationType.token
|
||||
|
||||
return AuthenticationType.none
|
||||
|
||||
def request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
withhold_token: bool = False,
|
||||
auth: Optional[Tuple] = None,
|
||||
bearer: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make an HTTP request with the appropriate authentication.
|
||||
|
||||
Use the right auth parameter for Safety supported auth types.
|
||||
|
||||
Args:
|
||||
method (str): The HTTP method.
|
||||
url (str): The URL to request.
|
||||
withhold_token (bool): Whether to withhold the token.
|
||||
auth (Optional[Tuple]): The authentication tuple.
|
||||
bearer (bool): Whether to use bearer authentication.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
requests.Response: The HTTP response.
|
||||
|
||||
Raises:
|
||||
Exception: If the request fails.
|
||||
"""
|
||||
# By default use the token_auth
|
||||
TIMEOUT_KEYWARD = "timeout"
|
||||
func_timeout = (
|
||||
kwargs[TIMEOUT_KEYWARD] if TIMEOUT_KEYWARD in kwargs else REQUEST_TIMEOUT
|
||||
)
|
||||
|
||||
if "headers" not in kwargs:
|
||||
kwargs["headers"] = {}
|
||||
|
||||
kwargs["headers"].update(get_meta_http_headers())
|
||||
|
||||
if self.api_key:
|
||||
kwargs["headers"]["X-Api-Key"] = self.api_key
|
||||
|
||||
if not self.token or not bearer:
|
||||
# Fallback to no token auth
|
||||
auth = ()
|
||||
|
||||
# Override proxies
|
||||
if self.proxies:
|
||||
kwargs["proxies"] = self.proxies
|
||||
|
||||
if self.proxy_timeout:
|
||||
kwargs["timeout"] = int(self.proxy_timeout) / 1000
|
||||
|
||||
if ("proxies" not in kwargs or not self.proxies) and self.proxy_required:
|
||||
output_exception(
|
||||
"Proxy connection is required but there is not a proxy setup.", # type: ignore
|
||||
exit_code_output=True,
|
||||
)
|
||||
|
||||
request_func = super(SafetyAuthSession, self).request
|
||||
params = {
|
||||
"method": method,
|
||||
"url": url,
|
||||
"withhold_token": withhold_token,
|
||||
"auth": auth,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
try:
|
||||
return request_func(**params)
|
||||
except Exception as e:
|
||||
LOG.debug("Request failed: %s", e)
|
||||
|
||||
if self.proxy_required:
|
||||
output_exception(
|
||||
f"Proxy is required but the connection failed because: {e}", # type: ignore
|
||||
exit_code_output=True,
|
||||
)
|
||||
|
||||
if "proxies" in kwargs or self.proxies:
|
||||
params["proxies"] = {}
|
||||
params["timeout"] = func_timeout
|
||||
self.proxies = {}
|
||||
message = (
|
||||
"The proxy configuration failed to function and was disregarded."
|
||||
)
|
||||
LOG.debug(message)
|
||||
if message not in [
|
||||
a["message"] for a in SafetyContext.local_announcements
|
||||
]:
|
||||
SafetyContext.local_announcements.append(
|
||||
{"message": message, "type": "warning", "local": True}
|
||||
)
|
||||
|
||||
return request_func(**params)
|
||||
|
||||
raise e
|
||||
|
||||
def fetch_openid_config(self) -> Any:
|
||||
"""
|
||||
Fetch the OpenID configuration from the authorization server.
|
||||
|
||||
Returns:
|
||||
Any: The OpenID configuration.
|
||||
"""
|
||||
|
||||
try:
|
||||
openid_config = self.get(
|
||||
url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT
|
||||
).json()
|
||||
except Exception as e:
|
||||
LOG.debug("Unable to load the openID config: %s", e)
|
||||
openid_config = {}
|
||||
|
||||
return openid_config
|
||||
|
||||
@parse_response
|
||||
def fetch_user_info(self) -> Any:
|
||||
"""
|
||||
Fetch user information from the authorization server.
|
||||
|
||||
Returns:
|
||||
Any: The user information.
|
||||
"""
|
||||
USER_INFO_ENDPOINT = f"{AUTH_SERVER_URL}/userinfo"
|
||||
|
||||
r = self.get(url=USER_INFO_ENDPOINT)
|
||||
|
||||
return r
|
||||
|
||||
@parse_response
|
||||
def check_project(
|
||||
self,
|
||||
scan_stage: str,
|
||||
safety_source: str,
|
||||
project_slug: Optional[str] = None,
|
||||
git_origin: Optional[str] = None,
|
||||
project_slug_source: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Check project information.
|
||||
|
||||
Args:
|
||||
scan_stage (str): The scan stage.
|
||||
safety_source (str): The safety source.
|
||||
project_slug (Optional[str]): The project slug.
|
||||
git_origin (Optional[str]): The git origin.
|
||||
project_slug_source (Optional[str]): The project slug source.
|
||||
|
||||
Returns:
|
||||
Any: The project information.
|
||||
"""
|
||||
|
||||
data = {
|
||||
"scan_stage": scan_stage,
|
||||
"safety_source": safety_source,
|
||||
"project_slug": project_slug,
|
||||
"project_slug_source": project_slug_source,
|
||||
"git_origin": git_origin,
|
||||
}
|
||||
|
||||
r = self.post(url=PLATFORM_API_PROJECT_CHECK_ENDPOINT, json=data)
|
||||
|
||||
return r
|
||||
|
||||
@parse_response
|
||||
def project(self, project_id: str) -> Any:
|
||||
"""
|
||||
Get project information.
|
||||
|
||||
Args:
|
||||
project_id (str): The project ID.
|
||||
|
||||
Returns:
|
||||
Any: The project information.
|
||||
"""
|
||||
data = {"project": project_id}
|
||||
|
||||
return self.get(url=PLATFORM_API_PROJECT_ENDPOINT, params=data)
|
||||
|
||||
@parse_response
|
||||
def download_policy(
|
||||
self, project_id: Optional[str], stage: Stage, branch: Optional[str]
|
||||
) -> Any:
|
||||
"""
|
||||
Download the project policy.
|
||||
|
||||
Args:
|
||||
project_id (Optional[str]): The project ID.
|
||||
stage (Stage): The stage.
|
||||
branch (Optional[str]): The branch.
|
||||
|
||||
Returns:
|
||||
Any: The policy data.
|
||||
"""
|
||||
data = {
|
||||
"project": project_id,
|
||||
"stage": STAGE_ID_MAPPING[stage],
|
||||
"branch": branch,
|
||||
}
|
||||
|
||||
return self.get(url=PLATFORM_API_POLICY_ENDPOINT, params=data)
|
||||
|
||||
@parse_response
|
||||
def project_scan_request(self, project_id: str) -> Any:
|
||||
"""
|
||||
Request a project scan.
|
||||
|
||||
Args:
|
||||
project_id (str): The project ID.
|
||||
|
||||
Returns:
|
||||
Any: The scan request result.
|
||||
"""
|
||||
data = {"project_id": project_id}
|
||||
|
||||
return self.post(url=PLATFORM_API_PROJECT_SCAN_REQUEST_ENDPOINT, json=data)
|
||||
|
||||
@parse_response
|
||||
def upload_report(self, json_report: str) -> Any:
|
||||
"""
|
||||
Upload a scan report.
|
||||
|
||||
Args:
|
||||
json_report (str): The JSON report.
|
||||
|
||||
Returns:
|
||||
Any: The upload result.
|
||||
"""
|
||||
|
||||
return self.post(
|
||||
url=PLATFORM_API_PROJECT_UPLOAD_SCAN_ENDPOINT,
|
||||
data=json_report,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def upload_requirements(self, json_payload: str) -> Any:
|
||||
"""
|
||||
Upload a scan report.
|
||||
Args:
|
||||
json_payload (str): The JSON payload to upload.
|
||||
Returns:
|
||||
Any: The result of the upload operation.
|
||||
"""
|
||||
return self.post(
|
||||
url=PLATFORM_API_REQUIREMENTS_UPLOAD_SCAN_ENDPOINT,
|
||||
data=json.dumps(json_payload),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
@parse_response
|
||||
def check_updates(
|
||||
self,
|
||||
version: int,
|
||||
safety_version: Optional[str] = None,
|
||||
python_version: Optional[str] = None,
|
||||
os_type: Optional[str] = None,
|
||||
os_release: Optional[str] = None,
|
||||
os_description: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Check for updates.
|
||||
|
||||
Args:
|
||||
version (int): The version.
|
||||
safety_version (Optional[str]): The Safety version.
|
||||
python_version (Optional[str]): The Python version.
|
||||
os_type (Optional[str]): The OS type.
|
||||
os_release (Optional[str]): The OS release.
|
||||
os_description (Optional[str]): The OS description.
|
||||
|
||||
Returns:
|
||||
Any: The update check result.
|
||||
"""
|
||||
data = {
|
||||
"version": version,
|
||||
"safety_version": safety_version,
|
||||
"python_version": python_version,
|
||||
"os_type": os_type,
|
||||
"os_release": os_release,
|
||||
"os_description": os_description,
|
||||
}
|
||||
|
||||
return self.get(url=PLATFORM_API_CHECK_UPDATES_ENDPOINT, params=data)
|
||||
|
||||
@parse_response
|
||||
def audit_packages(
|
||||
self, packages: List[str], ecosystem: Literal["pypi", "npmjs"]
|
||||
) -> Any:
|
||||
"""
|
||||
Audits packages for vulnerabilities
|
||||
Args:
|
||||
packages: list of package specifiers
|
||||
ecosystem: the ecosystem to audit
|
||||
|
||||
Returns:
|
||||
Any: The packages audit result.
|
||||
"""
|
||||
url = (
|
||||
FIREWALL_AUDIT_NPMJS_PACKAGES_ENDPOINT
|
||||
if ecosystem == "npmjs"
|
||||
else FIREWALL_AUDIT_PYPI_PACKAGES_ENDPOINT
|
||||
)
|
||||
|
||||
data = {"packages": [{"package_specifier": package} for package in packages]}
|
||||
|
||||
return self.post(url=url, json=data)
|
||||
|
||||
@parse_response
|
||||
def initialize(self) -> Any:
|
||||
"""
|
||||
Initialize a run.
|
||||
|
||||
Returns:
|
||||
Any: The initialization result.
|
||||
"""
|
||||
try:
|
||||
response = self.get(
|
||||
url=PLATFORM_API_INITIALIZE_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=5,
|
||||
)
|
||||
return response
|
||||
except requests.exceptions.Timeout:
|
||||
LOG.error("Auth request to initialize timed out after 5 seconds.")
|
||||
except Exception:
|
||||
LOG.exception("Exception trying to auth initialize", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
class S3PresignedAdapter(HTTPAdapter):
|
||||
def send( # type: ignore
|
||||
self, request: requests.PreparedRequest, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Send a request, removing the Authorization header.
|
||||
|
||||
Args:
|
||||
request (requests.PreparedRequest): The prepared request.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
requests.Response: The response.
|
||||
"""
|
||||
request.headers.pop("Authorization", None)
|
||||
return super().send(request, **kwargs)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_jupyter_notebook() -> bool:
|
||||
"""
|
||||
Detects if the code is running in a Jupyter notebook environment, including
|
||||
various cloud-hosted Jupyter notebooks.
|
||||
|
||||
Returns:
|
||||
bool: True if the environment is identified as a Jupyter notebook (or
|
||||
equivalent cloud-based environment), False otherwise.
|
||||
|
||||
Supported environments:
|
||||
- Google Colab
|
||||
- Amazon SageMaker
|
||||
- Azure Notebooks
|
||||
- Kaggle Notebooks
|
||||
- Databricks Notebooks
|
||||
- Datalore by JetBrains
|
||||
- Paperspace Gradient Notebooks
|
||||
- Classic Jupyter Notebook and JupyterLab
|
||||
"""
|
||||
if (
|
||||
(
|
||||
importlib.util.find_spec("google")
|
||||
and importlib.util.find_spec("google.colab")
|
||||
)
|
||||
is not None
|
||||
or importlib.util.find_spec("sagemaker") is not None
|
||||
or importlib.util.find_spec("azureml") is not None
|
||||
or importlib.util.find_spec("kaggle") is not None
|
||||
or importlib.util.find_spec("dbutils") is not None
|
||||
or importlib.util.find_spec("datalore") is not None
|
||||
or importlib.util.find_spec("gradient") is not None
|
||||
):
|
||||
return True
|
||||
|
||||
# Detect classic Jupyter Notebook, JupyterLab, and other IPython kernel-based environments
|
||||
try:
|
||||
from IPython import get_ipython # type: ignore
|
||||
|
||||
ipython = get_ipython()
|
||||
if ipython is not None and "IPKernelApp" in ipython.config:
|
||||
return True
|
||||
except (ImportError, AttributeError, NameError):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def save_flags_config(flags: Dict[FeatureType, bool]) -> None:
|
||||
"""
|
||||
Save feature flags configuration to file.
|
||||
|
||||
This function attempts to save feature flags to the configuration file
|
||||
but will fail silently if unable to do so (e.g., due to permission issues
|
||||
or disk problems). Silent failure is chosen to prevent configuration issues
|
||||
from disrupting core application functionality.
|
||||
|
||||
Note that if saving fails, the application will continue using existing
|
||||
or default flag values until the next restart.
|
||||
|
||||
Args:
|
||||
flags: Dictionary mapping feature types to their enabled/disabled state
|
||||
|
||||
The operation will be logged (with stack trace) if it fails.
|
||||
"""
|
||||
import configparser
|
||||
|
||||
from safety.constants import CONFIG_FILE_USER
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(CONFIG_FILE_USER)
|
||||
|
||||
flag_settings = {key.name.upper(): str(value) for key, value in flags.items()}
|
||||
|
||||
if not config.has_section("settings"):
|
||||
config.add_section("settings")
|
||||
|
||||
settings = dict(config.items("settings"))
|
||||
settings.update(flag_settings)
|
||||
|
||||
for key, value in settings.items():
|
||||
config.set("settings", key, value)
|
||||
|
||||
try:
|
||||
with open(CONFIG_FILE_USER, "w") as config_file:
|
||||
config.write(config_file)
|
||||
except Exception:
|
||||
LOG.exception("Unable to save flags configuration.")
|
||||
|
||||
|
||||
def get_feature_name(feature: FeatureType, as_attr: bool = False) -> str:
|
||||
"""Returns a formatted feature name with enabled suffix.
|
||||
|
||||
Args:
|
||||
feature: The feature to format the name for
|
||||
as_attr: If True, formats for attribute usage (underscore),
|
||||
otherwise uses hyphen
|
||||
|
||||
Returns:
|
||||
Formatted feature name string with enabled suffix
|
||||
"""
|
||||
name = feature.name.lower()
|
||||
separator = "_" if as_attr else "-"
|
||||
return f"{name}{separator}enabled"
|
||||
|
||||
|
||||
def str_to_bool(value) -> Optional[bool]:
|
||||
"""Convert basic string representations to boolean."""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.lower().strip()
|
||||
if value in ("true"):
|
||||
return True
|
||||
if value in ("false"):
|
||||
return False
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def initialize(ctx: Any, refresh: bool = True) -> None:
|
||||
"""
|
||||
Initializes the run by loading settings.
|
||||
|
||||
Args:
|
||||
ctx (Any): The context object.
|
||||
refresh (bool): Whether to refresh settings from the server. Defaults to True.
|
||||
"""
|
||||
settings = None
|
||||
current_values = {}
|
||||
|
||||
if not ctx.obj:
|
||||
ctx.obj = SafetyCLI()
|
||||
|
||||
for feature in FeatureType:
|
||||
value = get_config_setting(feature.name)
|
||||
if value is not None:
|
||||
current_values[feature] = str_to_bool(value)
|
||||
|
||||
if refresh:
|
||||
try:
|
||||
settings = ctx.obj.auth.client.initialize() # type: ignore
|
||||
except Exception:
|
||||
LOG.info("Unable to initialize, continue with default values.")
|
||||
|
||||
if settings:
|
||||
for feature in FeatureType:
|
||||
server_value = str_to_bool(settings.get(feature.config_key))
|
||||
if server_value is not None:
|
||||
if (
|
||||
feature not in current_values
|
||||
or current_values[feature] != server_value
|
||||
):
|
||||
current_values[feature] = server_value
|
||||
|
||||
save_flags_config(current_values)
|
||||
|
||||
for feature, value in current_values.items():
|
||||
if value is not None:
|
||||
setattr(ctx.obj, feature.attr_name, value)
|
||||
Reference in New Issue
Block a user