This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -0,0 +1,329 @@
import configparser
from typing import Any, Dict, Optional, Tuple, Union
from authlib.oidc.core import CodeIDToken
from authlib.jose import jwt
from authlib.jose.errors import ExpiredTokenError
from safety.auth.models import Organization
from safety.auth.constants import (
CLI_AUTH_LOGOUT,
CLI_CALLBACK,
AUTH_CONFIG_USER,
CLI_AUTH,
)
from safety.constants import CONFIG
from safety_schemas.models import Stage
from safety.util import get_proxy_dict
def get_authorization_data(
client,
code_verifier: str,
organization: Optional[Organization] = None,
sign_up: bool = False,
ensure_auth: bool = False,
headless: bool = False,
) -> Tuple[str, str]:
"""
Generate the authorization URL for the authentication process.
Args:
client: The authentication client.
code_verifier (str): The code verifier for the PKCE flow.
organization (Optional[Organization]): The organization to authenticate with.
sign_up (bool): Whether the URL is for sign-up.
ensure_auth (bool): Whether to ensure authentication.
headless (bool): Whether to run in headless mode.
Returns:
Tuple[str, str]: The authorization URL and initial state.
"""
kwargs = {
"sign_up": sign_up,
"locale": "en",
"ensure_auth": ensure_auth,
"headless": headless,
}
if organization:
kwargs["organization"] = organization.id
return client.create_authorization_url(
CLI_AUTH, code_verifier=code_verifier, **kwargs
)
def get_logout_url(id_token: str) -> str:
"""
Generate the logout URL.
Args:
id_token (str): The ID token.
Returns:
str: The logout URL.
"""
return f"{CLI_AUTH_LOGOUT}?id_token={id_token}"
def get_redirect_url() -> str:
"""
Get the redirect URL for the authentication callback.
Returns:
str: The redirect URL.
"""
return CLI_CALLBACK
def get_organization() -> Optional[Organization]:
"""
Retrieve the organization configuration.
Returns:
Optional[Organization]: The organization object, or None if not configured.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
org_conf: Union[Dict[str, str], configparser.SectionProxy] = (
config["organization"] if "organization" in config.sections() else {}
)
org_id: Optional[str] = (
org_conf["id"].replace('"', "") if org_conf.get("id", None) else None
)
org_name: Optional[str] = (
org_conf["name"].replace('"', "") if org_conf.get("name", None) else None
)
if not org_id:
return None
org = Organization(id=org_id, name=org_name) # type: ignore
return org
def get_auth_info(ctx) -> Optional[Dict]:
"""
Retrieve the authentication information.
Args:
ctx: The context object containing authentication data.
Returns:
Optional[Dict]: The authentication information, or None if not authenticated.
"""
from safety.auth.utils import is_email_verified
info = None
if ctx.obj.auth.client.token:
try:
info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore
verified = is_email_verified(info) # type: ignore
if not verified:
user_info = ctx.obj.auth.client.fetch_user_info()
verified = is_email_verified(user_info)
if verified:
# refresh only if needed
raise ExpiredTokenError
except ExpiredTokenError:
# id_token expired. So fire a manually a refresh
try:
ctx.obj.auth.client.refresh_token(
ctx.obj.auth.client.metadata.get("token_endpoint"),
refresh_token=ctx.obj.auth.client.token.get("refresh_token"),
)
info = get_token_data(
get_token(name="id_token"), # type: ignore
keys=ctx.obj.auth.keys, # type: ignore
)
except Exception as _e:
clean_session(ctx.obj.auth.client)
except Exception as _g:
clean_session(ctx.obj.auth.client)
return info
def get_token_data(
token: str, keys: Any, silent_if_expired: bool = False
) -> Optional[Dict]:
"""
Decode and validate the token data.
Args:
token (str): The token to decode.
keys (Any): The keys to use for decoding.
silent_if_expired (bool): Whether to silently ignore expired tokens.
Returns:
Optional[Dict]: The decoded token data, or None if invalid.
"""
claims = jwt.decode(token, keys, claims_cls=CodeIDToken)
try:
claims.validate()
except ExpiredTokenError as e:
if not silent_if_expired:
raise e
return claims
def get_token(name: str = "access_token") -> Optional[str]:
""" "
Retrieve a token from the local authentication configuration.
This returns tokens saved in the local auth configuration.
There are two types of tokens: access_token and id_token
Args:
name (str): The name of the token to retrieve.
Returns:
Optional[str]: The token value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
if "auth" in config.sections() and name in config["auth"]:
value = config["auth"][name]
if value:
return value
return None
def get_host_config(key_name: str) -> Optional[Any]:
"""
Retrieve a configuration value from the host configuration.
Args:
key_name (str): The name of the configuration key.
Returns:
Optional[Any]: The configuration value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
if not config.has_section("host"):
return None
host_section = dict(config.items("host"))
if key_name in host_section:
if key_name == "stage":
# Support old alias in the config.ini
if host_section[key_name] == "dev":
host_section[key_name] = "development"
if host_section[key_name] not in {env.value for env in Stage}:
return None
return Stage(host_section[key_name])
return None
def str_to_bool(s: str) -> bool:
"""
Convert a string to a boolean value.
Args:
s (str): The string to convert.
Returns:
bool: The converted boolean value.
Raises:
ValueError: If the string cannot be converted.
"""
if s.lower() == "true" or s == "1":
return True
elif s.lower() == "false" or s == "0":
return False
else:
raise ValueError(f"Cannot convert '{s}' to a boolean value.")
def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]:
"""
Retrieve the proxy configuration.
Returns:
Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
proxy_dictionary = None
required = False
timeout = None
proxy = None
if config.has_section("proxy"):
proxy = dict(config.items("proxy"))
if proxy:
try:
proxy_dictionary = get_proxy_dict(
proxy["protocol"],
proxy["host"],
proxy["port"], # type: ignore
)
required = str_to_bool(proxy["required"])
timeout = proxy["timeout"]
except Exception:
pass
return proxy_dictionary, timeout, required # type: ignore
def clean_session(client) -> bool:
"""
Clean the authentication session.
Args:
client: The authentication client.
Returns:
bool: Always returns True.
"""
config = configparser.ConfigParser()
config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile)
client.token = None
return True
def save_auth_config(
access_token: Optional[str] = None,
id_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> None:
"""
Save the authentication configuration.
Args:
access_token (Optional[str]): The access token.
id_token (Optional[str]): The ID token.
refresh_token (Optional[str]): The refresh token.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
config["auth"] = { # type: ignore
"access_token": access_token,
"id_token": id_token,
"refresh_token": refresh_token,
}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile) # type: ignore