309 lines
9.8 KiB
Python
309 lines
9.8 KiB
Python
# type: ignore
|
|
import http.server
|
|
import json
|
|
import logging
|
|
import random
|
|
import socket
|
|
import sys
|
|
import time
|
|
from typing import Any, Optional, Dict, Tuple
|
|
import urllib.parse
|
|
import threading
|
|
import click
|
|
|
|
from safety.auth.utils import is_jupyter_notebook
|
|
from safety.console import main_console as console
|
|
|
|
from safety.auth.constants import (
|
|
AUTH_SERVER_URL,
|
|
CLI_AUTH_SUCCESS,
|
|
CLI_LOGOUT_SUCCESS,
|
|
HOST,
|
|
)
|
|
from safety.auth.main import save_auth_config
|
|
from rich.prompt import Prompt
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
def find_available_port() -> Optional[int]:
|
|
"""
|
|
Find an available port on localhost within the dynamic port range (49152-65536).
|
|
|
|
Returns:
|
|
Optional[int]: An available port number, or None if no ports are available.
|
|
"""
|
|
# Dynamic ports IANA
|
|
port_range = list(range(49152, 65536))
|
|
random.shuffle(port_range)
|
|
|
|
for port in port_range:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
try:
|
|
s.connect(("localhost", port))
|
|
# If the connect succeeds, the port is already in use
|
|
except socket.error:
|
|
# If the connect fails, the port is available
|
|
return port
|
|
|
|
return None
|
|
|
|
|
|
def auth_process(
|
|
code: str, state: str, initial_state: str, code_verifier: str, client: Any
|
|
) -> Any:
|
|
"""
|
|
Process the authentication callback and exchange the authorization code for tokens.
|
|
|
|
Args:
|
|
code (str): The authorization code.
|
|
state (str): The state parameter from the callback.
|
|
initial_state (str): The initial state parameter.
|
|
code_verifier (str): The code verifier for PKCE.
|
|
client (Any): The OAuth client.
|
|
|
|
Returns:
|
|
Any: The user information.
|
|
|
|
Raises:
|
|
SystemExit: If there is an error during authentication.
|
|
"""
|
|
err = None
|
|
|
|
if initial_state is None or initial_state != state:
|
|
err = (
|
|
"The state parameter value provided does not match the expected "
|
|
"value. The state parameter is used to protect against Cross-Site "
|
|
"Request Forgery (CSRF) attacks. For security reasons, the "
|
|
"authorization process cannot proceed with an invalid state "
|
|
"parameter value. Please try again, ensuring that the state "
|
|
"parameter value provided in the authorization request matches "
|
|
"the value returned in the callback."
|
|
)
|
|
|
|
if err:
|
|
click.secho(f"Error: {err}", fg="red")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
tokens = client.fetch_token(
|
|
url=f"{AUTH_SERVER_URL}/oauth/token",
|
|
code_verifier=code_verifier,
|
|
client_id=client.client_id,
|
|
grant_type="authorization_code",
|
|
code=code,
|
|
)
|
|
|
|
save_auth_config(
|
|
access_token=tokens["access_token"],
|
|
id_token=tokens["id_token"],
|
|
refresh_token=tokens["refresh_token"],
|
|
)
|
|
return client.fetch_user_info()
|
|
|
|
except Exception as e:
|
|
LOG.exception(e)
|
|
sys.exit(1)
|
|
|
|
|
|
class CallbackHandler(http.server.BaseHTTPRequestHandler):
|
|
def auth(self, code: str, state: str, err: str, error_description: str) -> None:
|
|
"""
|
|
Handle the authentication callback.
|
|
|
|
Args:
|
|
code (str): The authorization code.
|
|
state (str): The state parameter.
|
|
err (str): The error message, if any.
|
|
error_description (str): The error description, if any.
|
|
"""
|
|
initial_state = self.server.initial_state
|
|
ctx = self.server.ctx
|
|
|
|
result = auth_process(
|
|
code=code,
|
|
state=state,
|
|
initial_state=initial_state,
|
|
code_verifier=ctx.obj.auth.code_verifier,
|
|
client=ctx.obj.auth.client,
|
|
)
|
|
|
|
self.server.callback = result
|
|
self.do_redirect(location=CLI_AUTH_SUCCESS, params={})
|
|
|
|
def logout(self) -> None:
|
|
"""
|
|
Handle the logout callback.
|
|
"""
|
|
ctx = self.server.ctx
|
|
uri = CLI_LOGOUT_SUCCESS
|
|
|
|
if ctx.obj.auth.org:
|
|
uri = f"{uri}&org_id={ctx.obj.auth.org.id}"
|
|
|
|
self.do_redirect(location=CLI_LOGOUT_SUCCESS, params={})
|
|
|
|
def do_GET(self) -> None:
|
|
"""
|
|
Handle GET requests.
|
|
"""
|
|
query = urllib.parse.urlparse(self.path).query
|
|
params = urllib.parse.parse_qs(query)
|
|
callback_type: Optional[str] = None
|
|
|
|
try:
|
|
c_type = params.get("type", [])
|
|
if (
|
|
isinstance(c_type, list)
|
|
and len(c_type) == 1
|
|
and isinstance(c_type[0], str)
|
|
):
|
|
callback_type = c_type[0]
|
|
except Exception:
|
|
msg = "Unable to process the callback, try again."
|
|
self.send_error(400, msg)
|
|
click.secho("Unable to process the callback, try again.")
|
|
return
|
|
|
|
if callback_type == "logout":
|
|
self.logout()
|
|
return
|
|
|
|
code = params.get("code", [""])[0]
|
|
state = params.get("state", [""])[0]
|
|
err = params.get("error", [""])[0]
|
|
error_description = params.get("error_description", [""])[0]
|
|
|
|
self.auth(code=code, state=state, err=err, error_description=error_description)
|
|
|
|
def do_redirect(self, location: str, params: Dict) -> None:
|
|
"""
|
|
Redirect the client to the specified location.
|
|
|
|
Args:
|
|
location (str): The URL to redirect to.
|
|
params (dict): Additional parameters for the redirection.
|
|
"""
|
|
self.send_response(302)
|
|
self.send_header("Location", location)
|
|
self.send_header("Connection", "close")
|
|
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate")
|
|
self.end_headers()
|
|
|
|
def log_message(self, format: str, *args: Any) -> None:
|
|
"""
|
|
Log an arbitrary message.
|
|
|
|
Args:
|
|
format (str): The format string.
|
|
args (Any): Arguments for the format string.
|
|
"""
|
|
LOG.info(format % args)
|
|
|
|
|
|
def process_browser_callback(uri: str, **kwargs: Any) -> Any:
|
|
"""
|
|
Process the browser callback for authentication.
|
|
|
|
Args:
|
|
uri (str): The authorization URL.
|
|
**kwargs (Any): Additional keyword arguments.
|
|
|
|
Returns:
|
|
Any: The user information.
|
|
|
|
Raises:
|
|
SystemExit: If there is an error during the process.
|
|
"""
|
|
|
|
class ThreadedHTTPServer(http.server.HTTPServer):
|
|
def __init__(self, server_address: Tuple, RequestHandlerClass: Any) -> None:
|
|
"""
|
|
Initialize the ThreadedHTTPServer.
|
|
Args:
|
|
server_address (Tuple): The server address as a tuple (host, port).
|
|
RequestHandlerClass (Any): The request handler class.
|
|
"""
|
|
super().__init__(server_address, RequestHandlerClass)
|
|
self.initial_state = None
|
|
self.ctx = None
|
|
self.callback = None
|
|
self.timeout_reached = False
|
|
|
|
def handle_timeout(self) -> None:
|
|
"""
|
|
Handle server timeout.
|
|
"""
|
|
self.timeout_reached = True
|
|
return super().handle_timeout()
|
|
|
|
PORT = find_available_port()
|
|
|
|
if not PORT:
|
|
click.secho("No available ports.")
|
|
sys.exit(1)
|
|
|
|
try:
|
|
headless = kwargs.get("headless", False)
|
|
initial_state = kwargs.get("initial_state", None)
|
|
ctx = kwargs.get("ctx", None)
|
|
message = "Copy and paste this URL into your browser:\n:icon_warning: Ensure there are no extra spaces, especially at line breaks, as they may break the link."
|
|
|
|
if not headless:
|
|
# Start a threaded HTTP server to handle the callback
|
|
server = ThreadedHTTPServer((HOST, PORT), CallbackHandler)
|
|
server.initial_state = initial_state
|
|
server.timeout = kwargs.get("timeout", 600)
|
|
server.ctx = ctx
|
|
server_thread = threading.Thread(target=server.handle_request)
|
|
server_thread.start()
|
|
message = "If the browser does not automatically open in 5 seconds, copy and paste this url into your browser:"
|
|
|
|
target = uri if headless else f"{uri}&port={PORT}"
|
|
|
|
if is_jupyter_notebook():
|
|
console.print(f"{message} {target}")
|
|
else:
|
|
console.print(f"{message} [link={target}]{target}[/link]")
|
|
|
|
if headless:
|
|
# Handle the headless mode where user manually provides the response
|
|
exchange_data = None
|
|
while not exchange_data:
|
|
auth_code_text = Prompt.ask(
|
|
"Paste the response here", default=None, console=console
|
|
)
|
|
try:
|
|
exchange_data = json.loads(auth_code_text)
|
|
state = exchange_data["state"]
|
|
code = exchange_data["code"]
|
|
except Exception:
|
|
code = state = None
|
|
|
|
return auth_process(
|
|
code=code,
|
|
state=state,
|
|
initial_state=initial_state,
|
|
code_verifier=ctx.obj.auth.code_verifier,
|
|
client=ctx.obj.auth.client,
|
|
)
|
|
else:
|
|
# Wait for the browser authentication in non-headless mode
|
|
console.print()
|
|
wait_msg = "waiting for browser authentication"
|
|
with console.status(wait_msg, spinner="bouncingBar"):
|
|
time.sleep(2)
|
|
click.launch(target)
|
|
server_thread.join()
|
|
|
|
except OSError as e:
|
|
if e.errno == socket.errno.EADDRINUSE:
|
|
reason = f"The port {HOST}:{PORT} is currently being used by another application or process. Please choose a different port or terminate the conflicting application/process to free up the port."
|
|
else:
|
|
reason = "An error occurred while performing this operation."
|
|
|
|
click.secho(reason)
|
|
sys.exit(1)
|
|
|
|
return server.callback
|