|
from __future__ import annotations |
|
|
|
import asyncio |
|
import functools |
|
import hashlib |
|
import hmac |
|
import json |
|
import os |
|
import re |
|
import shutil |
|
import sys |
|
from collections import deque |
|
from contextlib import AsyncExitStack, asynccontextmanager |
|
from dataclasses import dataclass as python_dataclass |
|
from datetime import datetime |
|
from pathlib import Path |
|
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper |
|
from typing import ( |
|
TYPE_CHECKING, |
|
AsyncContextManager, |
|
AsyncGenerator, |
|
BinaryIO, |
|
Callable, |
|
List, |
|
Optional, |
|
Tuple, |
|
Union, |
|
) |
|
from urllib.parse import urlparse |
|
|
|
import anyio |
|
import fastapi |
|
import gradio_client.utils as client_utils |
|
import httpx |
|
import multipart |
|
from gradio_client.documentation import document |
|
from multipart.multipart import parse_options_header |
|
from starlette.datastructures import FormData, Headers, MutableHeaders, UploadFile |
|
from starlette.formparsers import MultiPartException, MultipartPart |
|
from starlette.responses import PlainTextResponse, Response |
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send |
|
|
|
from gradio import processing_utils, utils |
|
from gradio.data_classes import PredictBody |
|
from gradio.exceptions import Error |
|
from gradio.helpers import EventData |
|
from gradio.state_holder import SessionState |
|
|
|
if TYPE_CHECKING: |
|
from gradio.blocks import Blocks |
|
from gradio.routes import App |
|
|
|
|
|
class Obj: |
|
""" |
|
Using a class to convert dictionaries into objects. Used by the `Request` class. |
|
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ |
|
""" |
|
|
|
def __init__(self, dict_): |
|
self.__dict__.update(dict_) |
|
for key, value in dict_.items(): |
|
if isinstance(value, (dict, list)): |
|
value = Obj(value) |
|
setattr(self, key, value) |
|
|
|
def __getitem__(self, item): |
|
return self.__dict__[item] |
|
|
|
def __setitem__(self, item, value): |
|
self.__dict__[item] = value |
|
|
|
def __iter__(self): |
|
for key, value in self.__dict__.items(): |
|
if isinstance(value, Obj): |
|
yield (key, dict(value)) |
|
else: |
|
yield (key, value) |
|
|
|
def __contains__(self, item) -> bool: |
|
if item in self.__dict__: |
|
return True |
|
for value in self.__dict__.values(): |
|
if isinstance(value, Obj) and item in value: |
|
return True |
|
return False |
|
|
|
def get(self, item, default=None): |
|
if item in self: |
|
return self.__dict__[item] |
|
return default |
|
|
|
def keys(self): |
|
return self.__dict__.keys() |
|
|
|
def values(self): |
|
return self.__dict__.values() |
|
|
|
def items(self): |
|
return self.__dict__.items() |
|
|
|
def __str__(self) -> str: |
|
return str(self.__dict__) |
|
|
|
def __repr__(self) -> str: |
|
return str(self.__dict__) |
|
|
|
|
|
@document() |
|
class Request: |
|
""" |
|
A Gradio request object that can be used to access the request headers, cookies, |
|
query parameters and other information about the request from within the prediction |
|
function. The class is a thin wrapper around the fastapi.Request class. Attributes |
|
of this class include: `headers`, `client`, `query_params`, `session_hash`, and `path_params`. If |
|
auth is enabled, the `username` attribute can be used to get the logged in user. |
|
Example: |
|
import gradio as gr |
|
def echo(text, request: gr.Request): |
|
if request: |
|
print("Request headers dictionary:", request.headers) |
|
print("IP address:", request.client.host) |
|
print("Query parameters:", dict(request.query_params)) |
|
print("Session hash:", request.session_hash) |
|
return text |
|
io = gr.Interface(echo, "textbox", "textbox").launch() |
|
Demos: request_ip_headers |
|
""" |
|
|
|
def __init__( |
|
self, |
|
request: fastapi.Request | None = None, |
|
username: str | None = None, |
|
session_hash: str | None = None, |
|
**kwargs, |
|
): |
|
""" |
|
Can be instantiated with either a fastapi.Request or by manually passing in |
|
attributes (needed for queueing). |
|
Parameters: |
|
request: A fastapi.Request |
|
username: The username of the logged in user (if auth is enabled) |
|
session_hash: The session hash of the current session. It is unique for each page load. |
|
""" |
|
self.request = request |
|
self.username = username |
|
self.session_hash = session_hash |
|
self.kwargs: dict = kwargs |
|
|
|
def dict_to_obj(self, d): |
|
if isinstance(d, dict): |
|
return json.loads(json.dumps(d), object_hook=Obj) |
|
else: |
|
return d |
|
|
|
def __getattr__(self, name): |
|
if self.request: |
|
return self.dict_to_obj(getattr(self.request, name)) |
|
else: |
|
try: |
|
obj = self.kwargs[name] |
|
except KeyError as ke: |
|
raise AttributeError( |
|
f"'Request' object has no attribute '{name}'" |
|
) from ke |
|
return self.dict_to_obj(obj) |
|
|
|
|
|
class FnIndexInferError(Exception): |
|
pass |
|
|
|
|
|
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int: |
|
if body.fn_index is None: |
|
for i, fn in enumerate(app.get_blocks().fns): |
|
if fn.api_name == api_name: |
|
return i |
|
|
|
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.") |
|
else: |
|
return body.fn_index |
|
|
|
|
|
def compile_gr_request( |
|
app: App, |
|
body: PredictBody, |
|
fn_index_inferred: int, |
|
username: Optional[str], |
|
request: Optional[fastapi.Request], |
|
): |
|
|
|
|
|
if app.get_blocks().fns[fn_index_inferred].cancels: |
|
body.data = [body.session_hash] |
|
if body.request: |
|
if body.batched: |
|
gr_request = [Request(username=username, request=request)] |
|
else: |
|
gr_request = Request( |
|
username=username, request=body.request, session_hash=body.session_hash |
|
) |
|
else: |
|
if request is None: |
|
raise ValueError("request must be provided if body.request is None") |
|
gr_request = Request( |
|
username=username, request=request, session_hash=body.session_hash |
|
) |
|
|
|
return gr_request |
|
|
|
|
|
def restore_session_state(app: App, body: PredictBody): |
|
event_id = body.event_id |
|
session_hash = getattr(body, "session_hash", None) |
|
if session_hash is not None: |
|
session_state = app.state_holder[session_hash] |
|
|
|
|
|
|
|
|
|
|
|
if event_id is None: |
|
iterator = None |
|
elif event_id in app.iterators_to_reset: |
|
iterator = None |
|
app.iterators_to_reset.remove(event_id) |
|
else: |
|
iterator = app.iterators.get(event_id) |
|
else: |
|
session_state = SessionState(app.get_blocks()) |
|
iterator = None |
|
|
|
return session_state, iterator |
|
|
|
|
|
def prepare_event_data( |
|
blocks: Blocks, |
|
body: PredictBody, |
|
) -> EventData: |
|
target = body.trigger_id |
|
event_data = EventData( |
|
blocks.blocks.get(target) if target else None, |
|
body.event_data, |
|
) |
|
return event_data |
|
|
|
|
|
async def call_process_api( |
|
app: App, |
|
body: PredictBody, |
|
gr_request: Union[Request, list[Request]], |
|
fn_index_inferred: int, |
|
root_path: str, |
|
): |
|
session_state, iterator = restore_session_state(app=app, body=body) |
|
|
|
dependency = app.get_blocks().fns[fn_index_inferred] |
|
event_data = prepare_event_data(app.get_blocks(), body) |
|
event_id = body.event_id |
|
|
|
session_hash = getattr(body, "session_hash", None) |
|
inputs = body.data |
|
|
|
batch_in_single_out = not body.batched and dependency.batch |
|
if batch_in_single_out: |
|
inputs = [inputs] |
|
|
|
try: |
|
with utils.MatplotlibBackendMananger(): |
|
output = await app.get_blocks().process_api( |
|
fn_index=fn_index_inferred, |
|
inputs=inputs, |
|
request=gr_request, |
|
state=session_state, |
|
iterator=iterator, |
|
session_hash=session_hash, |
|
event_id=event_id, |
|
event_data=event_data, |
|
in_event_listener=True, |
|
simple_format=body.simple_format, |
|
root_path=root_path, |
|
) |
|
iterator = output.pop("iterator", None) |
|
if event_id is not None: |
|
app.iterators[event_id] = iterator |
|
if isinstance(output, Error): |
|
raise output |
|
except BaseException: |
|
iterator = app.iterators.get(event_id) if event_id is not None else None |
|
if iterator is not None: |
|
run_id = id(iterator) |
|
pending_streams: dict[int, list] = ( |
|
app.get_blocks().pending_streams[session_hash].get(run_id, {}) |
|
) |
|
for stream in pending_streams.values(): |
|
stream.append(None) |
|
raise |
|
|
|
if batch_in_single_out: |
|
output["data"] = output["data"][0] |
|
|
|
return output |
|
|
|
|
|
def get_root_url( |
|
request: fastapi.Request, route_path: str, root_path: str | None |
|
) -> str: |
|
""" |
|
Gets the root url of the Gradio app (i.e. the public url of the app) without a trailing slash. |
|
|
|
This is how the root_url is resolved: |
|
1. If a user provides a `root_path` manually that is a full URL, it is returned directly. |
|
2. If the request has an x-forwarded-host header (e.g. because it is behind a proxy), the root url is |
|
constructed from the x-forwarded-host header. In this case, `route_path` is not used to construct the root url. |
|
3. Otherwise, the root url is constructed from the request url. The query parameters and `route_path` are stripped off. |
|
And if a relative `root_path` is provided, and it is not already the subpath of the URL, it is appended to the root url. |
|
|
|
In cases (2) and (3), We also check to see if the x-forwarded-proto header is present, and if so, convert the root url to https. |
|
And if there are multiple hosts in the x-forwarded-host or multiple protocols in the x-forwarded-proto, the first one is used. |
|
""" |
|
|
|
def get_first_header_value(header_name: str): |
|
header_value = request.headers.get(header_name) |
|
if header_value: |
|
return header_value.split(",")[0].strip() |
|
return None |
|
|
|
if root_path and client_utils.is_http_url_like(root_path): |
|
return root_path.rstrip("/") |
|
|
|
x_forwarded_host = get_first_header_value("x-forwarded-host") |
|
root_url = f"http://{x_forwarded_host}" if x_forwarded_host else str(request.url) |
|
root_url = httpx.URL(root_url) |
|
root_url = root_url.copy_with(query=None) |
|
root_url = str(root_url).rstrip("/") |
|
if get_first_header_value("x-forwarded-proto") == "https": |
|
root_url = root_url.replace("http://", "https://") |
|
|
|
route_path = route_path.rstrip("/") |
|
if len(route_path) > 0 and not x_forwarded_host: |
|
root_url = root_url[: -len(route_path)] |
|
root_url = root_url.rstrip("/") |
|
|
|
root_url = httpx.URL(root_url) |
|
if root_path and root_url.path != root_path: |
|
root_url = root_url.copy_with(path=root_path) |
|
|
|
return str(root_url).rstrip("/") |
|
|
|
|
|
def _user_safe_decode(src: bytes, codec: str) -> str: |
|
try: |
|
return src.decode(codec) |
|
except (UnicodeDecodeError, LookupError): |
|
return src.decode("latin-1") |
|
|
|
|
|
class GradioUploadFile(UploadFile): |
|
"""UploadFile with a sha attribute.""" |
|
|
|
def __init__( |
|
self, |
|
file: BinaryIO, |
|
*, |
|
size: int | None = None, |
|
filename: str | None = None, |
|
headers: Headers | None = None, |
|
) -> None: |
|
super().__init__(file, size=size, filename=filename, headers=headers) |
|
self.sha = hashlib.sha1() |
|
|
|
|
|
@python_dataclass(frozen=True) |
|
class FileUploadProgressUnit: |
|
filename: str |
|
chunk_size: int |
|
|
|
|
|
@python_dataclass |
|
class FileUploadProgressTracker: |
|
deque: deque[FileUploadProgressUnit] |
|
is_done: bool |
|
|
|
|
|
class FileUploadProgressNotTrackedError(Exception): |
|
pass |
|
|
|
|
|
class FileUploadProgressNotQueuedError(Exception): |
|
pass |
|
|
|
|
|
class FileUploadProgress: |
|
def __init__(self) -> None: |
|
self._statuses: dict[str, FileUploadProgressTracker] = {} |
|
|
|
def track(self, upload_id: str): |
|
if upload_id not in self._statuses: |
|
self._statuses[upload_id] = FileUploadProgressTracker(deque(), False) |
|
|
|
def append(self, upload_id: str, filename: str, message_bytes: bytes): |
|
if upload_id not in self._statuses: |
|
self.track(upload_id) |
|
queue = self._statuses[upload_id].deque |
|
|
|
if len(queue) == 0: |
|
queue.append(FileUploadProgressUnit(filename, len(message_bytes))) |
|
else: |
|
last_unit = queue.popleft() |
|
if last_unit.filename != filename: |
|
queue.append(FileUploadProgressUnit(filename, len(message_bytes))) |
|
else: |
|
queue.append( |
|
FileUploadProgressUnit( |
|
filename, |
|
last_unit.chunk_size + len(message_bytes), |
|
) |
|
) |
|
|
|
def set_done(self, upload_id: str): |
|
if upload_id not in self._statuses: |
|
self.track(upload_id) |
|
self._statuses[upload_id].is_done = True |
|
|
|
def is_done(self, upload_id: str): |
|
if upload_id not in self._statuses: |
|
raise FileUploadProgressNotTrackedError() |
|
return self._statuses[upload_id].is_done |
|
|
|
def stop_tracking(self, upload_id: str): |
|
if upload_id in self._statuses: |
|
del self._statuses[upload_id] |
|
|
|
def pop(self, upload_id: str) -> FileUploadProgressUnit: |
|
if upload_id not in self._statuses: |
|
raise FileUploadProgressNotTrackedError() |
|
try: |
|
return self._statuses[upload_id].deque.pop() |
|
except IndexError as e: |
|
raise FileUploadProgressNotQueuedError() from e |
|
|
|
|
|
class GradioMultiPartParser: |
|
"""Vendored from starlette.MultipartParser. |
|
|
|
Thanks starlette! |
|
|
|
Made the following modifications |
|
- Use GradioUploadFile instead of UploadFile |
|
- Use NamedTemporaryFile instead of SpooledTemporaryFile |
|
- Compute hash of data as the request is streamed |
|
|
|
""" |
|
|
|
max_file_size = 1024 * 1024 |
|
|
|
def __init__( |
|
self, |
|
headers: Headers, |
|
stream: AsyncGenerator[bytes, None], |
|
*, |
|
max_files: Union[int, float] = 1000, |
|
max_fields: Union[int, float] = 1000, |
|
upload_id: str | None = None, |
|
upload_progress: FileUploadProgress | None = None, |
|
max_file_size: int | float, |
|
) -> None: |
|
self.headers = headers |
|
self.stream = stream |
|
self.max_files = max_files |
|
self.max_fields = max_fields |
|
self.items: List[Tuple[str, Union[str, UploadFile]]] = [] |
|
self.upload_id = upload_id |
|
self.upload_progress = upload_progress |
|
self._current_files = 0 |
|
self._current_fields = 0 |
|
self.max_file_size = max_file_size |
|
self._current_partial_header_name: bytes = b"" |
|
self._current_partial_header_value: bytes = b"" |
|
self._current_part = MultipartPart() |
|
self._charset = "" |
|
self._file_parts_to_write: List[Tuple[MultipartPart, bytes]] = [] |
|
self._file_parts_to_finish: List[MultipartPart] = [] |
|
self._files_to_close_on_error: List[_TemporaryFileWrapper] = [] |
|
|
|
def on_part_begin(self) -> None: |
|
self._current_part = MultipartPart() |
|
|
|
def on_part_data(self, data: bytes, start: int, end: int) -> None: |
|
message_bytes = data[start:end] |
|
if self.upload_progress is not None: |
|
self.upload_progress.append( |
|
self.upload_id, |
|
self._current_part.file.filename, |
|
message_bytes, |
|
) |
|
if self._current_part.file is None: |
|
self._current_part.data += message_bytes |
|
else: |
|
self._file_parts_to_write.append((self._current_part, message_bytes)) |
|
|
|
def on_part_end(self) -> None: |
|
if self._current_part.file is None: |
|
self.items.append( |
|
( |
|
self._current_part.field_name, |
|
_user_safe_decode(self._current_part.data, self._charset), |
|
) |
|
) |
|
else: |
|
self._file_parts_to_finish.append(self._current_part) |
|
|
|
|
|
|
|
self.items.append((self._current_part.field_name, self._current_part.file)) |
|
|
|
def on_header_field(self, data: bytes, start: int, end: int) -> None: |
|
self._current_partial_header_name += data[start:end] |
|
|
|
def on_header_value(self, data: bytes, start: int, end: int) -> None: |
|
self._current_partial_header_value += data[start:end] |
|
|
|
def on_header_end(self) -> None: |
|
field = self._current_partial_header_name.lower() |
|
if field == b"content-disposition": |
|
self._current_part.content_disposition = self._current_partial_header_value |
|
self._current_part.item_headers.append( |
|
(field, self._current_partial_header_value) |
|
) |
|
self._current_partial_header_name = b"" |
|
self._current_partial_header_value = b"" |
|
|
|
def on_headers_finished(self) -> None: |
|
_, options = parse_options_header(self._current_part.content_disposition or b"") |
|
try: |
|
self._current_part.field_name = _user_safe_decode( |
|
options[b"name"], str(self._charset) |
|
) |
|
except KeyError as e: |
|
raise MultiPartException( |
|
'The Content-Disposition header field "name" must be ' "provided." |
|
) from e |
|
if b"filename" in options: |
|
self._current_files += 1 |
|
if self._current_files > self.max_files: |
|
raise MultiPartException( |
|
f"Too many files. Maximum number of files is {self.max_files}." |
|
) |
|
filename = _user_safe_decode(options[b"filename"], str(self._charset)) |
|
tempfile = NamedTemporaryFile(delete=False) |
|
self._files_to_close_on_error.append(tempfile) |
|
self._current_part.file = GradioUploadFile( |
|
file=tempfile, |
|
size=0, |
|
filename=filename, |
|
headers=Headers(raw=self._current_part.item_headers), |
|
) |
|
else: |
|
self._current_fields += 1 |
|
if self._current_fields > self.max_fields: |
|
raise MultiPartException( |
|
f"Too many fields. Maximum number of fields is {self.max_fields}." |
|
) |
|
self._current_part.file = None |
|
|
|
def on_end(self) -> None: |
|
pass |
|
|
|
async def parse(self) -> FormData: |
|
|
|
_, params = parse_options_header(self.headers["Content-Type"]) |
|
charset = params.get(b"charset", "utf-8") |
|
if isinstance(charset, bytes): |
|
charset = charset.decode("latin-1") |
|
self._charset = charset |
|
try: |
|
boundary = params[b"boundary"] |
|
except KeyError as e: |
|
raise MultiPartException("Missing boundary in multipart.") from e |
|
|
|
|
|
callbacks: multipart.multipart.MultipartCallbacks = { |
|
"on_part_begin": self.on_part_begin, |
|
"on_part_data": self.on_part_data, |
|
"on_part_end": self.on_part_end, |
|
"on_header_field": self.on_header_field, |
|
"on_header_value": self.on_header_value, |
|
"on_header_end": self.on_header_end, |
|
"on_headers_finished": self.on_headers_finished, |
|
"on_end": self.on_end, |
|
} |
|
|
|
|
|
parser = multipart.MultipartParser(boundary, callbacks) |
|
try: |
|
|
|
async for chunk in self.stream: |
|
parser.write(chunk) |
|
|
|
|
|
|
|
|
|
|
|
for part, data in self._file_parts_to_write: |
|
assert part.file |
|
await part.file.write(data) |
|
part.file.sha.update(data) |
|
if os.stat(part.file.file.name).st_size > self.max_file_size: |
|
if self.upload_progress is not None: |
|
self.upload_progress.set_done(self.upload_id) |
|
raise MultiPartException( |
|
f"File size exceeded maximum allowed size of {self.max_file_size} bytes." |
|
) |
|
for part in self._file_parts_to_finish: |
|
assert part.file |
|
await part.file.seek(0) |
|
self._file_parts_to_write.clear() |
|
self._file_parts_to_finish.clear() |
|
except MultiPartException as exc: |
|
|
|
for file in self._files_to_close_on_error: |
|
file.close() |
|
Path(file.name).unlink() |
|
raise exc |
|
|
|
parser.finalize() |
|
if self.upload_progress is not None: |
|
self.upload_progress.set_done(self.upload_id) |
|
return FormData(self.items) |
|
|
|
|
|
def move_uploaded_files_to_cache(files: list[str], destinations: list[str]) -> None: |
|
for file, dest in zip(files, destinations): |
|
shutil.move(file, dest) |
|
|
|
|
|
def update_root_in_config(config: dict, root: str) -> dict: |
|
""" |
|
Updates the root "key" in the config dictionary to the new root url. If the |
|
root url has changed, all of the urls in the config that correspond to component |
|
file urls are updated to use the new root url. |
|
""" |
|
previous_root = config.get("root") |
|
if previous_root is None or previous_root != root: |
|
config["root"] = root |
|
config = processing_utils.add_root_url(config, root, previous_root) |
|
return config |
|
|
|
|
|
def compare_passwords_securely(input_password: str, correct_password: str) -> bool: |
|
return hmac.compare_digest(input_password.encode(), correct_password.encode()) |
|
|
|
|
|
def starts_with_protocol(string: str) -> bool: |
|
"""This regex matches strings that start with a scheme (one or more characters not including colon, slash, or space) |
|
followed by ://, or start with just //, \\/, /\\, or \\ as they are interpreted as SMB paths on Windows. |
|
""" |
|
pattern = r"^(?:[a-zA-Z][a-zA-Z0-9+\-.]*://|//|\\\\|\\/|/\\)" |
|
return re.match(pattern, string) is not None |
|
|
|
|
|
def get_hostname(url: str) -> str: |
|
""" |
|
Returns the hostname of a given url, or an empty string if the url cannot be parsed. |
|
Examples: |
|
get_hostname("https://www.gradio.app") -> "www.gradio.app" |
|
get_hostname("localhost:7860") -> "localhost" |
|
get_hostname("127.0.0.1") -> "127.0.0.1" |
|
""" |
|
if not url: |
|
return "" |
|
if "://" not in url: |
|
url = "http://" + url |
|
try: |
|
return urlparse(url).hostname or "" |
|
except Exception: |
|
return "" |
|
|
|
|
|
class CustomCORSMiddleware: |
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
app: ASGIApp, |
|
) -> None: |
|
self.app = app |
|
self.all_methods = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") |
|
self.preflight_headers = { |
|
"Access-Control-Allow-Methods": ", ".join(self.all_methods), |
|
"Access-Control-Max-Age": str(600), |
|
} |
|
self.simple_headers = {"Access-Control-Allow-Credentials": "true"} |
|
|
|
|
|
|
|
self.localhost_aliases = ["localhost", "127.0.0.1", "0.0.0.0", "null"] |
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
|
if scope["type"] != "http": |
|
await self.app(scope, receive, send) |
|
return |
|
headers = Headers(scope=scope) |
|
origin = headers.get("origin") |
|
if origin is None: |
|
await self.app(scope, receive, send) |
|
return |
|
if scope["method"] == "OPTIONS" and "access-control-request-method" in headers: |
|
response = self.preflight_response(request_headers=headers) |
|
await response(scope, receive, send) |
|
return |
|
await self.simple_response(scope, receive, send, request_headers=headers) |
|
|
|
def preflight_response(self, request_headers: Headers) -> Response: |
|
headers = dict(self.preflight_headers) |
|
origin = request_headers["Origin"] |
|
if self.is_valid_origin(request_headers): |
|
headers["Access-Control-Allow-Origin"] = origin |
|
requested_headers = request_headers.get("access-control-request-headers") |
|
if requested_headers is not None: |
|
headers["Access-Control-Allow-Headers"] = requested_headers |
|
return PlainTextResponse("OK", status_code=200, headers=headers) |
|
|
|
async def simple_response( |
|
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers |
|
) -> None: |
|
send = functools.partial(self._send, send=send, request_headers=request_headers) |
|
await self.app(scope, receive, send) |
|
|
|
async def _send( |
|
self, message: Message, send: Send, request_headers: Headers |
|
) -> None: |
|
if message["type"] != "http.response.start": |
|
await send(message) |
|
return |
|
message.setdefault("headers", []) |
|
headers = MutableHeaders(scope=message) |
|
headers.update(self.simple_headers) |
|
has_cookie = "cookie" in request_headers |
|
origin = request_headers["Origin"] |
|
if has_cookie or self.is_valid_origin(request_headers): |
|
self.allow_explicit_origin(headers, origin) |
|
await send(message) |
|
|
|
def is_valid_origin(self, request_headers: Headers) -> bool: |
|
origin = request_headers["Origin"] |
|
host = request_headers["Host"] |
|
host_name = get_hostname(host) |
|
origin_name = get_hostname(origin) |
|
return ( |
|
host_name not in self.localhost_aliases |
|
or origin_name in self.localhost_aliases |
|
) |
|
|
|
@staticmethod |
|
def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: |
|
headers["Access-Control-Allow-Origin"] = origin |
|
headers.add_vary_header("Origin") |
|
|
|
|
|
def delete_files_created_by_app(blocks: Blocks, age: int | None) -> None: |
|
"""Delete files that are older than age. If age is None, delete all files.""" |
|
dont_delete = set() |
|
for component in blocks.blocks.values(): |
|
dont_delete.update(getattr(component, "keep_in_cache", set())) |
|
for temp_set in blocks.temp_file_sets: |
|
|
|
|
|
to_remove = set() |
|
for file in temp_set: |
|
if file in dont_delete: |
|
continue |
|
try: |
|
file_path = Path(file) |
|
modified_time = datetime.fromtimestamp(file_path.lstat().st_ctime) |
|
if age is None or (datetime.now() - modified_time).seconds > age: |
|
os.remove(file) |
|
to_remove.add(file) |
|
except FileNotFoundError: |
|
continue |
|
temp_set -= to_remove |
|
|
|
|
|
async def delete_files_on_schedule(app: App, frequency: int, age: int) -> None: |
|
"""Startup task to delete files created by the app based on time since last modification.""" |
|
while True: |
|
await asyncio.sleep(frequency) |
|
await anyio.to_thread.run_sync( |
|
delete_files_created_by_app, app.get_blocks(), age |
|
) |
|
|
|
|
|
@asynccontextmanager |
|
async def _lifespan_handler( |
|
app: App, frequency: int = 1, age: int = 1 |
|
) -> AsyncGenerator: |
|
"""A context manager that triggers the startup and shutdown events of the app.""" |
|
asyncio.create_task(delete_files_on_schedule(app, frequency, age)) |
|
yield |
|
delete_files_created_by_app(app.get_blocks(), age=None) |
|
|
|
|
|
async def _delete_state(app: App): |
|
"""Delete all expired state every second.""" |
|
while True: |
|
app.state_holder.delete_all_expired_state() |
|
await asyncio.sleep(1) |
|
|
|
|
|
@asynccontextmanager |
|
async def _delete_state_handler(app: App): |
|
"""When the server launches, regularly delete expired state.""" |
|
|
|
|
|
if sys.version_info < (3, 10): |
|
loop = asyncio.get_running_loop() |
|
app.stop_event = asyncio.Event(loop=loop) |
|
asyncio.create_task(_delete_state(app)) |
|
yield |
|
|
|
|
|
def create_lifespan_handler( |
|
user_lifespan: Callable[[App], AsyncContextManager] | None, |
|
frequency: int | None = 1, |
|
age: int | None = 1, |
|
) -> Callable[[App], AsyncContextManager]: |
|
"""Return a context manager that applies _lifespan_handler and user_lifespan if it exists.""" |
|
|
|
@asynccontextmanager |
|
async def _handler(app: App): |
|
async with AsyncExitStack() as stack: |
|
await stack.enter_async_context(_delete_state_handler(app)) |
|
if frequency and age: |
|
await stack.enter_async_context(_lifespan_handler(app, frequency, age)) |
|
if user_lifespan is not None: |
|
await stack.enter_async_context(user_lifespan(app)) |
|
yield |
|
|
|
return _handler |
|
|