updates
This commit is contained in:
@@ -1,17 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import importlib.util
|
||||
import os
|
||||
import stat
|
||||
import typing
|
||||
from email.utils import parsedate
|
||||
from typing import Union
|
||||
|
||||
import anyio
|
||||
import anyio.to_thread
|
||||
|
||||
from starlette._utils import get_route_path
|
||||
from starlette.datastructures import URL, Headers
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import FileResponse, RedirectResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
PathLike = typing.Union[str, "os.PathLike[str]"]
|
||||
PathLike = Union[str, "os.PathLike[str]"]
|
||||
|
||||
|
||||
class NotModifiedResponse(Response):
|
||||
@@ -27,11 +32,7 @@ class NotModifiedResponse(Response):
|
||||
def __init__(self, headers: Headers):
|
||||
super().__init__(
|
||||
status_code=304,
|
||||
headers={
|
||||
name: value
|
||||
for name, value in headers.items()
|
||||
if name in self.NOT_MODIFIED_HEADERS
|
||||
},
|
||||
headers={name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS},
|
||||
)
|
||||
|
||||
|
||||
@@ -39,10 +40,8 @@ class StaticFiles:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
directory: typing.Optional[PathLike] = None,
|
||||
packages: typing.Optional[
|
||||
typing.List[typing.Union[str, typing.Tuple[str, str]]]
|
||||
] = None,
|
||||
directory: PathLike | None = None,
|
||||
packages: list[str | tuple[str, str]] | None = None,
|
||||
html: bool = False,
|
||||
check_dir: bool = True,
|
||||
follow_symlink: bool = False,
|
||||
@@ -58,11 +57,9 @@ class StaticFiles:
|
||||
|
||||
def get_directories(
|
||||
self,
|
||||
directory: typing.Optional[PathLike] = None,
|
||||
packages: typing.Optional[
|
||||
typing.List[typing.Union[str, typing.Tuple[str, str]]]
|
||||
] = None,
|
||||
) -> typing.List[PathLike]:
|
||||
directory: PathLike | None = None,
|
||||
packages: list[str | tuple[str, str]] | None = None,
|
||||
) -> list[PathLike]:
|
||||
"""
|
||||
Given `directory` and `packages` arguments, return a list of all the
|
||||
directories that should be used for serving static files from.
|
||||
@@ -79,12 +76,10 @@ class StaticFiles:
|
||||
spec = importlib.util.find_spec(package)
|
||||
assert spec is not None, f"Package {package!r} could not be found."
|
||||
assert spec.origin is not None, f"Package {package!r} could not be found."
|
||||
package_directory = os.path.normpath(
|
||||
os.path.join(spec.origin, "..", statics_dir)
|
||||
package_directory = os.path.normpath(os.path.join(spec.origin, "..", statics_dir))
|
||||
assert os.path.isdir(package_directory), (
|
||||
f"Directory '{statics_dir!r}' in package {package!r} could not be found."
|
||||
)
|
||||
assert os.path.isdir(
|
||||
package_directory
|
||||
), f"Directory '{statics_dir!r}' in package {package!r} could not be found."
|
||||
directories.append(package_directory)
|
||||
|
||||
return directories
|
||||
@@ -108,7 +103,8 @@ class StaticFiles:
|
||||
Given the ASGI scope, return the `path` string to serve up,
|
||||
with OS specific path separators, and any '..', '.' components removed.
|
||||
"""
|
||||
return os.path.normpath(os.path.join(*scope["path"].split("/")))
|
||||
route_path = get_route_path(scope)
|
||||
return os.path.normpath(os.path.join(*route_path.split("/")))
|
||||
|
||||
async def get_response(self, path: str, scope: Scope) -> Response:
|
||||
"""
|
||||
@@ -118,13 +114,15 @@ class StaticFiles:
|
||||
raise HTTPException(status_code=405)
|
||||
|
||||
try:
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, path
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, path)
|
||||
except PermissionError:
|
||||
raise HTTPException(status_code=401)
|
||||
except OSError:
|
||||
raise
|
||||
except OSError as exc:
|
||||
# Filename is too long, so it can't be a valid static file.
|
||||
if exc.errno == errno.ENAMETOOLONG:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
raise exc
|
||||
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
# We have a static file to serve.
|
||||
@@ -134,9 +132,7 @@ class StaticFiles:
|
||||
# We're in HTML mode, and have got a directory URL.
|
||||
# Check if we have 'index.html' file to serve.
|
||||
index_path = os.path.join(path, "index.html")
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, index_path
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, index_path)
|
||||
if stat_result is not None and stat.S_ISREG(stat_result.st_mode):
|
||||
if not scope["path"].endswith("/"):
|
||||
# Directory URLs should redirect to always end in "/".
|
||||
@@ -147,31 +143,22 @@ class StaticFiles:
|
||||
|
||||
if self.html:
|
||||
# Check for '404.html' if we're in HTML mode.
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(
|
||||
self.lookup_path, "404.html"
|
||||
)
|
||||
full_path, stat_result = await anyio.to_thread.run_sync(self.lookup_path, "404.html")
|
||||
if stat_result and stat.S_ISREG(stat_result.st_mode):
|
||||
return FileResponse(
|
||||
full_path,
|
||||
stat_result=stat_result,
|
||||
method=scope["method"],
|
||||
status_code=404,
|
||||
)
|
||||
return FileResponse(full_path, stat_result=stat_result, status_code=404)
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
def lookup_path(
|
||||
self, path: str
|
||||
) -> typing.Tuple[str, typing.Optional[os.stat_result]]:
|
||||
def lookup_path(self, path: str) -> tuple[str, os.stat_result | None]:
|
||||
for directory in self.all_directories:
|
||||
joined_path = os.path.join(directory, path)
|
||||
if self.follow_symlink:
|
||||
full_path = os.path.abspath(joined_path)
|
||||
directory = os.path.abspath(directory)
|
||||
else:
|
||||
full_path = os.path.realpath(joined_path)
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != directory:
|
||||
# Don't allow misbehaving clients to break out of the static files
|
||||
# directory.
|
||||
directory = os.path.realpath(directory)
|
||||
if os.path.commonpath([full_path, directory]) != str(directory):
|
||||
# Don't allow misbehaving clients to break out of the static files directory.
|
||||
continue
|
||||
try:
|
||||
return full_path, os.stat(full_path)
|
||||
@@ -186,12 +173,9 @@ class StaticFiles:
|
||||
scope: Scope,
|
||||
status_code: int = 200,
|
||||
) -> Response:
|
||||
method = scope["method"]
|
||||
request_headers = Headers(scope=scope)
|
||||
|
||||
response = FileResponse(
|
||||
full_path, status_code=status_code, stat_result=stat_result, method=method
|
||||
)
|
||||
response = FileResponse(full_path, status_code=status_code, stat_result=stat_result)
|
||||
if self.is_not_modified(response.headers, request_headers):
|
||||
return NotModifiedResponse(response.headers)
|
||||
return response
|
||||
@@ -208,37 +192,24 @@ class StaticFiles:
|
||||
try:
|
||||
stat_result = await anyio.to_thread.run_sync(os.stat, self.directory)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(
|
||||
f"StaticFiles directory '{self.directory}' does not exist."
|
||||
)
|
||||
raise RuntimeError(f"StaticFiles directory '{self.directory}' does not exist.")
|
||||
if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)):
|
||||
raise RuntimeError(
|
||||
f"StaticFiles path '{self.directory}' is not a directory."
|
||||
)
|
||||
raise RuntimeError(f"StaticFiles path '{self.directory}' is not a directory.")
|
||||
|
||||
def is_not_modified(
|
||||
self, response_headers: Headers, request_headers: Headers
|
||||
) -> bool:
|
||||
def is_not_modified(self, response_headers: Headers, request_headers: Headers) -> bool:
|
||||
"""
|
||||
Given the request and response headers, return `True` if an HTTP
|
||||
"Not Modified" response could be returned instead.
|
||||
"""
|
||||
try:
|
||||
if_none_match = request_headers["if-none-match"]
|
||||
if if_none_match := request_headers.get("if-none-match"):
|
||||
# The "etag" header is added by FileResponse, so it's always present.
|
||||
etag = response_headers["etag"]
|
||||
if if_none_match == etag:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
return etag in [tag.strip(" W/") for tag in if_none_match.split(",")]
|
||||
|
||||
try:
|
||||
if_modified_since = parsedate(request_headers["if-modified-since"])
|
||||
last_modified = parsedate(response_headers["last-modified"])
|
||||
if (
|
||||
if_modified_since is not None
|
||||
and last_modified is not None
|
||||
and if_modified_since >= last_modified
|
||||
):
|
||||
if if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified:
|
||||
return True
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user