updates
This commit is contained in:
@@ -0,0 +1,290 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple, Any, Callable
|
||||
|
||||
import click
|
||||
|
||||
from .main import (
|
||||
get_auth_info,
|
||||
get_host_config,
|
||||
get_organization,
|
||||
get_proxy_config,
|
||||
get_redirect_url,
|
||||
get_token_data,
|
||||
save_auth_config,
|
||||
get_token,
|
||||
clean_session,
|
||||
)
|
||||
from authlib.common.security import generate_token
|
||||
from safety.auth.constants import CLIENT_ID
|
||||
|
||||
from safety.auth.models import Organization, Auth
|
||||
from safety.auth.utils import (
|
||||
S3PresignedAdapter,
|
||||
SafetyAuthSession,
|
||||
get_keys,
|
||||
is_email_verified,
|
||||
)
|
||||
from safety.scan.constants import (
|
||||
CLI_KEY_HELP,
|
||||
CLI_PROXY_HOST_HELP,
|
||||
CLI_PROXY_PORT_HELP,
|
||||
CLI_PROXY_PROTOCOL_HELP,
|
||||
CLI_STAGE_HELP,
|
||||
)
|
||||
from safety.util import DependentOption, SafetyContext, get_proxy_dict
|
||||
from safety.models import SafetyCLI
|
||||
from safety_schemas.models import Stage
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_client_session(
|
||||
api_key: Optional[str] = None,
|
||||
proxies: Optional[Dict[str, str]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[SafetyAuthSession, Dict[str, Any]]:
|
||||
"""
|
||||
Builds and configures the client session for authentication.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): The API key for authentication.
|
||||
proxies (Optional[Dict[str, str]]): Proxy configuration.
|
||||
headers (Optional[Dict[str, str]]): Additional headers.
|
||||
|
||||
Returns:
|
||||
Tuple[SafetyAuthSession, Dict[str, Any]]: The configured client session and OpenID configuration.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
target_proxies = proxies
|
||||
|
||||
# Global proxy defined in the config.ini
|
||||
proxy_config, proxy_timeout, proxy_required = get_proxy_config()
|
||||
|
||||
if not proxies:
|
||||
target_proxies = proxy_config
|
||||
|
||||
def update_token(tokens, **kwargs):
|
||||
save_auth_config(
|
||||
access_token=tokens["access_token"],
|
||||
id_token=tokens["id_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
)
|
||||
load_auth_session(click_ctx=click.get_current_context(silent=True)) # type: ignore
|
||||
|
||||
client_session = SafetyAuthSession(
|
||||
client_id=CLIENT_ID,
|
||||
code_challenge_method="S256",
|
||||
redirect_uri=get_redirect_url(),
|
||||
update_token=update_token,
|
||||
scope="openid email profile offline_access",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
client_session.mount("https://pyup.io/static-s3/", S3PresignedAdapter())
|
||||
|
||||
client_session.proxy_required = proxy_required
|
||||
client_session.proxy_timeout = proxy_timeout
|
||||
client_session.proxies = target_proxies # type: ignore
|
||||
client_session.headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
openid_config = client_session.fetch_openid_config()
|
||||
|
||||
client_session.metadata["token_endpoint"] = openid_config.get(
|
||||
"token_endpoint", None
|
||||
)
|
||||
|
||||
if api_key:
|
||||
client_session.api_key = api_key # type: ignore
|
||||
client_session.headers["X-Api-Key"] = api_key
|
||||
|
||||
if headers:
|
||||
client_session.headers.update(headers)
|
||||
|
||||
return client_session, openid_config
|
||||
|
||||
|
||||
def load_auth_session(click_ctx: click.Context) -> None:
|
||||
"""
|
||||
Loads the authentication session from the context.
|
||||
|
||||
Args:
|
||||
click_ctx (click.Context): The Click context object.
|
||||
"""
|
||||
if not click_ctx:
|
||||
LOG.warning("Click context is needed to be able to load the Auth data.")
|
||||
return
|
||||
|
||||
client = click_ctx.obj.auth.client
|
||||
keys = click_ctx.obj.auth.keys
|
||||
|
||||
access_token: str = get_token(name="access_token") # type: ignore
|
||||
refresh_token: str = get_token(name="refresh_token") # type: ignore
|
||||
id_token: str = get_token(name="id_token") # type: ignore
|
||||
|
||||
if access_token and keys:
|
||||
try:
|
||||
token = get_token_data(access_token, keys, silent_if_expired=True)
|
||||
client.token = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"id_token": id_token,
|
||||
"token_type": "bearer",
|
||||
"expires_at": token.get("exp", None), # type: ignore
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
clean_session(client)
|
||||
|
||||
|
||||
def proxy_options(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator that defines proxy options for Click commands.
|
||||
|
||||
Options defined per command, this will override the proxy settings defined in the
|
||||
config.ini file.
|
||||
|
||||
Args:
|
||||
func (Callable): The Click command function.
|
||||
|
||||
Returns:
|
||||
Callable: The wrapped Click command function with proxy options.
|
||||
"""
|
||||
func = click.option(
|
||||
"--proxy-protocol",
|
||||
type=click.Choice(["http", "https"]),
|
||||
default="https",
|
||||
cls=DependentOption,
|
||||
required_options=["proxy_host"],
|
||||
help=CLI_PROXY_PROTOCOL_HELP,
|
||||
)(func)
|
||||
func = click.option(
|
||||
"--proxy-port",
|
||||
multiple=False,
|
||||
type=int,
|
||||
default=80,
|
||||
cls=DependentOption,
|
||||
required_options=["proxy_host"],
|
||||
help=CLI_PROXY_PORT_HELP,
|
||||
)(func)
|
||||
func = click.option(
|
||||
"--proxy-host", multiple=False, type=str, default=None, help=CLI_PROXY_HOST_HELP
|
||||
)(func)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def auth_options(stage: bool = True) -> Callable:
|
||||
"""
|
||||
Decorator that defines authentication options for Click commands.
|
||||
|
||||
Args:
|
||||
stage (bool): Whether to include the stage option.
|
||||
|
||||
Returns:
|
||||
Callable: The decorator function.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func = click.option(
|
||||
"--key", default=None, envvar="SAFETY_API_KEY", help=CLI_KEY_HELP
|
||||
)(func)
|
||||
|
||||
if stage:
|
||||
func = click.option(
|
||||
"--stage", default=None, envvar="SAFETY_STAGE", help=CLI_STAGE_HELP
|
||||
)(func)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def inject_session(
|
||||
ctx: click.Context,
|
||||
proxy_protocol: Optional[str] = None,
|
||||
proxy_host: Optional[str] = None,
|
||||
proxy_port: Optional[str] = None,
|
||||
key: Optional[str] = None,
|
||||
stage: Optional[Stage] = None,
|
||||
invoked_command: str = "",
|
||||
) -> Any:
|
||||
org: Optional[Organization] = get_organization()
|
||||
|
||||
if not stage:
|
||||
host_stage = get_host_config(key_name="stage")
|
||||
stage = host_stage if host_stage else Stage.development
|
||||
|
||||
proxy_config: Optional[Dict[str, str]] = get_proxy_dict(
|
||||
proxy_protocol, # type: ignore
|
||||
proxy_host, # type: ignore
|
||||
proxy_port, # type: ignore
|
||||
)
|
||||
|
||||
client_session, openid_config = build_client_session(
|
||||
api_key=key, proxies=proxy_config
|
||||
)
|
||||
keys = get_keys(client_session, openid_config)
|
||||
|
||||
auth = Auth(
|
||||
stage=stage,
|
||||
keys=keys,
|
||||
org=org,
|
||||
client_id=CLIENT_ID, # type: ignore
|
||||
client=client_session,
|
||||
code_verifier=generate_token(48),
|
||||
)
|
||||
|
||||
if not ctx.obj:
|
||||
ctx.obj = SafetyCLI()
|
||||
|
||||
ctx.obj.auth = auth
|
||||
|
||||
load_auth_session(ctx)
|
||||
|
||||
info = get_auth_info(ctx)
|
||||
|
||||
if info:
|
||||
ctx.obj.auth.name = info.get("name")
|
||||
ctx.obj.auth.email = info.get("email")
|
||||
ctx.obj.auth.email_verified = is_email_verified(info) # type: ignore
|
||||
SafetyContext().account = info["email"]
|
||||
else:
|
||||
SafetyContext().account = ""
|
||||
|
||||
@ctx.call_on_close
|
||||
def clean_up_on_close():
|
||||
LOG.debug("Closing requests session.")
|
||||
ctx.obj.auth.client.close()
|
||||
|
||||
if ctx.obj.event_bus:
|
||||
from safety.events.utils import (
|
||||
create_internal_event,
|
||||
InternalEventType,
|
||||
InternalPayload,
|
||||
)
|
||||
|
||||
payload = InternalPayload(ctx=ctx)
|
||||
|
||||
flush_event = create_internal_event(
|
||||
event_type=InternalEventType.FLUSH_SECURITY_TRACES, payload=payload
|
||||
)
|
||||
close_event = create_internal_event(
|
||||
event_type=InternalEventType.CLOSE_RESOURCES, payload=payload
|
||||
)
|
||||
|
||||
flush_future = ctx.obj.event_bus.emit(flush_event)
|
||||
close_future = ctx.obj.event_bus.emit(close_event)
|
||||
|
||||
# Wait for both events to be processed
|
||||
if flush_future and close_future:
|
||||
try:
|
||||
flush_future.result()
|
||||
close_future.result()
|
||||
except Exception as e:
|
||||
LOG.warning(f"Error waiting for events to process: {e}")
|
||||
|
||||
ctx.obj.event_bus.stop()
|
||||
Reference in New Issue
Block a user