import configparser from typing import Any, Dict, Optional, Tuple, Union from authlib.oidc.core import CodeIDToken from authlib.jose import jwt from authlib.jose.errors import ExpiredTokenError from safety.auth.models import Organization from safety.auth.constants import ( CLI_AUTH_LOGOUT, CLI_CALLBACK, AUTH_CONFIG_USER, CLI_AUTH, ) from safety.constants import CONFIG from safety_schemas.models import Stage from safety.util import get_proxy_dict def get_authorization_data( client, code_verifier: str, organization: Optional[Organization] = None, sign_up: bool = False, ensure_auth: bool = False, headless: bool = False, ) -> Tuple[str, str]: """ Generate the authorization URL for the authentication process. Args: client: The authentication client. code_verifier (str): The code verifier for the PKCE flow. organization (Optional[Organization]): The organization to authenticate with. sign_up (bool): Whether the URL is for sign-up. ensure_auth (bool): Whether to ensure authentication. headless (bool): Whether to run in headless mode. Returns: Tuple[str, str]: The authorization URL and initial state. """ kwargs = { "sign_up": sign_up, "locale": "en", "ensure_auth": ensure_auth, "headless": headless, } if organization: kwargs["organization"] = organization.id return client.create_authorization_url( CLI_AUTH, code_verifier=code_verifier, **kwargs ) def get_logout_url(id_token: str) -> str: """ Generate the logout URL. Args: id_token (str): The ID token. Returns: str: The logout URL. """ return f"{CLI_AUTH_LOGOUT}?id_token={id_token}" def get_redirect_url() -> str: """ Get the redirect URL for the authentication callback. Returns: str: The redirect URL. """ return CLI_CALLBACK def get_organization() -> Optional[Organization]: """ Retrieve the organization configuration. Returns: Optional[Organization]: The organization object, or None if not configured. """ config = configparser.ConfigParser() config.read(CONFIG) org_conf: Union[Dict[str, str], configparser.SectionProxy] = ( config["organization"] if "organization" in config.sections() else {} ) org_id: Optional[str] = ( org_conf["id"].replace('"', "") if org_conf.get("id", None) else None ) org_name: Optional[str] = ( org_conf["name"].replace('"', "") if org_conf.get("name", None) else None ) if not org_id: return None org = Organization(id=org_id, name=org_name) # type: ignore return org def get_auth_info(ctx) -> Optional[Dict]: """ Retrieve the authentication information. Args: ctx: The context object containing authentication data. Returns: Optional[Dict]: The authentication information, or None if not authenticated. """ from safety.auth.utils import is_email_verified info = None if ctx.obj.auth.client.token: try: info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore verified = is_email_verified(info) # type: ignore if not verified: user_info = ctx.obj.auth.client.fetch_user_info() verified = is_email_verified(user_info) if verified: # refresh only if needed raise ExpiredTokenError except ExpiredTokenError: # id_token expired. So fire a manually a refresh try: ctx.obj.auth.client.refresh_token( ctx.obj.auth.client.metadata.get("token_endpoint"), refresh_token=ctx.obj.auth.client.token.get("refresh_token"), ) info = get_token_data( get_token(name="id_token"), # type: ignore keys=ctx.obj.auth.keys, # type: ignore ) except Exception as _e: clean_session(ctx.obj.auth.client) except Exception as _g: clean_session(ctx.obj.auth.client) return info def get_token_data( token: str, keys: Any, silent_if_expired: bool = False ) -> Optional[Dict]: """ Decode and validate the token data. Args: token (str): The token to decode. keys (Any): The keys to use for decoding. silent_if_expired (bool): Whether to silently ignore expired tokens. Returns: Optional[Dict]: The decoded token data, or None if invalid. """ claims = jwt.decode(token, keys, claims_cls=CodeIDToken) try: claims.validate() except ExpiredTokenError as e: if not silent_if_expired: raise e return claims def get_token(name: str = "access_token") -> Optional[str]: """ " Retrieve a token from the local authentication configuration. This returns tokens saved in the local auth configuration. There are two types of tokens: access_token and id_token Args: name (str): The name of the token to retrieve. Returns: Optional[str]: The token value, or None if not found. """ config = configparser.ConfigParser() config.read(AUTH_CONFIG_USER) if "auth" in config.sections() and name in config["auth"]: value = config["auth"][name] if value: return value return None def get_host_config(key_name: str) -> Optional[Any]: """ Retrieve a configuration value from the host configuration. Args: key_name (str): The name of the configuration key. Returns: Optional[Any]: The configuration value, or None if not found. """ config = configparser.ConfigParser() config.read(CONFIG) if not config.has_section("host"): return None host_section = dict(config.items("host")) if key_name in host_section: if key_name == "stage": # Support old alias in the config.ini if host_section[key_name] == "dev": host_section[key_name] = "development" if host_section[key_name] not in {env.value for env in Stage}: return None return Stage(host_section[key_name]) return None def str_to_bool(s: str) -> bool: """ Convert a string to a boolean value. Args: s (str): The string to convert. Returns: bool: The converted boolean value. Raises: ValueError: If the string cannot be converted. """ if s.lower() == "true" or s == "1": return True elif s.lower() == "false" or s == "0": return False else: raise ValueError(f"Cannot convert '{s}' to a boolean value.") def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]: """ Retrieve the proxy configuration. Returns: Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required. """ config = configparser.ConfigParser() config.read(CONFIG) proxy_dictionary = None required = False timeout = None proxy = None if config.has_section("proxy"): proxy = dict(config.items("proxy")) if proxy: try: proxy_dictionary = get_proxy_dict( proxy["protocol"], proxy["host"], proxy["port"], # type: ignore ) required = str_to_bool(proxy["required"]) timeout = proxy["timeout"] except Exception: pass return proxy_dictionary, timeout, required # type: ignore def clean_session(client) -> bool: """ Clean the authentication session. Args: client: The authentication client. Returns: bool: Always returns True. """ config = configparser.ConfigParser() config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""} with open(AUTH_CONFIG_USER, "w") as configfile: config.write(configfile) client.token = None return True def save_auth_config( access_token: Optional[str] = None, id_token: Optional[str] = None, refresh_token: Optional[str] = None, ) -> None: """ Save the authentication configuration. Args: access_token (Optional[str]): The access token. id_token (Optional[str]): The ID token. refresh_token (Optional[str]): The refresh token. """ config = configparser.ConfigParser() config.read(AUTH_CONFIG_USER) config["auth"] = { # type: ignore "access_token": access_token, "id_token": id_token, "refresh_token": refresh_token, } with open(AUTH_CONFIG_USER, "w") as configfile: config.write(configfile) # type: ignore