updates
This commit is contained in:
@@ -1,15 +1,54 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class APIKeyBase(SecurityBase):
|
||||
pass
|
||||
def __init__(
|
||||
self,
|
||||
location: APIKeyIn,
|
||||
name: str,
|
||||
description: Union[str, None],
|
||||
scheme_name: Union[str, None],
|
||||
auto_error: bool,
|
||||
):
|
||||
self.auto_error = auto_error
|
||||
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": location},
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
"""
|
||||
The WWW-Authenticate header is not standardized for API Key authentication but
|
||||
the HTTP specification requires that an error of 401 "Unauthorized" must
|
||||
include a WWW-Authenticate header.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
|
||||
|
||||
For this, this method sends a custom challenge `APIKey`.
|
||||
"""
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "APIKey"},
|
||||
)
|
||||
|
||||
def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise self.make_not_authenticated_error()
|
||||
return None
|
||||
return api_key
|
||||
|
||||
|
||||
class APIKeyQuery(APIKeyBase):
|
||||
@@ -76,7 +115,7 @@ class APIKeyQuery(APIKeyBase):
|
||||
Doc(
|
||||
"""
|
||||
By default, if the query parameter is not provided, `APIKeyQuery` will
|
||||
automatically cancel the request and sebd the client an error.
|
||||
automatically cancel the request and send the client an error.
|
||||
|
||||
If `auto_error` is set to `False`, when the query parameter is not
|
||||
available, instead of erroring out, the dependency result will be
|
||||
@@ -91,24 +130,17 @@ class APIKeyQuery(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.query}, # type: ignore[arg-type]
|
||||
super().__init__(
|
||||
location=APIKeyIn.query,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.query_params.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
|
||||
class APIKeyHeader(APIKeyBase):
|
||||
@@ -186,24 +218,17 @@ class APIKeyHeader(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.header}, # type: ignore[arg-type]
|
||||
super().__init__(
|
||||
location=APIKeyIn.header,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.headers.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
|
||||
class APIKeyCookie(APIKeyBase):
|
||||
@@ -281,21 +306,14 @@ class APIKeyCookie(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
|
||||
super().__init__(
|
||||
location=APIKeyIn.cookie,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.cookies.get(self.model.name)
|
||||
if not api_key:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return api_key
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
Reference in New Issue
Block a user