shuvom's picture
Upload folder using huggingface_hub
5f5d58c
from __future__ import annotations
import hashlib
import json
from collections import deque
from dataclasses import dataclass as python_dataclass
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union
import fastapi
import httpx
import multipart
from gradio_client.documentation import document, set_documentation_group
from multipart.multipart import parse_options_header
from starlette.datastructures import FormData, Headers, UploadFile
from starlette.formparsers import MultiPartException, MultipartPart
from gradio import utils
from gradio.data_classes import PredictBody
from gradio.exceptions import Error
from gradio.helpers import EventData
from gradio.state_holder import SessionState
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.routes import App
set_documentation_group("routes")
class Obj:
"""
Using a class to convert dictionaries into objects. Used by the `Request` class.
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
"""
def __init__(self, dict_):
self.__dict__.update(dict_)
for key, value in dict_.items():
if isinstance(value, (dict, list)):
value = Obj(value)
setattr(self, key, value)
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, item, value):
self.__dict__[item] = value
def __iter__(self):
for key, value in self.__dict__.items():
if isinstance(value, Obj):
yield (key, dict(value))
else:
yield (key, value)
def __contains__(self, item) -> bool:
if item in self.__dict__:
return True
for value in self.__dict__.values():
if isinstance(value, Obj) and item in value:
return True
return False
def keys(self):
return self.__dict__.keys()
def values(self):
return self.__dict__.values()
def items(self):
return self.__dict__.items()
def __str__(self) -> str:
return str(self.__dict__)
def __repr__(self) -> str:
return str(self.__dict__)
@document()
class Request:
"""
A Gradio request object that can be used to access the request headers, cookies,
query parameters and other information about the request from within the prediction
function. The class is a thin wrapper around the fastapi.Request class. Attributes
of this class include: `headers`, `client`, `query_params`, and `path_params`. If
auth is enabled, the `username` attribute can be used to get the logged in user.
Example:
import gradio as gr
def echo(text, request: gr.Request):
if request:
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
print("Query parameters:", dict(request.query_params))
return text
io = gr.Interface(echo, "textbox", "textbox").launch()
Demos: request_ip_headers
"""
def __init__(
self,
request: fastapi.Request | None = None,
username: str | None = None,
**kwargs,
):
"""
Can be instantiated with either a fastapi.Request or by manually passing in
attributes (needed for queueing).
Parameters:
request: A fastapi.Request
"""
self.request = request
self.username = username
self.kwargs: dict = kwargs
def dict_to_obj(self, d):
if isinstance(d, dict):
return json.loads(json.dumps(d), object_hook=Obj)
else:
return d
def __getattr__(self, name):
if self.request:
return self.dict_to_obj(getattr(self.request, name))
else:
try:
obj = self.kwargs[name]
except KeyError as ke:
raise AttributeError(
f"'Request' object has no attribute '{name}'"
) from ke
return self.dict_to_obj(obj)
class FnIndexInferError(Exception):
pass
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int:
if body.fn_index is None:
for i, fn in enumerate(app.get_blocks().dependencies):
if fn["api_name"] == api_name:
return i
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.")
else:
return body.fn_index
def compile_gr_request(
app: App,
body: PredictBody,
fn_index_inferred: int,
username: Optional[str],
request: Optional[fastapi.Request],
):
# If this fn_index cancels jobs, then the only input we need is the
# current session hash
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
body.data = [body.session_hash]
if body.request:
if body.batched:
gr_request = [Request(username=username, request=request)]
else:
gr_request = Request(username=username, request=body.request)
else:
if request is None:
raise ValueError("request must be provided if body.request is None")
gr_request = Request(username=username, request=request)
return gr_request
def restore_session_state(app: App, body: PredictBody):
event_id = body.event_id
session_hash = getattr(body, "session_hash", None)
if session_hash is not None:
session_state = app.state_holder[session_hash]
# The should_reset set keeps track of the fn_indices
# that have been cancelled. When a job is cancelled,
# the /reset route will mark the jobs as having been reset.
# That way if the cancel job finishes BEFORE the job being cancelled
# the job being cancelled will not overwrite the state of the iterator.
if event_id is None:
iterator = None
elif event_id in app.iterators_to_reset:
iterator = None
app.iterators_to_reset.remove(event_id)
else:
iterator = app.iterators.get(event_id)
else:
session_state = SessionState(app.get_blocks())
iterator = None
return session_state, iterator
def prepare_event_data(
blocks: Blocks,
body: PredictBody,
) -> EventData:
target = body.trigger_id
event_data = EventData(
blocks.blocks.get(target) if target else None,
body.event_data,
)
return event_data
async def call_process_api(
app: App,
body: PredictBody,
gr_request: Union[Request, list[Request]],
fn_index_inferred: int,
):
session_state, iterator = restore_session_state(app=app, body=body)
dependency = app.get_blocks().dependencies[fn_index_inferred]
event_data = prepare_event_data(app.get_blocks(), body)
event_id = body.event_id
session_hash = getattr(body, "session_hash", None)
inputs = body.data
batch_in_single_out = not body.batched and dependency["batch"]
if batch_in_single_out:
inputs = [inputs]
try:
with utils.MatplotlibBackendMananger():
output = await app.get_blocks().process_api(
fn_index=fn_index_inferred,
inputs=inputs,
request=gr_request,
state=session_state,
iterator=iterator,
session_hash=session_hash,
event_id=event_id,
event_data=event_data,
in_event_listener=True,
)
iterator = output.pop("iterator", None)
if event_id is not None:
app.iterators[event_id] = iterator # type: ignore
if isinstance(output, Error):
raise output
except BaseException:
iterator = app.iterators.get(event_id) if event_id is not None else None
if iterator is not None: # close off any streams that are still open
run_id = id(iterator)
pending_streams: dict[int, list] = (
app.get_blocks().pending_streams[session_hash].get(run_id, {})
)
for stream in pending_streams.values():
stream.append(None)
raise
if batch_in_single_out:
output["data"] = output["data"][0]
return output
def strip_url(orig_url: str) -> str:
"""
Strips the query parameters and trailing slash from a URL.
"""
parsed_url = httpx.URL(orig_url)
stripped_url = parsed_url.copy_with(query=None)
stripped_url = str(stripped_url)
return stripped_url.rstrip("/")
def _user_safe_decode(src: bytes, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
return src.decode("latin-1")
class GradioUploadFile(UploadFile):
"""UploadFile with a sha attribute."""
def __init__(
self,
file: BinaryIO,
*,
size: int | None = None,
filename: str | None = None,
headers: Headers | None = None,
) -> None:
super().__init__(file, size=size, filename=filename, headers=headers)
self.sha = hashlib.sha1()
@python_dataclass(frozen=True)
class FileUploadProgressUnit:
filename: str
chunk_size: int
is_done: bool
class FileUploadProgress:
def __init__(self) -> None:
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {}
def track(self, upload_id: str):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
def update(self, upload_id: str, filename: str, message_bytes: bytes):
if upload_id not in self._statuses:
self._statuses[upload_id] = deque()
self._statuses[upload_id].append(
FileUploadProgressUnit(filename, len(message_bytes), is_done=False)
)
def set_done(self, upload_id: str):
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True))
def stop_tracking(self, upload_id: str):
if upload_id in self._statuses:
del self._statuses[upload_id]
def status(self, upload_id: str) -> deque[FileUploadProgressUnit]:
if upload_id not in self._statuses:
return deque()
return self._statuses[upload_id]
def is_tracked(self, upload_id: str):
return upload_id in self._statuses
class GradioMultiPartParser:
"""Vendored from starlette.MultipartParser.
Thanks starlette!
Made the following modifications
- Use GradioUploadFile instead of UploadFile
- Use NamedTemporaryFile instead of SpooledTemporaryFile
- Compute hash of data as the request is streamed
"""
max_file_size = 1024 * 1024
def __init__(
self,
headers: Headers,
stream: AsyncGenerator[bytes, None],
*,
max_files: Union[int, float] = 1000,
max_fields: Union[int, float] = 1000,
upload_id: str | None = None,
upload_progress: FileUploadProgress | None = None,
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.items: List[Tuple[str, Union[str, UploadFile]]] = []
self.upload_id = upload_id
self.upload_progress = upload_progress
self._current_files = 0
self._current_fields = 0
self._current_partial_header_name: bytes = b""
self._current_partial_header_value: bytes = b""
self._current_part = MultipartPart()
self._charset = ""
self._file_parts_to_write: List[Tuple[MultipartPart, bytes]] = []
self._file_parts_to_finish: List[MultipartPart] = []
self._files_to_close_on_error: List[_TemporaryFileWrapper] = []
def on_part_begin(self) -> None:
self._current_part = MultipartPart()
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self.upload_progress is not None:
self.upload_progress.update(
self.upload_id, # type: ignore
self._current_part.file.filename, # type: ignore
message_bytes,
)
if self._current_part.file is None:
self._current_part.data += message_bytes
else:
self._file_parts_to_write.append((self._current_part, message_bytes))
def on_part_end(self) -> None:
if self._current_part.file is None:
self.items.append(
(
self._current_part.field_name,
_user_safe_decode(self._current_part.data, self._charset),
)
)
else:
self._file_parts_to_finish.append(self._current_part)
# The file can be added to the items right now even though it's not
# finished yet, because it will be finished in the `parse()` method, before
# self.items is used in the return value.
self.items.append((self._current_part.field_name, self._current_part.file))
def on_header_field(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_name += data[start:end]
def on_header_value(self, data: bytes, start: int, end: int) -> None:
self._current_partial_header_value += data[start:end]
def on_header_end(self) -> None:
field = self._current_partial_header_name.lower()
if field == b"content-disposition":
self._current_part.content_disposition = self._current_partial_header_value
self._current_part.item_headers.append(
(field, self._current_partial_header_value)
)
self._current_partial_header_name = b""
self._current_partial_header_value = b""
def on_headers_finished(self) -> None:
disposition, options = parse_options_header(
self._current_part.content_disposition
)
try:
self._current_part.field_name = _user_safe_decode(
options[b"name"], self._charset
)
except KeyError as e:
raise MultiPartException(
'The Content-Disposition header field "name" must be ' "provided."
) from e
if b"filename" in options:
self._current_files += 1
if self._current_files > self.max_files:
raise MultiPartException(
f"Too many files. Maximum number of files is {self.max_files}."
)
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = NamedTemporaryFile(delete=False)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = GradioUploadFile(
file=tempfile, # type: ignore[arg-type]
size=0,
filename=filename,
headers=Headers(raw=self._current_part.item_headers),
)
else:
self._current_fields += 1
if self._current_fields > self.max_fields:
raise MultiPartException(
f"Too many fields. Maximum number of fields is {self.max_fields}."
)
self._current_part.file = None
def on_end(self) -> None:
pass
async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
_, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if isinstance(charset, bytes):
charset = charset.decode("latin-1")
self._charset = charset
try:
boundary = params[b"boundary"]
except KeyError as e:
raise MultiPartException("Missing boundary in multipart.") from e
# Callbacks dictionary.
callbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
"on_header_field": self.on_header_field,
"on_header_value": self.on_header_value,
"on_header_end": self.on_header_end,
"on_headers_finished": self.on_headers_finished,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
try:
# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
# Write file data, it needs to use await with the UploadFile methods
# that call the corresponding file methods *in a threadpool*,
# otherwise, if they were called directly in the callback methods above
# (regular, non-async functions), that would block the event loop in
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
await part.file.write(data)
part.file.sha.update(data) # type: ignore
for part in self._file_parts_to_finish:
assert part.file # for type checkers
await part.file.seek(0)
self._file_parts_to_write.clear()
self._file_parts_to_finish.clear()
except MultiPartException as exc:
# Close all the files if there was an error.
for file in self._files_to_close_on_error:
file.close()
raise exc
parser.finalize()
if self.upload_progress is not None:
self.upload_progress.set_done(self.upload_id) # type: ignore
return FormData(self.items)