This commit is contained in:
Iliyan Angelov
2025-12-01 06:50:10 +02:00
parent 91f51bc6fe
commit 62c1fe5951
4682 changed files with 544807 additions and 31208 deletions

View File

@@ -1,15 +1,54 @@
from typing import Optional
from typing import Optional, Union
from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
class APIKeyBase(SecurityBase):
pass
def __init__(
self,
location: APIKeyIn,
name: str,
description: Union[str, None],
scheme_name: Union[str, None],
auto_error: bool,
):
self.auto_error = auto_error
self.model: APIKey = APIKey(
**{"in": location},
name=name,
description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
def make_not_authenticated_error(self) -> HTTPException:
"""
The WWW-Authenticate header is not standardized for API Key authentication but
the HTTP specification requires that an error of 401 "Unauthorized" must
include a WWW-Authenticate header.
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
For this, this method sends a custom challenge `APIKey`.
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "APIKey"},
)
def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
if not api_key:
if self.auto_error:
raise self.make_not_authenticated_error()
return None
return api_key
class APIKeyQuery(APIKeyBase):
@@ -76,7 +115,7 @@ class APIKeyQuery(APIKeyBase):
Doc(
"""
By default, if the query parameter is not provided, `APIKeyQuery` will
automatically cancel the request and sebd the client an error.
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the query parameter is not
available, instead of erroring out, the dependency result will be
@@ -91,24 +130,17 @@ class APIKeyQuery(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.query}, # type: ignore[arg-type]
super().__init__(
location=APIKeyIn.query,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
return self.check_api_key(api_key)
class APIKeyHeader(APIKeyBase):
@@ -186,24 +218,17 @@ class APIKeyHeader(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.header}, # type: ignore[arg-type]
super().__init__(
location=APIKeyIn.header,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
return self.check_api_key(api_key)
class APIKeyCookie(APIKeyBase):
@@ -281,21 +306,14 @@ class APIKeyCookie(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.cookie}, # type: ignore[arg-type]
super().__init__(
location=APIKeyIn.cookie,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
else:
return None
return api_key
return self.check_api_key(api_key)

View File

@@ -1,7 +1,8 @@
import binascii
from base64 import b64decode
from typing import Optional
from typing import Dict, Optional
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
@@ -9,13 +10,13 @@ from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
class HTTPBasicCredentials(BaseModel):
"""
The HTTP Basic credendials given as the result of using `HTTPBasic` in a
The HTTP Basic credentials given as the result of using `HTTPBasic` in a
dependency.
Read more about it in the
@@ -75,10 +76,22 @@ class HTTPBase(SecurityBase):
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme=scheme, description=description)
self.model: HTTPBaseModel = HTTPBaseModel(
scheme=scheme, description=description
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_authenticate_headers(self) -> Dict[str, str]:
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=self.make_authenticate_headers(),
)
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
@@ -86,9 +99,7 @@ class HTTPBase(SecurityBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@@ -98,6 +109,8 @@ class HTTPBasic(HTTPBase):
"""
HTTP Basic authentication.
Ref: https://datatracker.ietf.org/doc/html/rfc7617
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
@@ -184,36 +197,28 @@ class HTTPBasic(HTTPBase):
self.realm = realm
self.auto_error = auto_error
def make_authenticate_headers(self) -> Dict[str, str]:
if self.realm:
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
else:
unauthorized_headers = {"WWW-Authenticate": "Basic"}
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=unauthorized_headers,
)
raise self.make_not_authenticated_error()
else:
return None
invalid_user_credentials_exc = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers=unauthorized_headers,
)
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc # noqa: B904
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
raise self.make_not_authenticated_error() from e
username, separator, password = data.partition(":")
if not separator:
raise invalid_user_credentials_exc
raise self.make_not_authenticated_error()
return HTTPBasicCredentials(username=username, password=password)
@@ -277,7 +282,7 @@ class HTTPBearer(HTTPBase):
bool,
Doc(
"""
By default, if the HTTP Bearer token not provided (in an
By default, if the HTTP Bearer token is not provided (in an
`Authorization` header), `HTTPBearer` will automatically cancel the
request and send the client an error.
@@ -305,17 +310,12 @@ class HTTPBearer(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@@ -325,6 +325,12 @@ class HTTPDigest(HTTPBase):
"""
HTTP Digest authentication.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full Digest scheme, you would need to to subclass it
and implement it in your code.
Ref: https://datatracker.ietf.org/doc/html/rfc7616
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
@@ -380,7 +386,7 @@ class HTTPDigest(HTTPBase):
bool,
Doc(
"""
By default, if the HTTP Digest not provided, `HTTPDigest` will
By default, if the HTTP Digest is not provided, `HTTPDigest` will
automatically cancel the request and send the client an error.
If `auto_error` is set to `False`, when the HTTP Digest is not
@@ -407,14 +413,12 @@ class HTTPDigest(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "digest":
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
if self.auto_error:
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional, Union, cast
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuth2 as OAuth2Model
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
@@ -7,10 +8,10 @@ from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from starlette.status import HTTP_401_UNAUTHORIZED
# TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from typing_extensions import Annotated
class OAuth2PasswordRequestForm:
@@ -52,9 +53,9 @@ class OAuth2PasswordRequestForm:
```
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
You could have custom internal logic to separate it by colon caracters (`:`) or
You could have custom internal logic to separate it by colon characters (`:`) or
similar, and get the two parts `items` and `read`. Many applications do that to
group and organize permisions, you could do it as well in your application, just
group and organize permissions, you could do it as well in your application, just
know that that it is application specific, it's not part of the specification.
"""
@@ -63,7 +64,7 @@ class OAuth2PasswordRequestForm:
*,
grant_type: Annotated[
Union[str, None],
Form(pattern="password"),
Form(pattern="^password$"),
Doc(
"""
The OAuth2 spec says it is required and MUST be the fixed string
@@ -85,11 +86,11 @@ class OAuth2PasswordRequestForm:
],
password: Annotated[
str,
Form(),
Form(json_schema_extra={"format": "password"}),
Doc(
"""
`password` string. The OAuth2 spec requires the exact field name
`password".
`password`.
"""
),
],
@@ -130,7 +131,7 @@ class OAuth2PasswordRequestForm:
] = None,
client_secret: Annotated[
Union[str, None],
Form(),
Form(json_schema_extra={"format": "password"}),
Doc(
"""
If there's a `client_password` (and a `client_id`), they can be sent
@@ -194,9 +195,9 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
```
Note that for OAuth2 the scope `items:read` is a single scope in an opaque string.
You could have custom internal logic to separate it by colon caracters (`:`) or
You could have custom internal logic to separate it by colon characters (`:`) or
similar, and get the two parts `items` and `read`. Many applications do that to
group and organize permisions, you could do it as well in your application, just
group and organize permissions, you could do it as well in your application, just
know that that it is application specific, it's not part of the specification.
@@ -217,7 +218,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
self,
grant_type: Annotated[
str,
Form(pattern="password"),
Form(pattern="^password$"),
Doc(
"""
The OAuth2 spec says it is required and MUST be the fixed string
@@ -243,7 +244,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
Doc(
"""
`password` string. The OAuth2 spec requires the exact field name
`password".
`password`.
"""
),
],
@@ -353,7 +354,7 @@ class OAuth2(SecurityBase):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
@@ -376,13 +377,33 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
"""
The OAuth 2 specification doesn't define the challenge that should be used,
because a `Bearer` token is not really the only option to authenticate.
But declaring any other authentication challenge would be application-specific
as it's not defined in the specification.
For practical reasons, this method uses the `Bearer` challenge by default, as
it's probably the most common one.
If you are implementing an OAuth2 authentication scheme other than the provided
ones in FastAPI (based on bearer tokens), you might want to override this.
Ref: https://datatracker.ietf.org/doc/html/rfc6749
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return authorization
@@ -441,7 +462,7 @@ class OAuth2PasswordBearer(OAuth2):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
@@ -457,11 +478,26 @@ class OAuth2PasswordBearer(OAuth2):
"""
),
] = True,
refreshUrl: Annotated[
Optional[str],
Doc(
"""
The URL to refresh the token and obtain a new one.
"""
),
] = None,
):
if not scopes:
scopes = {}
flows = OAuthFlowsModel(
password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes})
password=cast(
Any,
{
"tokenUrl": tokenUrl,
"refreshUrl": refreshUrl,
"scopes": scopes,
},
)
)
super().__init__(
flows=flows,
@@ -475,11 +511,7 @@ class OAuth2PasswordBearer(OAuth2):
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise self.make_not_authenticated_error()
else:
return None
return param
@@ -543,7 +575,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OAuth2 authentication, it will automatically cancel the request and
send the client an error.
@@ -585,11 +617,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise self.make_not_authenticated_error()
else:
return None # pragma: nocover
return param

View File

@@ -1,17 +1,23 @@
from typing import Optional
from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc # type: ignore [attr-defined]
from starlette.status import HTTP_401_UNAUTHORIZED
from typing_extensions import Annotated
class OpenIdConnect(SecurityBase):
"""
OpenID Connect authentication class. An instance of it would be used as a
dependency.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full OpenIdConnect scheme, for example, it doesn't use
the OpenIDConnect URL. You would need to to subclass it and implement it in your
code.
"""
def __init__(
@@ -49,7 +55,7 @@ class OpenIdConnect(SecurityBase):
bool,
Doc(
"""
By default, if no HTTP Auhtorization header is provided, required for
By default, if no HTTP Authorization header is provided, required for
OpenID Connect authentication, it will automatically cancel the request
and send the client an error.
@@ -72,13 +78,18 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return authorization