|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Contains utilities used by both the sync and async inference clients.""" |
|
import base64 |
|
import io |
|
import json |
|
import logging |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import ( |
|
TYPE_CHECKING, |
|
Any, |
|
AsyncIterable, |
|
BinaryIO, |
|
ContextManager, |
|
Dict, |
|
Generator, |
|
Iterable, |
|
List, |
|
Literal, |
|
Optional, |
|
Set, |
|
Union, |
|
overload, |
|
) |
|
|
|
from requests import HTTPError |
|
|
|
from ..constants import ENDPOINT |
|
from ..utils import ( |
|
build_hf_headers, |
|
get_session, |
|
hf_raise_for_status, |
|
is_aiohttp_available, |
|
is_numpy_available, |
|
is_pillow_available, |
|
) |
|
from ._text_generation import TextGenerationStreamResponse, _parse_text_generation_error |
|
|
|
|
|
if TYPE_CHECKING: |
|
from aiohttp import ClientResponse, ClientSession |
|
from PIL import Image |
|
|
|
|
|
UrlT = str |
|
PathT = Union[str, Path] |
|
BinaryT = Union[bytes, BinaryIO] |
|
ContentT = Union[BinaryT, PathT, UrlT] |
|
|
|
|
|
TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
@dataclass |
|
class ModelStatus: |
|
""" |
|
This Dataclass represents the the model status in the Hugging Face Inference API. |
|
|
|
Args: |
|
loaded (`bool`): |
|
If the model is currently loaded into Hugging Face's InferenceAPI. Models |
|
are loaded on-demand, leading to the user's first request taking longer. |
|
If a model is loaded, you can be assured that it is in a healthy state. |
|
state (`str`): |
|
The current state of the model. This can be 'Loaded', 'Loadable', 'TooBig'. |
|
If a model's state is 'Loadable', it's not too big and has a supported |
|
backend. Loadable models are automatically loaded when the user first |
|
requests inference on the endpoint. This means it is transparent for the |
|
user to load a model, except that the first call takes longer to complete. |
|
compute_type (`str`): |
|
The type of compute resource the model is using or will use, such as 'gpu' or 'cpu'. |
|
framework (`str`): |
|
The name of the framework that the model was built with, such as 'transformers' |
|
or 'text-generation-inference'. |
|
""" |
|
|
|
loaded: bool |
|
state: str |
|
compute_type: str |
|
framework: str |
|
|
|
|
|
class InferenceTimeoutError(HTTPError, TimeoutError): |
|
"""Error raised when a model is unavailable or the request times out.""" |
|
|
|
|
|
|
|
|
|
|
|
def _import_aiohttp(): |
|
|
|
if not is_aiohttp_available(): |
|
raise ImportError("Please install aiohttp to use `AsyncInferenceClient` (`pip install aiohttp`).") |
|
import aiohttp |
|
|
|
return aiohttp |
|
|
|
|
|
def _import_numpy(): |
|
"""Make sure `numpy` is installed on the machine.""" |
|
if not is_numpy_available(): |
|
raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") |
|
import numpy |
|
|
|
return numpy |
|
|
|
|
|
def _import_pil_image(): |
|
"""Make sure `PIL` is installed on the machine.""" |
|
if not is_pillow_available(): |
|
raise ImportError( |
|
"Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" |
|
" post-processed, use `client.post(...)` and get the raw response from the server." |
|
) |
|
from PIL import Image |
|
|
|
return Image |
|
|
|
|
|
|
|
|
|
|
|
_RECOMMENDED_MODELS: Optional[Dict[str, Optional[str]]] = None |
|
|
|
|
|
def _fetch_recommended_models() -> Dict[str, Optional[str]]: |
|
global _RECOMMENDED_MODELS |
|
if _RECOMMENDED_MODELS is None: |
|
response = get_session().get(f"{ENDPOINT}/api/tasks", headers=build_hf_headers()) |
|
hf_raise_for_status(response) |
|
_RECOMMENDED_MODELS = { |
|
task: _first_or_none(details["widgetModels"]) for task, details in response.json().items() |
|
} |
|
return _RECOMMENDED_MODELS |
|
|
|
|
|
def _first_or_none(items: List[Any]) -> Optional[Any]: |
|
try: |
|
return items[0] or None |
|
except IndexError: |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
def _open_as_binary(content: ContentT) -> ContextManager[BinaryT]: |
|
... |
|
|
|
|
|
@overload |
|
def _open_as_binary(content: Literal[None]) -> ContextManager[Literal[None]]: |
|
... |
|
|
|
|
|
@contextmanager |
|
def _open_as_binary(content: Optional[ContentT]) -> Generator[Optional[BinaryT], None, None]: |
|
"""Open `content` as a binary file, either from a URL, a local path, or raw bytes. |
|
|
|
Do nothing if `content` is None, |
|
|
|
TODO: handle a PIL.Image as input |
|
TODO: handle base64 as input |
|
""" |
|
|
|
if isinstance(content, str): |
|
if content.startswith("https://") or content.startswith("http://"): |
|
logger.debug(f"Downloading content from {content}") |
|
yield get_session().get(content).content |
|
return |
|
content = Path(content) |
|
if not content.exists(): |
|
raise FileNotFoundError( |
|
f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" |
|
" file. To pass raw content, please encode it as bytes first." |
|
) |
|
|
|
|
|
if isinstance(content, Path): |
|
logger.debug(f"Opening content from {content}") |
|
with content.open("rb") as f: |
|
yield f |
|
else: |
|
|
|
yield content |
|
|
|
|
|
def _b64_encode(content: ContentT) -> str: |
|
"""Encode a raw file (image, audio) into base64. Can be byes, an opened file, a path or a URL.""" |
|
with _open_as_binary(content) as data: |
|
data_as_bytes = data if isinstance(data, bytes) else data.read() |
|
return base64.b64encode(data_as_bytes).decode() |
|
|
|
|
|
def _b64_to_image(encoded_image: str) -> "Image": |
|
"""Parse a base64-encoded string into a PIL Image.""" |
|
Image = _import_pil_image() |
|
return Image.open(io.BytesIO(base64.b64decode(encoded_image))) |
|
|
|
|
|
def _bytes_to_list(content: bytes) -> List: |
|
"""Parse bytes from a Response object into a Python list. |
|
|
|
Expects the response body to be JSON-encoded data. |
|
|
|
NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a |
|
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
|
""" |
|
return json.loads(content.decode()) |
|
|
|
|
|
def _bytes_to_dict(content: bytes) -> Dict: |
|
"""Parse bytes from a Response object into a Python dictionary. |
|
|
|
Expects the response body to be JSON-encoded data. |
|
|
|
NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a |
|
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
|
""" |
|
return json.loads(content.decode()) |
|
|
|
|
|
def _bytes_to_image(content: bytes) -> "Image": |
|
"""Parse bytes from a Response object into a PIL Image. |
|
|
|
Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. |
|
""" |
|
Image = _import_pil_image() |
|
return Image.open(io.BytesIO(content)) |
|
|
|
|
|
|
|
|
|
|
|
def _stream_text_generation_response( |
|
bytes_output_as_lines: Iterable[bytes], details: bool |
|
) -> Union[Iterable[str], Iterable[TextGenerationStreamResponse]]: |
|
|
|
for byte_payload in bytes_output_as_lines: |
|
|
|
if byte_payload == b"\n": |
|
continue |
|
|
|
payload = byte_payload.decode("utf-8") |
|
|
|
|
|
if payload.startswith("data:"): |
|
|
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) |
|
|
|
if json_payload.get("error") is not None: |
|
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
|
|
|
output = TextGenerationStreamResponse(**json_payload) |
|
yield output.token.text if not details else output |
|
|
|
|
|
async def _async_stream_text_generation_response( |
|
bytes_output_as_lines: AsyncIterable[bytes], details: bool |
|
) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamResponse]]: |
|
|
|
async for byte_payload in bytes_output_as_lines: |
|
|
|
if byte_payload == b"\n": |
|
continue |
|
|
|
payload = byte_payload.decode("utf-8") |
|
|
|
|
|
if payload.startswith("data:"): |
|
|
|
json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) |
|
|
|
if json_payload.get("error") is not None: |
|
raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
|
|
|
output = TextGenerationStreamResponse(**json_payload) |
|
yield output.token.text if not details else output |
|
|
|
|
|
async def _async_yield_from(client: "ClientSession", response: "ClientResponse") -> AsyncIterable[bytes]: |
|
async for byte_payload in response.content: |
|
yield byte_payload |
|
await client.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_NON_TGI_SERVERS: Set[Optional[str]] = set() |
|
|
|
|
|
def _set_as_non_tgi(model: Optional[str]) -> None: |
|
_NON_TGI_SERVERS.add(model) |
|
|
|
|
|
def _is_tgi_server(model: Optional[str]) -> bool: |
|
return model not in _NON_TGI_SERVERS |
|
|