updates
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import binascii
|
||||
from base64 import b64decode
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
|
||||
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
|
||||
@@ -9,13 +10,13 @@ from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class HTTPBasicCredentials(BaseModel):
|
||||
"""
|
||||
The HTTP Basic credendials given as the result of using `HTTPBasic` in a
|
||||
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
|
||||
dependency.
|
||||
|
||||
Read more about it in the
|
||||
@@ -75,10 +76,22 @@ class HTTPBase(SecurityBase):
|
||||
description: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
self.model = HTTPBaseModel(scheme=scheme, description=description)
|
||||
self.model: HTTPBaseModel = HTTPBaseModel(
|
||||
scheme=scheme, description=description
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_authenticate_headers(self) -> Dict[str, str]:
|
||||
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers=self.make_authenticate_headers(),
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request
|
||||
) -> Optional[HTTPAuthorizationCredentials]:
|
||||
@@ -86,9 +99,7 @@ class HTTPBase(SecurityBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
@@ -98,6 +109,8 @@ class HTTPBasic(HTTPBase):
|
||||
"""
|
||||
HTTP Basic authentication.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc7617
|
||||
|
||||
## Usage
|
||||
|
||||
Create an instance object and use that object as the dependency in `Depends()`.
|
||||
@@ -184,36 +197,28 @@ class HTTPBasic(HTTPBase):
|
||||
self.realm = realm
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_authenticate_headers(self) -> Dict[str, str]:
|
||||
if self.realm:
|
||||
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
||||
return {"WWW-Authenticate": "Basic"}
|
||||
|
||||
async def __call__( # type: ignore
|
||||
self, request: Request
|
||||
) -> Optional[HTTPBasicCredentials]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if self.realm:
|
||||
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
||||
else:
|
||||
unauthorized_headers = {"WWW-Authenticate": "Basic"}
|
||||
if not authorization or scheme.lower() != "basic":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers=unauthorized_headers,
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
invalid_user_credentials_exc = HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers=unauthorized_headers,
|
||||
)
|
||||
try:
|
||||
data = b64decode(param).decode("ascii")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise invalid_user_credentials_exc # noqa: B904
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
|
||||
raise self.make_not_authenticated_error() from e
|
||||
username, separator, password = data.partition(":")
|
||||
if not separator:
|
||||
raise invalid_user_credentials_exc
|
||||
raise self.make_not_authenticated_error()
|
||||
return HTTPBasicCredentials(username=username, password=password)
|
||||
|
||||
|
||||
@@ -277,7 +282,7 @@ class HTTPBearer(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Bearer token not provided (in an
|
||||
By default, if the HTTP Bearer token is not provided (in an
|
||||
`Authorization` header), `HTTPBearer` will automatically cancel the
|
||||
request and send the client an error.
|
||||
|
||||
@@ -305,17 +310,12 @@ class HTTPBearer(HTTPBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
@@ -325,6 +325,12 @@ class HTTPDigest(HTTPBase):
|
||||
"""
|
||||
HTTP Digest authentication.
|
||||
|
||||
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
|
||||
but it doesn't implement the full Digest scheme, you would need to to subclass it
|
||||
and implement it in your code.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc7616
|
||||
|
||||
## Usage
|
||||
|
||||
Create an instance object and use that object as the dependency in `Depends()`.
|
||||
@@ -380,7 +386,7 @@ class HTTPDigest(HTTPBase):
|
||||
bool,
|
||||
Doc(
|
||||
"""
|
||||
By default, if the HTTP Digest not provided, `HTTPDigest` will
|
||||
By default, if the HTTP Digest is not provided, `HTTPDigest` will
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the HTTP Digest is not
|
||||
@@ -407,14 +413,12 @@ class HTTPDigest(HTTPBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "digest":
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
if self.auto_error:
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
|
||||
Reference in New Issue
Block a user