updates
This commit is contained in:
@@ -1,18 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import typing
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, ParamSpec
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from starlette._utils import is_async_callable
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
from starlette.responses import RedirectResponse
|
||||
from starlette.websockets import WebSocket
|
||||
|
||||
_CallableType = typing.TypeVar("_CallableType", bound=typing.Callable)
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool:
|
||||
def has_required_scope(conn: HTTPConnection, scopes: Sequence[str]) -> bool:
|
||||
for scope in scopes:
|
||||
if scope not in conn.auth.scopes:
|
||||
return False
|
||||
@@ -20,32 +23,28 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo
|
||||
|
||||
|
||||
def requires(
|
||||
scopes: typing.Union[str, typing.Sequence[str]],
|
||||
scopes: str | Sequence[str],
|
||||
status_code: int = 403,
|
||||
redirect: typing.Optional[str] = None,
|
||||
) -> typing.Callable[[_CallableType], _CallableType]:
|
||||
redirect: str | None = None,
|
||||
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]:
|
||||
scopes_list = [scopes] if isinstance(scopes, str) else list(scopes)
|
||||
|
||||
def decorator(func: typing.Callable) -> typing.Callable:
|
||||
def decorator(
|
||||
func: Callable[_P, Any],
|
||||
) -> Callable[_P, Any]:
|
||||
sig = inspect.signature(func)
|
||||
for idx, parameter in enumerate(sig.parameters.values()):
|
||||
if parameter.name == "request" or parameter.name == "websocket":
|
||||
type_ = parameter.name
|
||||
break
|
||||
else:
|
||||
raise Exception(
|
||||
f'No "request" or "websocket" argument on function "{func}"'
|
||||
)
|
||||
raise Exception(f'No "request" or "websocket" argument on function "{func}"')
|
||||
|
||||
if type_ == "websocket":
|
||||
# Handle websocket functions. (Always async)
|
||||
@functools.wraps(func)
|
||||
async def websocket_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> None:
|
||||
websocket = kwargs.get(
|
||||
"websocket", args[idx] if idx < len(args) else None
|
||||
)
|
||||
async def websocket_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
||||
websocket = kwargs.get("websocket", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(websocket, WebSocket)
|
||||
|
||||
if not has_required_scope(websocket, scopes_list):
|
||||
@@ -58,19 +57,14 @@ def requires(
|
||||
elif is_async_callable(func):
|
||||
# Handle async request/response functions.
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(
|
||||
*args: typing.Any, **kwargs: typing.Any
|
||||
) -> Response:
|
||||
async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
|
||||
request = kwargs.get("request", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = "{redirect_path}?{orig_request}".format(
|
||||
redirect_path=request.url_for(redirect),
|
||||
orig_request=orig_request_qparam,
|
||||
)
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return await func(*args, **kwargs)
|
||||
@@ -80,24 +74,21 @@ def requires(
|
||||
else:
|
||||
# Handle sync request/response functions.
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
||||
def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
|
||||
request = kwargs.get("request", args[idx] if idx < len(args) else None)
|
||||
assert isinstance(request, Request)
|
||||
|
||||
if not has_required_scope(request, scopes_list):
|
||||
if redirect is not None:
|
||||
orig_request_qparam = urlencode({"next": str(request.url)})
|
||||
next_url = "{redirect_path}?{orig_request}".format(
|
||||
redirect_path=request.url_for(redirect),
|
||||
orig_request=orig_request_qparam,
|
||||
)
|
||||
next_url = f"{request.url_for(redirect)}?{orig_request_qparam}"
|
||||
return RedirectResponse(url=next_url, status_code=303)
|
||||
raise HTTPException(status_code=status_code)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return sync_wrapper
|
||||
|
||||
return decorator # type: ignore[return-value]
|
||||
return decorator
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
@@ -105,14 +96,12 @@ class AuthenticationError(Exception):
|
||||
|
||||
|
||||
class AuthenticationBackend:
|
||||
async def authenticate(
|
||||
self, conn: HTTPConnection
|
||||
) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]:
|
||||
async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class AuthCredentials:
|
||||
def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None):
|
||||
def __init__(self, scopes: Sequence[str] | None = None):
|
||||
self.scopes = [] if scopes is None else list(scopes)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user