| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import base64 |
| | import concurrent.futures |
| | import copy |
| | import json |
| | import mimetypes |
| | import os |
| | import pkgutil |
| | import secrets |
| | import shutil |
| | import tempfile |
| | import time |
| | import warnings |
| | from dataclasses import dataclass, field |
| | from datetime import datetime |
| | from enum import Enum |
| | from pathlib import Path |
| | from threading import Lock |
| | from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, Optional, TypedDict |
| |
|
| | import fsspec.asyn |
| | import httpx |
| | import huggingface_hub |
| | from huggingface_hub import SpaceStage |
| | from websockets.legacy.protocol import WebSocketCommonProtocol |
| |
|
| | if TYPE_CHECKING: |
| | from gradio_client.data_classes import ParameterInfo |
| |
|
| | API_URL = "api/predict/" |
| | SSE_URL_V0 = "queue/join" |
| | SSE_DATA_URL_V0 = "queue/data" |
| | SSE_URL = "queue/data" |
| | SSE_DATA_URL = "queue/join" |
| | WS_URL = "queue/join" |
| | UPLOAD_URL = "upload" |
| | LOGIN_URL = "login" |
| | CONFIG_URL = "config" |
| | API_INFO_URL = "info?all_endpoints=True" |
| | RAW_API_INFO_URL = "info?serialize=False" |
| | SPACE_FETCHER_URL = "https://gradio-space-api-fetcher-v2.hf.space/api" |
| | RESET_URL = "reset" |
| | SPACE_URL = "https://hf.space/{}" |
| | HEARTBEAT_URL = "heartbeat/{session_hash}" |
| | CANCEL_URL = "cancel" |
| |
|
| | STATE_COMPONENT = "state" |
| | INVALID_RUNTIME = [ |
| | SpaceStage.NO_APP_FILE, |
| | SpaceStage.CONFIG_ERROR, |
| | SpaceStage.BUILD_ERROR, |
| | SpaceStage.RUNTIME_ERROR, |
| | SpaceStage.PAUSED, |
| | ] |
| |
|
| |
|
| | class Message(TypedDict, total=False): |
| | msg: str |
| | output: dict[str, Any] |
| | event_id: str |
| | rank: int |
| | rank_eta: float |
| | queue_size: int |
| | success: bool |
| | progress_data: list[dict] |
| | log: str |
| | level: str |
| |
|
| |
|
| | def get_package_version() -> str: |
| | try: |
| | package_json_data = ( |
| | pkgutil.get_data(__name__, "package.json").decode("utf-8").strip() |
| | ) |
| | package_data = json.loads(package_json_data) |
| | version = package_data.get("version", "") |
| | return version |
| | except Exception: |
| | return "" |
| |
|
| |
|
| | __version__ = get_package_version() |
| |
|
| |
|
| | class TooManyRequestsError(Exception): |
| | """Raised when the API returns a 429 status code.""" |
| |
|
| | pass |
| |
|
| |
|
| | class QueueError(Exception): |
| | """Raised when the queue is full or there is an issue adding a job to the queue.""" |
| |
|
| | pass |
| |
|
| |
|
| | class InvalidAPIEndpointError(Exception): |
| | """Raised when the API endpoint is invalid.""" |
| |
|
| | pass |
| |
|
| |
|
| | class SpaceDuplicationError(Exception): |
| | """Raised when something goes wrong with a Space Duplication.""" |
| |
|
| | pass |
| |
|
| |
|
| | class ServerMessage(str, Enum): |
| | send_hash = "send_hash" |
| | queue_full = "queue_full" |
| | estimation = "estimation" |
| | send_data = "send_data" |
| | process_starts = "process_starts" |
| | process_generating = "process_generating" |
| | process_completed = "process_completed" |
| | log = "log" |
| | progress = "progress" |
| | heartbeat = "heartbeat" |
| | server_stopped = "Server stopped unexpectedly." |
| | unexpected_error = "unexpected_error" |
| | close_stream = "close_stream" |
| |
|
| |
|
| | class Status(Enum): |
| | """Status codes presented to client users.""" |
| |
|
| | STARTING = "STARTING" |
| | JOINING_QUEUE = "JOINING_QUEUE" |
| | QUEUE_FULL = "QUEUE_FULL" |
| | IN_QUEUE = "IN_QUEUE" |
| | SENDING_DATA = "SENDING_DATA" |
| | PROCESSING = "PROCESSING" |
| | ITERATING = "ITERATING" |
| | PROGRESS = "PROGRESS" |
| | FINISHED = "FINISHED" |
| | CANCELLED = "CANCELLED" |
| | LOG = "LOG" |
| |
|
| | @staticmethod |
| | def ordering(status: Status) -> int: |
| | """Order of messages. Helpful for testing.""" |
| | order = [ |
| | Status.STARTING, |
| | Status.JOINING_QUEUE, |
| | Status.QUEUE_FULL, |
| | Status.IN_QUEUE, |
| | Status.SENDING_DATA, |
| | Status.PROCESSING, |
| | Status.PROGRESS, |
| | Status.ITERATING, |
| | Status.FINISHED, |
| | Status.CANCELLED, |
| | ] |
| | return order.index(status) |
| |
|
| | def __lt__(self, other: Status): |
| | return self.ordering(self) < self.ordering(other) |
| |
|
| | @staticmethod |
| | def msg_to_status(msg: str) -> Status: |
| | """Map the raw message from the backend to the status code presented to users.""" |
| | return { |
| | ServerMessage.send_hash: Status.JOINING_QUEUE, |
| | ServerMessage.queue_full: Status.QUEUE_FULL, |
| | ServerMessage.estimation: Status.IN_QUEUE, |
| | ServerMessage.send_data: Status.SENDING_DATA, |
| | ServerMessage.process_starts: Status.PROCESSING, |
| | ServerMessage.process_generating: Status.ITERATING, |
| | ServerMessage.process_completed: Status.FINISHED, |
| | ServerMessage.progress: Status.PROGRESS, |
| | ServerMessage.log: Status.LOG, |
| | ServerMessage.server_stopped: Status.FINISHED, |
| | }[msg] |
| |
|
| |
|
| | @dataclass |
| | class ProgressUnit: |
| | index: Optional[int] |
| | length: Optional[int] |
| | unit: Optional[str] |
| | progress: Optional[float] |
| | desc: Optional[str] |
| |
|
| | @classmethod |
| | def from_msg(cls, data: list[dict]) -> list[ProgressUnit]: |
| | return [ |
| | cls( |
| | index=d.get("index"), |
| | length=d.get("length"), |
| | unit=d.get("unit"), |
| | progress=d.get("progress"), |
| | desc=d.get("desc"), |
| | ) |
| | for d in data |
| | ] |
| |
|
| |
|
| | @dataclass |
| | class StatusUpdate: |
| | """Update message sent from the worker thread to the Job on the main thread.""" |
| |
|
| | code: Status |
| | rank: int | None |
| | queue_size: int | None |
| | eta: float | None |
| | success: bool | None |
| | time: datetime | None |
| | progress_data: list[ProgressUnit] | None |
| | log: tuple[str, str] | None = None |
| |
|
| |
|
| | def create_initial_status_update(): |
| | return StatusUpdate( |
| | code=Status.STARTING, |
| | rank=None, |
| | queue_size=None, |
| | eta=None, |
| | success=None, |
| | time=datetime.now(), |
| | progress_data=None, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class JobStatus: |
| | """The job status. |
| | |
| | Keeps track of the latest status update and intermediate outputs (not yet implements). |
| | """ |
| |
|
| | latest_status: StatusUpdate = field(default_factory=create_initial_status_update) |
| | outputs: list[Any] = field(default_factory=list) |
| |
|
| |
|
| | @dataclass |
| | class Communicator: |
| | """Helper class to help communicate between the worker thread and main thread.""" |
| |
|
| | lock: Lock |
| | job: JobStatus |
| | prediction_processor: Callable[..., tuple] |
| | reset_url: str |
| | should_cancel: bool = False |
| | event_id: str | None = None |
| | thread_complete: bool = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def is_http_url_like(possible_url) -> bool: |
| | """ |
| | Check if the given value is a string that looks like an HTTP(S) URL. |
| | """ |
| | if not isinstance(possible_url, str): |
| | return False |
| | return possible_url.startswith(("http://", "https://")) |
| |
|
| |
|
| | def probe_url(possible_url: str) -> bool: |
| | """ |
| | Probe the given URL to see if it responds with a 200 status code (to HEAD, then to GET). |
| | """ |
| | headers = {"User-Agent": "gradio (https://gradio.app/; gradio-team@huggingface.co)"} |
| | try: |
| | with httpx.Client() as client: |
| | head_request = httpx.head(possible_url, headers=headers) |
| | if head_request.status_code == 405: |
| | return client.get(possible_url, headers=headers).is_success |
| | return head_request.is_success |
| | except Exception: |
| | return False |
| |
|
| |
|
| | def is_valid_url(possible_url: str) -> bool: |
| | """ |
| | Check if the given string is a valid URL. |
| | """ |
| | warnings.warn( |
| | "is_valid_url should not be used. " |
| | "Use is_http_url_like() and probe_url(), as suitable, instead.", |
| | ) |
| | return is_http_url_like(possible_url) and probe_url(possible_url) |
| |
|
| |
|
| | async def get_pred_from_ws( |
| | websocket: WebSocketCommonProtocol, |
| | data: str, |
| | hash_data: str, |
| | helper: Communicator | None = None, |
| | ) -> dict[str, Any]: |
| | completed = False |
| | resp = {} |
| | while not completed: |
| | |
| | |
| | task = asyncio.create_task(websocket.recv()) |
| | while not task.done(): |
| | if helper: |
| | with helper.lock: |
| | if helper.should_cancel: |
| | |
| | |
| | async with httpx.AsyncClient() as http: |
| | reset = http.post( |
| | helper.reset_url, json=json.loads(hash_data) |
| | ) |
| | |
| | |
| | task.cancel() |
| | await asyncio.gather(task, reset, return_exceptions=True) |
| | raise concurrent.futures.CancelledError() |
| | |
| | await asyncio.sleep(0.01) |
| | msg = task.result() |
| | resp = json.loads(msg) |
| | if helper: |
| | with helper.lock: |
| | has_progress = "progress_data" in resp |
| | status_update = StatusUpdate( |
| | code=Status.msg_to_status(resp["msg"]), |
| | queue_size=resp.get("queue_size"), |
| | rank=resp.get("rank", None), |
| | success=resp.get("success"), |
| | time=datetime.now(), |
| | eta=resp.get("rank_eta"), |
| | progress_data=ProgressUnit.from_msg(resp["progress_data"]) |
| | if has_progress |
| | else None, |
| | ) |
| | output = resp.get("output", {}).get("data", []) |
| | if output and status_update.code != Status.FINISHED: |
| | try: |
| | result = helper.prediction_processor(*output) |
| | except Exception as e: |
| | result = [e] |
| | helper.job.outputs.append(result) |
| | helper.job.latest_status = status_update |
| | if resp["msg"] == "queue_full": |
| | raise QueueError("Queue is full! Please try again.") |
| | if resp["msg"] == "send_hash": |
| | await websocket.send(hash_data) |
| | elif resp["msg"] == "send_data": |
| | await websocket.send(data) |
| | completed = resp["msg"] == "process_completed" |
| | return resp["output"] |
| |
|
| |
|
| | def get_pred_from_sse_v0( |
| | client: httpx.Client, |
| | data: dict, |
| | hash_data: dict, |
| | helper: Communicator, |
| | sse_url: str, |
| | sse_data_url: str, |
| | headers: dict[str, str], |
| | cookies: dict[str, str] | None, |
| | ssl_verify: bool, |
| | executor: concurrent.futures.ThreadPoolExecutor, |
| | ) -> dict[str, Any] | None: |
| | helper.thread_complete = False |
| | future_cancel = executor.submit( |
| | check_for_cancel, helper, headers, cookies, ssl_verify |
| | ) |
| | future_sse = executor.submit( |
| | stream_sse_v0, |
| | client, |
| | data, |
| | hash_data, |
| | helper, |
| | sse_url, |
| | sse_data_url, |
| | headers, |
| | cookies, |
| | ) |
| | done, _ = concurrent.futures.wait( |
| | [future_cancel, future_sse], |
| | return_when=concurrent.futures.FIRST_COMPLETED, |
| | ) |
| | helper.thread_complete = True |
| |
|
| | if len(done) != 1: |
| | raise ValueError(f"Did not expect {len(done)} tasks to be done.") |
| | for future in done: |
| | return future.result() |
| |
|
| |
|
| | def get_pred_from_sse_v1plus( |
| | helper: Communicator, |
| | headers: dict[str, str], |
| | cookies: dict[str, str] | None, |
| | pending_messages_per_event: dict[str, list[Message | None]], |
| | event_id: str, |
| | protocol: Literal["sse_v1", "sse_v2", "sse_v2.1"], |
| | ssl_verify: bool, |
| | executor: concurrent.futures.ThreadPoolExecutor, |
| | ) -> dict[str, Any] | None: |
| | helper.thread_complete = False |
| | future_cancel = executor.submit( |
| | check_for_cancel, helper, headers, cookies, ssl_verify |
| | ) |
| | future_sse = executor.submit( |
| | stream_sse_v1plus, helper, pending_messages_per_event, event_id, protocol |
| | ) |
| | done, _ = concurrent.futures.wait( |
| | [future_cancel, future_sse], |
| | return_when=concurrent.futures.FIRST_COMPLETED, |
| | ) |
| | helper.thread_complete = True |
| |
|
| | if len(done) != 1: |
| | raise ValueError(f"Did not expect {len(done)} tasks to be done.") |
| | for future in done: |
| | exception = future.exception() |
| | if exception: |
| | raise exception |
| | return future.result() |
| |
|
| |
|
| | def check_for_cancel( |
| | helper: Communicator, |
| | headers: dict[str, str], |
| | cookies: dict[str, str] | None, |
| | ssl_verify: bool, |
| | ): |
| | while True: |
| | time.sleep(0.05) |
| | with helper.lock: |
| | if helper.should_cancel: |
| | break |
| | if helper.thread_complete: |
| | raise concurrent.futures.CancelledError() |
| | if helper.event_id: |
| | httpx.post( |
| | helper.reset_url, |
| | json={"event_id": helper.event_id}, |
| | headers=headers, |
| | cookies=cookies, |
| | verify=ssl_verify, |
| | ) |
| | raise concurrent.futures.CancelledError() |
| |
|
| |
|
| | def stream_sse_v0( |
| | client: httpx.Client, |
| | data: dict, |
| | hash_data: dict, |
| | helper: Communicator, |
| | sse_url: str, |
| | sse_data_url: str, |
| | headers: dict[str, str], |
| | cookies: dict[str, str] | None, |
| | ) -> dict[str, Any]: |
| | try: |
| | with client.stream( |
| | "GET", |
| | sse_url, |
| | params=hash_data, |
| | headers=headers, |
| | cookies=cookies, |
| | ) as response: |
| | for line in response.iter_lines(): |
| | line = line.rstrip("\n") |
| | if len(line) == 0: |
| | continue |
| | if line.startswith("data:"): |
| | resp = json.loads(line[5:]) |
| | if resp["msg"] in [ServerMessage.log, ServerMessage.heartbeat]: |
| | continue |
| | with helper.lock: |
| | has_progress = "progress_data" in resp |
| | status_update = StatusUpdate( |
| | code=Status.msg_to_status(resp["msg"]), |
| | queue_size=resp.get("queue_size"), |
| | rank=resp.get("rank", None), |
| | success=resp.get("success"), |
| | time=datetime.now(), |
| | eta=resp.get("rank_eta"), |
| | progress_data=ProgressUnit.from_msg(resp["progress_data"]) |
| | if has_progress |
| | else None, |
| | ) |
| | output = resp.get("output", {}).get("data", []) |
| | if output and status_update.code != Status.FINISHED: |
| | try: |
| | result = helper.prediction_processor(*output) |
| | except Exception as e: |
| | result = [e] |
| | helper.job.outputs.append(result) |
| | helper.job.latest_status = status_update |
| | if helper.thread_complete: |
| | raise concurrent.futures.CancelledError() |
| | if resp["msg"] == "queue_full": |
| | raise QueueError("Queue is full! Please try again.") |
| | elif resp["msg"] == "send_data": |
| | event_id = resp["event_id"] |
| | helper.event_id = event_id |
| | req = client.post( |
| | sse_data_url, |
| | json={"event_id": event_id, **data, **hash_data}, |
| | headers=headers, |
| | cookies=cookies, |
| | ) |
| | req.raise_for_status() |
| | elif resp["msg"] == "process_completed": |
| | return resp["output"] |
| | else: |
| | raise ValueError(f"Unexpected message: {line}") |
| | raise ValueError("Did not receive process_completed message.") |
| | except concurrent.futures.CancelledError: |
| | raise |
| |
|
| |
|
| | def stream_sse_v1plus( |
| | helper: Communicator, |
| | pending_messages_per_event: dict[str, list[Message | None]], |
| | event_id: str, |
| | protocol: Literal["sse_v1", "sse_v2", "sse_v2.1", "sse_v3"], |
| | ) -> dict[str, Any]: |
| | try: |
| | pending_messages = pending_messages_per_event[event_id] |
| | pending_responses_for_diffs = None |
| |
|
| | while True: |
| | if len(pending_messages) > 0: |
| | msg = pending_messages.pop(0) |
| | else: |
| | time.sleep(0.05) |
| | continue |
| |
|
| | if msg is None or helper.thread_complete: |
| | raise concurrent.futures.CancelledError() |
| |
|
| | with helper.lock: |
| | log_message = None |
| | if msg["msg"] == ServerMessage.log: |
| | log = msg.get("log") |
| | level = msg.get("level") |
| | if log and level: |
| | log_message = (log, level) |
| | status_update = StatusUpdate( |
| | code=Status.msg_to_status(msg["msg"]), |
| | queue_size=msg.get("queue_size"), |
| | rank=msg.get("rank", None), |
| | success=msg.get("success"), |
| | time=datetime.now(), |
| | eta=msg.get("rank_eta"), |
| | progress_data=ProgressUnit.from_msg(msg["progress_data"]) |
| | if "progress_data" in msg |
| | else None, |
| | log=log_message, |
| | ) |
| | output = msg.get("output", {}).get("data", []) |
| | if msg["msg"] == ServerMessage.process_generating and protocol in [ |
| | "sse_v2", |
| | "sse_v2.1", |
| | "sse_v3", |
| | ]: |
| | if pending_responses_for_diffs is None: |
| | pending_responses_for_diffs = list(output) |
| | else: |
| | for i, value in enumerate(output): |
| | prev_output = pending_responses_for_diffs[i] |
| | new_output = apply_diff(prev_output, value) |
| | pending_responses_for_diffs[i] = new_output |
| | output[i] = new_output |
| |
|
| | if output and status_update.code != Status.FINISHED: |
| | try: |
| | result = helper.prediction_processor(*output) |
| | except Exception as e: |
| | result = [e] |
| | helper.job.outputs.append(result) |
| | helper.job.latest_status = status_update |
| | if msg["msg"] == ServerMessage.process_completed: |
| | del pending_messages_per_event[event_id] |
| | return msg["output"] |
| | elif msg["msg"] == ServerMessage.server_stopped: |
| | raise ValueError("Server stopped.") |
| |
|
| | except concurrent.futures.CancelledError: |
| | raise |
| |
|
| |
|
| | def apply_diff(obj, diff): |
| | obj = copy.deepcopy(obj) |
| |
|
| | def apply_edit(target, path, action, value): |
| | if len(path) == 0: |
| | if action == "replace": |
| | return value |
| | elif action == "append": |
| | return target + value |
| | else: |
| | raise ValueError(f"Unsupported action: {action}") |
| |
|
| | current = target |
| | for i in range(len(path) - 1): |
| | current = current[path[i]] |
| |
|
| | last_path = path[-1] |
| | if action == "replace": |
| | current[last_path] = value |
| | elif action == "append": |
| | current[last_path] += value |
| | elif action == "add": |
| | if isinstance(current, list): |
| | current.insert(int(last_path), value) |
| | else: |
| | current[last_path] = value |
| | elif action == "delete": |
| | if isinstance(current, list): |
| | del current[int(last_path)] |
| | else: |
| | del current[last_path] |
| | else: |
| | raise ValueError(f"Unknown action: {action}") |
| |
|
| | return target |
| |
|
| | for action, path, value in diff: |
| | obj = apply_edit(obj, path, action, value) |
| |
|
| | return obj |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def create_tmp_copy_of_file(file_path: str, dir: str | None = None) -> str: |
| | directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20) |
| | directory.mkdir(exist_ok=True, parents=True) |
| | dest = directory / Path(file_path).name |
| | shutil.copy2(file_path, dest) |
| | return str(dest.resolve()) |
| |
|
| |
|
| | def download_tmp_copy_of_file( |
| | url_path: str, hf_token: str | None = None, dir: str | None = None |
| | ) -> str: |
| | """Kept for backwards compatibility for 3.x spaces.""" |
| | if dir is not None: |
| | os.makedirs(dir, exist_ok=True) |
| | headers = {"Authorization": "Bearer " + hf_token} if hf_token else {} |
| | directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20) |
| | directory.mkdir(exist_ok=True, parents=True) |
| | file_path = directory / Path(url_path).name |
| |
|
| | with httpx.stream( |
| | "GET", url_path, headers=headers, follow_redirects=True |
| | ) as response: |
| | response.raise_for_status() |
| | with open(file_path, "wb") as f: |
| | for chunk in response.iter_raw(): |
| | f.write(chunk) |
| | return str(file_path.resolve()) |
| |
|
| |
|
| | def get_mimetype(filename: str) -> str | None: |
| | if filename.endswith(".vtt"): |
| | return "text/vtt" |
| | mimetype = mimetypes.guess_type(filename)[0] |
| | if mimetype is not None: |
| | mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac") |
| | return mimetype |
| |
|
| |
|
| | def get_extension(encoding: str) -> str | None: |
| | encoding = encoding.replace("audio/wav", "audio/x-wav") |
| | type = mimetypes.guess_type(encoding)[0] |
| | if type == "audio/flac": |
| | return "flac" |
| | elif type is None: |
| | return None |
| | extension = mimetypes.guess_extension(type) |
| | if extension is not None and extension.startswith("."): |
| | extension = extension[1:] |
| | return extension |
| |
|
| |
|
| | def encode_file_to_base64(f: str | Path): |
| | with open(f, "rb") as file: |
| | encoded_string = base64.b64encode(file.read()) |
| | base64_str = str(encoded_string, "utf-8") |
| | mimetype = get_mimetype(str(f)) |
| | return ( |
| | "data:" |
| | + (mimetype if mimetype is not None else "") |
| | + ";base64," |
| | + base64_str |
| | ) |
| |
|
| |
|
| | def encode_url_to_base64(url: str): |
| | resp = httpx.get(url) |
| | resp.raise_for_status() |
| | encoded_string = base64.b64encode(resp.content) |
| | base64_str = str(encoded_string, "utf-8") |
| | mimetype = get_mimetype(url) |
| | return ( |
| | "data:" + (mimetype if mimetype is not None else "") + ";base64," + base64_str |
| | ) |
| |
|
| |
|
| | def encode_url_or_file_to_base64(path: str | Path): |
| | path = str(path) |
| | if is_http_url_like(path): |
| | return encode_url_to_base64(path) |
| | return encode_file_to_base64(path) |
| |
|
| |
|
| | def download_byte_stream(url: str, hf_token=None): |
| | arr = bytearray() |
| | headers = {"Authorization": "Bearer " + hf_token} if hf_token else {} |
| | with httpx.stream("GET", url, headers=headers) as r: |
| | for data in r.iter_bytes(): |
| | arr += data |
| | yield data |
| | yield arr |
| |
|
| |
|
| | def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]: |
| | extension = get_extension(encoding) |
| | data = encoding.rsplit(",", 1)[-1] |
| | return base64.b64decode(data), extension |
| |
|
| |
|
| | def strip_invalid_filename_characters(filename: str, max_bytes: int = 200) -> str: |
| | """Strips invalid characters from a filename and ensures that the file_length is less than `max_bytes` bytes.""" |
| | filename = "".join([char for char in filename if char.isalnum() or char in "._- "]) |
| | filename_len = len(filename.encode()) |
| | if filename_len > max_bytes: |
| | while filename_len > max_bytes: |
| | if len(filename) == 0: |
| | break |
| | filename = filename[:-1] |
| | filename_len = len(filename.encode()) |
| | return filename |
| |
|
| |
|
| | def sanitize_parameter_names(original_name: str) -> str: |
| | """Cleans up a Python parameter name to make the API info more readable.""" |
| | return ( |
| | "".join([char for char in original_name if char.isalnum() or char in " _"]) |
| | .replace(" ", "_") |
| | .lower() |
| | ) |
| |
|
| |
|
| | def decode_base64_to_file( |
| | encoding: str, |
| | file_path: str | None = None, |
| | dir: str | Path | None = None, |
| | prefix: str | None = None, |
| | ): |
| | directory = Path(dir or tempfile.gettempdir()) / secrets.token_hex(20) |
| | directory.mkdir(exist_ok=True, parents=True) |
| | data, extension = decode_base64_to_binary(encoding) |
| | if file_path is not None and prefix is None: |
| | filename = Path(file_path).name |
| | prefix = filename |
| | if "." in filename: |
| | prefix = filename[0 : filename.index(".")] |
| | extension = filename[filename.index(".") + 1 :] |
| |
|
| | if prefix is not None: |
| | prefix = strip_invalid_filename_characters(prefix) |
| |
|
| | if extension is None: |
| | file_obj = tempfile.NamedTemporaryFile( |
| | delete=False, prefix=prefix, dir=directory |
| | ) |
| | else: |
| | file_obj = tempfile.NamedTemporaryFile( |
| | delete=False, |
| | prefix=prefix, |
| | suffix="." + extension, |
| | dir=directory, |
| | ) |
| | file_obj.write(data) |
| | file_obj.flush() |
| | return file_obj |
| |
|
| |
|
| | def dict_or_str_to_json_file(jsn: str | dict | list, dir: str | Path | None = None): |
| | if dir is not None: |
| | os.makedirs(dir, exist_ok=True) |
| |
|
| | file_obj = tempfile.NamedTemporaryFile( |
| | delete=False, suffix=".json", dir=dir, mode="w+" |
| | ) |
| | if isinstance(jsn, str): |
| | jsn = json.loads(jsn) |
| | json.dump(jsn, file_obj) |
| | file_obj.flush() |
| | return file_obj |
| |
|
| |
|
| | def file_to_json(file_path: str | Path) -> dict | list: |
| | with open(file_path) as f: |
| | return json.load(f) |
| |
|
| |
|
| | |
| | |
| | |
| | def set_space_timeout( |
| | space_id: str, |
| | hf_token: str | None = None, |
| | timeout_in_seconds: int = 300, |
| | ): |
| | headers = huggingface_hub.utils.build_hf_headers( |
| | token=hf_token, |
| | library_name="gradio_client", |
| | library_version=__version__, |
| | ) |
| | try: |
| | httpx.post( |
| | f"https://huggingface.co/api/spaces/{space_id}/sleeptime", |
| | json={"seconds": timeout_in_seconds}, |
| | headers=headers, |
| | ) |
| | except httpx.HTTPStatusError as e: |
| | raise SpaceDuplicationError( |
| | f"Could not set sleep timeout on duplicated Space. Please visit {SPACE_URL.format(space_id)} " |
| | "to set a timeout manually to reduce billing charges." |
| | ) from e |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def synchronize_async(func: Callable, *args, **kwargs) -> Any: |
| | """ |
| | Runs async functions in sync scopes. Can be used in any scope. |
| | |
| | Example: |
| | if inspect.iscoroutinefunction(block_fn.fn): |
| | predictions = utils.synchronize_async(block_fn.fn, *processed_input) |
| | |
| | Args: |
| | func: |
| | *args: |
| | **kwargs: |
| | """ |
| | return fsspec.asyn.sync(fsspec.asyn.get_loop(), func, *args, **kwargs) |
| |
|
| |
|
| | class APIInfoParseError(ValueError): |
| | pass |
| |
|
| |
|
| | def get_type(schema: dict): |
| | if "const" in schema: |
| | return "const" |
| | if "enum" in schema: |
| | return "enum" |
| | elif "type" in schema: |
| | return schema["type"] |
| | elif schema.get("$ref"): |
| | return "$ref" |
| | elif schema.get("oneOf"): |
| | return "oneOf" |
| | elif schema.get("anyOf"): |
| | return "anyOf" |
| | elif schema.get("allOf"): |
| | return "allOf" |
| | elif "type" not in schema: |
| | return {} |
| | else: |
| | raise APIInfoParseError(f"Cannot parse type for {schema}") |
| |
|
| |
|
| | FILE_DATA_FORMATS = [ |
| | "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None)", |
| | "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool)", |
| | "Dict(path: str, url: str | None, size: int | None, orig_name: str | None, mime_type: str | None, is_stream: bool, meta: Dict())", |
| | ] |
| |
|
| | CURRENT_FILE_DATA_FORMAT = FILE_DATA_FORMATS[-1] |
| |
|
| |
|
| | def json_schema_to_python_type(schema: Any) -> str: |
| | type_ = _json_schema_to_python_type(schema, schema.get("$defs")) |
| | return type_.replace(CURRENT_FILE_DATA_FORMAT, "filepath") |
| |
|
| |
|
| | def _json_schema_to_python_type(schema: Any, defs) -> str: |
| | """Convert the json schema into a python type hint""" |
| | if schema == {}: |
| | return "Any" |
| | type_ = get_type(schema) |
| | if type_ == {}: |
| | if "json" in schema.get("description", {}): |
| | return "Dict[Any, Any]" |
| | else: |
| | return "Any" |
| | elif type_ == "$ref": |
| | return _json_schema_to_python_type(defs[schema["$ref"].split("/")[-1]], defs) |
| | elif type_ == "null": |
| | return "None" |
| | elif type_ == "const": |
| | return f"Literal[{schema['const']}]" |
| | elif type_ == "enum": |
| | return ( |
| | "Literal[" + ", ".join(["'" + str(v) + "'" for v in schema["enum"]]) + "]" |
| | ) |
| | elif type_ == "integer": |
| | return "int" |
| | elif type_ == "string": |
| | return "str" |
| | elif type_ == "boolean": |
| | return "bool" |
| | elif type_ == "number": |
| | return "float" |
| | elif type_ == "array": |
| | items = schema.get("items", []) |
| | if "prefixItems" in items: |
| | elements = ", ".join( |
| | [_json_schema_to_python_type(i, defs) for i in items["prefixItems"]] |
| | ) |
| | return f"Tuple[{elements}]" |
| | elif "prefixItems" in schema: |
| | elements = ", ".join( |
| | [_json_schema_to_python_type(i, defs) for i in schema["prefixItems"]] |
| | ) |
| | return f"Tuple[{elements}]" |
| | else: |
| | elements = _json_schema_to_python_type(items, defs) |
| | return f"List[{elements}]" |
| | elif type_ == "object": |
| |
|
| | def get_desc(v): |
| | return f" ({v.get('description')})" if v.get("description") else "" |
| |
|
| | props = schema.get("properties", {}) |
| |
|
| | des = [ |
| | f"{n}: {_json_schema_to_python_type(v, defs)}{get_desc(v)}" |
| | for n, v in props.items() |
| | if n != "$defs" |
| | ] |
| |
|
| | if "additionalProperties" in schema: |
| | des += [ |
| | f"str, {_json_schema_to_python_type(schema['additionalProperties'], defs)}" |
| | ] |
| | des = ", ".join(des) |
| | return f"Dict({des})" |
| | elif type_ in ["oneOf", "anyOf"]: |
| | desc = " | ".join([_json_schema_to_python_type(i, defs) for i in schema[type_]]) |
| | return desc |
| | elif type_ == "allOf": |
| | data = ", ".join(_json_schema_to_python_type(i, defs) for i in schema[type_]) |
| | desc = f"All[{data}]" |
| | return desc |
| | else: |
| | raise APIInfoParseError(f"Cannot parse schema {schema}") |
| |
|
| |
|
| | def traverse(json_obj: Any, func: Callable, is_root: Callable[..., bool]) -> Any: |
| | """ |
| | Traverse a JSON object and apply a function to each element that satisfies the is_root condition. |
| | """ |
| | if is_root(json_obj): |
| | return func(json_obj) |
| | elif isinstance(json_obj, dict): |
| | new_obj = {} |
| | for key, value in json_obj.items(): |
| | new_obj[key] = traverse(value, func, is_root) |
| | return new_obj |
| | elif isinstance(json_obj, (list, tuple)): |
| | new_obj = [] |
| | for item in json_obj: |
| | new_obj.append(traverse(item, func, is_root)) |
| | return new_obj |
| | else: |
| | return json_obj |
| |
|
| |
|
| | async def async_traverse( |
| | json_obj: Any, |
| | func: Callable[..., Coroutine[Any, Any, Any]], |
| | is_root: Callable[..., bool], |
| | ) -> Any: |
| | """ |
| | Traverse a JSON object and apply a async function to each element that satisfies the is_root condition. |
| | """ |
| | if is_root(json_obj): |
| | return await func(json_obj) |
| | elif isinstance(json_obj, dict): |
| | new_obj = {} |
| | for key, value in json_obj.items(): |
| | new_obj[key] = await async_traverse(value, func, is_root) |
| | return new_obj |
| | elif isinstance(json_obj, (list, tuple)): |
| | new_obj = [] |
| | for item in json_obj: |
| | new_obj.append(await async_traverse(item, func, is_root)) |
| | return new_obj |
| | else: |
| | return json_obj |
| |
|
| |
|
| | def value_is_file(api_info: dict) -> bool: |
| | info = _json_schema_to_python_type(api_info, api_info.get("$defs")) |
| | return any(file_data_format in info for file_data_format in FILE_DATA_FORMATS) |
| |
|
| |
|
| | def is_filepath(s) -> bool: |
| | """ |
| | Check if the given value is a valid str or Path filepath on the local filesystem, e.g. "path/to/file". |
| | """ |
| | return isinstance(s, (str, Path)) and Path(s).exists() and Path(s).is_file() |
| |
|
| |
|
| | def is_file_obj(d) -> bool: |
| | """ |
| | Check if the given value is a valid FileData object dictionary in versions of Gradio<=4.20, e.g. |
| | { |
| | "path": "path/to/file", |
| | } |
| | """ |
| | return isinstance(d, dict) and "path" in d and isinstance(d["path"], str) |
| |
|
| |
|
| | def is_file_obj_with_meta(d) -> bool: |
| | """ |
| | Check if the given value is a valid FileData object dictionary in newer versions of Gradio |
| | where the file objects include a specific "meta" key, e.g. |
| | { |
| | "path": "path/to/file", |
| | "meta": {"_type: "gradio.FileData"} |
| | } |
| | """ |
| | return ( |
| | isinstance(d, dict) |
| | and "path" in d |
| | and isinstance(d["path"], str) |
| | and "meta" in d |
| | and d["meta"].get("_type", "") == "gradio.FileData" |
| | ) |
| |
|
| |
|
| | def is_file_obj_with_url(d) -> bool: |
| | """ |
| | Check if the given value is a valid FileData object dictionary in newer versions of Gradio |
| | where the file objects include a specific "meta" key, and ALSO include a "url" key, e.g. |
| | { |
| | "path": "path/to/file", |
| | "url": "/file=path/to/file", |
| | "meta": {"_type: "gradio.FileData"} |
| | } |
| | """ |
| | return is_file_obj_with_meta(d) and "url" in d and isinstance(d["url"], str) |
| |
|
| |
|
| | SKIP_COMPONENTS = { |
| | "state", |
| | "row", |
| | "column", |
| | "tabs", |
| | "tab", |
| | "tabitem", |
| | "box", |
| | "form", |
| | "accordion", |
| | "group", |
| | "interpretation", |
| | "dataset", |
| | } |
| |
|
| |
|
| | def handle_file(filepath_or_url: str | Path): |
| | s = str(filepath_or_url) |
| | data = {"path": s, "meta": {"_type": "gradio.FileData"}} |
| | if is_http_url_like(s): |
| | return {**data, "orig_name": s.split("/")[-1], "url": s} |
| | elif Path(s).exists(): |
| | return {**data, "orig_name": Path(s).name} |
| | else: |
| | raise ValueError( |
| | f"File {s} does not exist on local filesystem and is not a valid URL." |
| | ) |
| |
|
| |
|
| | def file(filepath_or_url: str | Path): |
| | warnings.warn( |
| | "file() is deprecated and will be removed in a future version. Use handle_file() instead." |
| | ) |
| | return handle_file(filepath_or_url) |
| |
|
| |
|
| | def construct_args( |
| | parameters_info: list[ParameterInfo] | None, args: tuple, kwargs: dict |
| | ) -> list: |
| | class _Keywords(Enum): |
| | NO_VALUE = "NO_VALUE" |
| |
|
| | _args = list(args) |
| | if parameters_info is None: |
| | if kwargs: |
| | raise ValueError( |
| | "This endpoint does not support key-word arguments Please click on 'view API' in the footer of the Gradio app to see usage." |
| | ) |
| | return _args |
| | num_args = len(args) |
| | _args = _args + [_Keywords.NO_VALUE] * (len(parameters_info) - num_args) |
| |
|
| | kwarg_arg_mapping = {} |
| | kwarg_names = [] |
| | for index, param_info in enumerate(parameters_info): |
| | if "parameter_name" in param_info: |
| | kwarg_arg_mapping[param_info["parameter_name"]] = index |
| | kwarg_names.append(param_info["parameter_name"]) |
| | else: |
| | kwarg_names.append("argument {index}") |
| | if ( |
| | param_info.get("parameter_has_default", False) |
| | and _args[index] == _Keywords.NO_VALUE |
| | ): |
| | _args[index] = param_info.get("parameter_default") |
| |
|
| | for key, value in kwargs.items(): |
| | if key in kwarg_arg_mapping: |
| | if kwarg_arg_mapping[key] < num_args: |
| | raise TypeError( |
| | f"Parameter `{key}` is already set as a positional argument. Please click on 'view API' in the footer of the Gradio app to see usage." |
| | ) |
| | else: |
| | _args[kwarg_arg_mapping[key]] = value |
| | else: |
| | raise TypeError( |
| | f"Parameter `{key}` is not a valid key-word argument. Please click on 'view API' in the footer of the Gradio app to see usage." |
| | ) |
| |
|
| | if _Keywords.NO_VALUE in _args: |
| | raise TypeError( |
| | f"No value provided for required argument: {kwarg_names[_args.index(_Keywords.NO_VALUE)]}" |
| | ) |
| |
|
| | return _args |
| |
|