Files
Hotel-Booking/Backend/venv/lib/python3.12/site-packages/safety/auth/main.py
Iliyan Angelov 62c1fe5951 updates
2025-12-01 06:50:10 +02:00

330 lines
8.7 KiB
Python

import configparser
from typing import Any, Dict, Optional, Tuple, Union
from authlib.oidc.core import CodeIDToken
from authlib.jose import jwt
from authlib.jose.errors import ExpiredTokenError
from safety.auth.models import Organization
from safety.auth.constants import (
CLI_AUTH_LOGOUT,
CLI_CALLBACK,
AUTH_CONFIG_USER,
CLI_AUTH,
)
from safety.constants import CONFIG
from safety_schemas.models import Stage
from safety.util import get_proxy_dict
def get_authorization_data(
client,
code_verifier: str,
organization: Optional[Organization] = None,
sign_up: bool = False,
ensure_auth: bool = False,
headless: bool = False,
) -> Tuple[str, str]:
"""
Generate the authorization URL for the authentication process.
Args:
client: The authentication client.
code_verifier (str): The code verifier for the PKCE flow.
organization (Optional[Organization]): The organization to authenticate with.
sign_up (bool): Whether the URL is for sign-up.
ensure_auth (bool): Whether to ensure authentication.
headless (bool): Whether to run in headless mode.
Returns:
Tuple[str, str]: The authorization URL and initial state.
"""
kwargs = {
"sign_up": sign_up,
"locale": "en",
"ensure_auth": ensure_auth,
"headless": headless,
}
if organization:
kwargs["organization"] = organization.id
return client.create_authorization_url(
CLI_AUTH, code_verifier=code_verifier, **kwargs
)
def get_logout_url(id_token: str) -> str:
"""
Generate the logout URL.
Args:
id_token (str): The ID token.
Returns:
str: The logout URL.
"""
return f"{CLI_AUTH_LOGOUT}?id_token={id_token}"
def get_redirect_url() -> str:
"""
Get the redirect URL for the authentication callback.
Returns:
str: The redirect URL.
"""
return CLI_CALLBACK
def get_organization() -> Optional[Organization]:
"""
Retrieve the organization configuration.
Returns:
Optional[Organization]: The organization object, or None if not configured.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
org_conf: Union[Dict[str, str], configparser.SectionProxy] = (
config["organization"] if "organization" in config.sections() else {}
)
org_id: Optional[str] = (
org_conf["id"].replace('"', "") if org_conf.get("id", None) else None
)
org_name: Optional[str] = (
org_conf["name"].replace('"', "") if org_conf.get("name", None) else None
)
if not org_id:
return None
org = Organization(id=org_id, name=org_name) # type: ignore
return org
def get_auth_info(ctx) -> Optional[Dict]:
"""
Retrieve the authentication information.
Args:
ctx: The context object containing authentication data.
Returns:
Optional[Dict]: The authentication information, or None if not authenticated.
"""
from safety.auth.utils import is_email_verified
info = None
if ctx.obj.auth.client.token:
try:
info = get_token_data(get_token(name="id_token"), keys=ctx.obj.auth.keys) # type: ignore
verified = is_email_verified(info) # type: ignore
if not verified:
user_info = ctx.obj.auth.client.fetch_user_info()
verified = is_email_verified(user_info)
if verified:
# refresh only if needed
raise ExpiredTokenError
except ExpiredTokenError:
# id_token expired. So fire a manually a refresh
try:
ctx.obj.auth.client.refresh_token(
ctx.obj.auth.client.metadata.get("token_endpoint"),
refresh_token=ctx.obj.auth.client.token.get("refresh_token"),
)
info = get_token_data(
get_token(name="id_token"), # type: ignore
keys=ctx.obj.auth.keys, # type: ignore
)
except Exception as _e:
clean_session(ctx.obj.auth.client)
except Exception as _g:
clean_session(ctx.obj.auth.client)
return info
def get_token_data(
token: str, keys: Any, silent_if_expired: bool = False
) -> Optional[Dict]:
"""
Decode and validate the token data.
Args:
token (str): The token to decode.
keys (Any): The keys to use for decoding.
silent_if_expired (bool): Whether to silently ignore expired tokens.
Returns:
Optional[Dict]: The decoded token data, or None if invalid.
"""
claims = jwt.decode(token, keys, claims_cls=CodeIDToken)
try:
claims.validate()
except ExpiredTokenError as e:
if not silent_if_expired:
raise e
return claims
def get_token(name: str = "access_token") -> Optional[str]:
""" "
Retrieve a token from the local authentication configuration.
This returns tokens saved in the local auth configuration.
There are two types of tokens: access_token and id_token
Args:
name (str): The name of the token to retrieve.
Returns:
Optional[str]: The token value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
if "auth" in config.sections() and name in config["auth"]:
value = config["auth"][name]
if value:
return value
return None
def get_host_config(key_name: str) -> Optional[Any]:
"""
Retrieve a configuration value from the host configuration.
Args:
key_name (str): The name of the configuration key.
Returns:
Optional[Any]: The configuration value, or None if not found.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
if not config.has_section("host"):
return None
host_section = dict(config.items("host"))
if key_name in host_section:
if key_name == "stage":
# Support old alias in the config.ini
if host_section[key_name] == "dev":
host_section[key_name] = "development"
if host_section[key_name] not in {env.value for env in Stage}:
return None
return Stage(host_section[key_name])
return None
def str_to_bool(s: str) -> bool:
"""
Convert a string to a boolean value.
Args:
s (str): The string to convert.
Returns:
bool: The converted boolean value.
Raises:
ValueError: If the string cannot be converted.
"""
if s.lower() == "true" or s == "1":
return True
elif s.lower() == "false" or s == "0":
return False
else:
raise ValueError(f"Cannot convert '{s}' to a boolean value.")
def get_proxy_config() -> Tuple[Optional[Dict[str, str]], Optional[int], bool]:
"""
Retrieve the proxy configuration.
Returns:
Tuple[Optional[Dict[str, str]], Optional[int], bool]: The proxy configuration, timeout, and whether it is required.
"""
config = configparser.ConfigParser()
config.read(CONFIG)
proxy_dictionary = None
required = False
timeout = None
proxy = None
if config.has_section("proxy"):
proxy = dict(config.items("proxy"))
if proxy:
try:
proxy_dictionary = get_proxy_dict(
proxy["protocol"],
proxy["host"],
proxy["port"], # type: ignore
)
required = str_to_bool(proxy["required"])
timeout = proxy["timeout"]
except Exception:
pass
return proxy_dictionary, timeout, required # type: ignore
def clean_session(client) -> bool:
"""
Clean the authentication session.
Args:
client: The authentication client.
Returns:
bool: Always returns True.
"""
config = configparser.ConfigParser()
config["auth"] = {"access_token": "", "id_token": "", "refresh_token": ""}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile)
client.token = None
return True
def save_auth_config(
access_token: Optional[str] = None,
id_token: Optional[str] = None,
refresh_token: Optional[str] = None,
) -> None:
"""
Save the authentication configuration.
Args:
access_token (Optional[str]): The access token.
id_token (Optional[str]): The ID token.
refresh_token (Optional[str]): The refresh token.
"""
config = configparser.ConfigParser()
config.read(AUTH_CONFIG_USER)
config["auth"] = { # type: ignore
"access_token": access_token,
"id_token": id_token,
"refresh_token": refresh_token,
}
with open(AUTH_CONFIG_USER, "w") as configfile:
config.write(configfile) # type: ignore