| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Contains utilities used by both the sync and async inference clients.""" |
|
|
| import base64 |
| import io |
| import json |
| import logging |
| import mimetypes |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import TYPE_CHECKING, Any, AsyncIterable, BinaryIO, Iterable, Literal, NoReturn, Optional, Union, overload |
|
|
| import httpx |
|
|
| from huggingface_hub.errors import ( |
| GenerationError, |
| HfHubHTTPError, |
| IncompleteGenerationError, |
| OverloadedError, |
| TextGenerationError, |
| UnknownError, |
| ValidationError, |
| ) |
|
|
| from ..utils import get_session, is_numpy_available, is_pillow_available |
| from ._generated.types import ChatCompletionStreamOutput, TextGenerationStreamOutput |
|
|
|
|
| if TYPE_CHECKING: |
| from PIL.Image import Image |
|
|
| |
| UrlT = str |
| PathT = Union[str, Path] |
| ContentT = Union[bytes, BinaryIO, PathT, UrlT, "Image", bytearray, memoryview] |
|
|
| |
| TASKS_EXPECTING_IMAGES = {"text-to-image", "image-to-image"} |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class RequestParameters: |
| url: str |
| task: str |
| model: Optional[str] |
| json: Optional[Union[str, dict, list]] |
| data: Optional[bytes] |
| headers: dict[str, Any] |
|
|
|
|
| class MimeBytes(bytes): |
| """ |
| A bytes object with a mime type. |
| To be returned by `_prepare_payload_open_as_mime_bytes` in subclasses. |
| |
| Example: |
| ```python |
| >>> b = MimeBytes(b"hello", "text/plain") |
| >>> isinstance(b, bytes) |
| True |
| >>> b.mime_type |
| 'text/plain' |
| ``` |
| """ |
|
|
| mime_type: Optional[str] |
|
|
| def __new__(cls, data: bytes, mime_type: Optional[str] = None): |
| obj = super().__new__(cls, data) |
| obj.mime_type = mime_type |
| if isinstance(data, MimeBytes) and mime_type is None: |
| obj.mime_type = data.mime_type |
| return obj |
|
|
|
|
| |
|
|
|
|
| def _import_numpy(): |
| """Make sure `numpy` is installed on the machine.""" |
| if not is_numpy_available(): |
| raise ImportError("Please install numpy to use deal with embeddings (`pip install numpy`).") |
| import numpy |
|
|
| return numpy |
|
|
|
|
| def _import_pil_image(): |
| """Make sure `PIL` is installed on the machine.""" |
| if not is_pillow_available(): |
| raise ImportError( |
| "Please install Pillow to use deal with images (`pip install Pillow`). If you don't want the image to be" |
| " post-processed, use `client.post(...)` and get the raw response from the server." |
| ) |
| from PIL import Image |
|
|
| return Image |
|
|
|
|
| |
|
|
|
|
| @overload |
| def _open_as_mime_bytes(content: ContentT) -> MimeBytes: ... |
|
|
|
|
| @overload |
| def _open_as_mime_bytes(content: Literal[None]) -> Literal[None]: ... |
|
|
|
|
| def _open_as_mime_bytes(content: Optional[ContentT]) -> Optional[MimeBytes]: |
| """Open `content` as a binary file, either from a URL, a local path, raw bytes, or a PIL Image. |
| |
| Do nothing if `content` is None. |
| """ |
| |
| if content is None: |
| return None |
|
|
| |
| if isinstance(content, bytes): |
| return MimeBytes(content) |
|
|
| |
| if isinstance(content, (bytearray, memoryview)): |
| return MimeBytes(bytes(content)) |
|
|
| |
| if hasattr(content, "read"): |
| logger.debug("Reading content from BinaryIO") |
| data = content.read() |
| mime_type = mimetypes.guess_type(str(content.name))[0] if hasattr(content, "name") else None |
| if isinstance(data, str): |
| raise TypeError("Expected binary stream (bytes), but got text stream") |
| return MimeBytes(data, mime_type=mime_type) |
|
|
| |
| if isinstance(content, str): |
| if content.startswith("https://") or content.startswith("http://"): |
| logger.debug(f"Downloading content from {content}") |
| response = get_session().get(content) |
| mime_type = response.headers.get("Content-Type") |
| if mime_type is None: |
| mime_type = mimetypes.guess_type(content)[0] |
| return MimeBytes(response.content, mime_type=mime_type) |
|
|
| content = Path(content) |
| if not content.exists(): |
| raise FileNotFoundError( |
| f"File not found at {content}. If `data` is a string, it must either be a URL or a path to a local" |
| " file. To pass raw content, please encode it as bytes first." |
| ) |
|
|
| |
| if isinstance(content, Path): |
| logger.debug(f"Opening content from {content}") |
| return MimeBytes(content.read_bytes(), mime_type=mimetypes.guess_type(content)[0]) |
|
|
| |
| if is_pillow_available(): |
| from PIL import Image |
|
|
| if isinstance(content, Image.Image): |
| logger.debug("Converting PIL Image to bytes") |
| buffer = io.BytesIO() |
| format = content.format or "PNG" |
| content.save(buffer, format=format) |
| return MimeBytes(buffer.getvalue(), mime_type=f"image/{format.lower()}") |
|
|
| |
| raise TypeError( |
| f"Unsupported content type: {type(content)}. " |
| "Expected one of: bytes, bytearray, BinaryIO, memoryview, Path, str (URL or file path), or PIL.Image.Image." |
| ) |
|
|
|
|
| def _b64_encode(content: ContentT) -> str: |
| """Encode a raw file (image, audio) into base64. Can be bytes, an opened file, a path or a URL.""" |
| raw_bytes = _open_as_mime_bytes(content) |
| return base64.b64encode(raw_bytes).decode() |
|
|
|
|
| def _as_url(content: ContentT, default_mime_type: str) -> str: |
| if isinstance(content, str) and content.startswith(("http://", "https://", "data:")): |
| return content |
|
|
| |
| raw_bytes = _open_as_mime_bytes(content) |
|
|
| |
| mime_type = raw_bytes.mime_type or default_mime_type |
|
|
| |
| encoded_data = base64.b64encode(raw_bytes).decode() |
|
|
| |
| return f"data:{mime_type};base64,{encoded_data}" |
|
|
|
|
| def _b64_to_image(encoded_image: str) -> "Image": |
| """Parse a base64-encoded string into a PIL Image.""" |
| Image = _import_pil_image() |
| return Image.open(io.BytesIO(base64.b64decode(encoded_image))) |
|
|
|
|
| def _bytes_to_list(content: bytes) -> list: |
| """Parse bytes from a Response object into a Python list. |
| |
| Expects the response body to be JSON-encoded data. |
| |
| NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a |
| dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
| """ |
| return json.loads(content.decode()) |
|
|
|
|
| def _bytes_to_dict(content: bytes) -> dict: |
| """Parse bytes from a Response object into a Python dictionary. |
| |
| Expects the response body to be JSON-encoded data. |
| |
| NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a |
| list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect. |
| """ |
| return json.loads(content.decode()) |
|
|
|
|
| def _bytes_to_image(content: bytes) -> "Image": |
| """Parse bytes from a Response object into a PIL Image. |
| |
| Expects the response body to be raw bytes. To deal with b64 encoded images, use `_b64_to_image` instead. |
| """ |
| Image = _import_pil_image() |
| return Image.open(io.BytesIO(content)) |
|
|
|
|
| def _as_dict(response: Union[bytes, dict]) -> dict: |
| return json.loads(response) if isinstance(response, bytes) else response |
|
|
|
|
| |
|
|
|
|
| def _stream_text_generation_response( |
| output_lines: Iterable[str], details: bool |
| ) -> Union[Iterable[str], Iterable[TextGenerationStreamOutput]]: |
| """Used in `InferenceClient.text_generation`.""" |
| |
| for line in output_lines: |
| try: |
| output = _format_text_generation_stream_output(line, details) |
| except StopIteration: |
| break |
| if output is not None: |
| yield output |
|
|
|
|
| async def _async_stream_text_generation_response( |
| output_lines: AsyncIterable[str], details: bool |
| ) -> Union[AsyncIterable[str], AsyncIterable[TextGenerationStreamOutput]]: |
| """Used in `AsyncInferenceClient.text_generation`.""" |
| |
| async for line in output_lines: |
| try: |
| output = _format_text_generation_stream_output(line, details) |
| except StopIteration: |
| break |
| if output is not None: |
| yield output |
|
|
|
|
| def _format_text_generation_stream_output( |
| line: str, details: bool |
| ) -> Optional[Union[str, TextGenerationStreamOutput]]: |
| if not line.startswith("data:"): |
| return None |
|
|
| if line.strip() == "data: [DONE]": |
| raise StopIteration("[DONE] signal received.") |
|
|
| |
| payload = line.lstrip("data:").rstrip("/n") |
| json_payload = json.loads(payload) |
|
|
| |
| if json_payload.get("error") is not None: |
| raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
|
|
| |
| output = TextGenerationStreamOutput.parse_obj_as_instance(json_payload) |
| return output.token.text if not details else output |
|
|
|
|
| def _stream_chat_completion_response( |
| lines: Iterable[str], |
| ) -> Iterable[ChatCompletionStreamOutput]: |
| """Used in `InferenceClient.chat_completion` if model is served with TGI.""" |
| for line in lines: |
| try: |
| output = _format_chat_completion_stream_output(line) |
| except StopIteration: |
| break |
| if output is not None: |
| yield output |
|
|
|
|
| async def _async_stream_chat_completion_response( |
| lines: AsyncIterable[str], |
| ) -> AsyncIterable[ChatCompletionStreamOutput]: |
| """Used in `AsyncInferenceClient.chat_completion`.""" |
| async for line in lines: |
| try: |
| output = _format_chat_completion_stream_output(line) |
| except StopIteration: |
| break |
| if output is not None: |
| yield output |
|
|
|
|
| def _format_chat_completion_stream_output( |
| line: str, |
| ) -> Optional[ChatCompletionStreamOutput]: |
| if not line.startswith("data:"): |
| return None |
|
|
| if line.strip() == "data: [DONE]": |
| raise StopIteration("[DONE] signal received.") |
|
|
| |
| json_payload = json.loads(line.lstrip("data:").strip()) |
|
|
| |
| if json_payload.get("error") is not None: |
| raise _parse_text_generation_error(json_payload["error"], json_payload.get("error_type")) |
|
|
| |
| return ChatCompletionStreamOutput.parse_obj_as_instance(json_payload) |
|
|
|
|
| async def _async_yield_from(client: httpx.AsyncClient, response: httpx.Response) -> AsyncIterable[str]: |
| async for line in response.aiter_lines(): |
| yield line.strip() |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| _UNSUPPORTED_TEXT_GENERATION_KWARGS: dict[Optional[str], list[str]] = {} |
|
|
|
|
| def _set_unsupported_text_generation_kwargs(model: Optional[str], unsupported_kwargs: list[str]) -> None: |
| _UNSUPPORTED_TEXT_GENERATION_KWARGS.setdefault(model, []).extend(unsupported_kwargs) |
|
|
|
|
| def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> list[str]: |
| return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, []) |
|
|
|
|
| |
| |
| |
| |
| |
|
|
|
|
| def raise_text_generation_error(http_error: HfHubHTTPError) -> NoReturn: |
| """ |
| Try to parse text-generation-inference error message and raise HTTPError in any case. |
| |
| Args: |
| error (`HTTPError`): |
| The HTTPError that have been raised. |
| """ |
| |
| if http_error.response is None: |
| raise http_error |
|
|
| try: |
| |
| payload = getattr(http_error, "response_error_payload", None) or http_error.response.json() |
| error = payload.get("error") |
| error_type = payload.get("error_type") |
| except Exception: |
| raise http_error |
|
|
| |
| if error_type is not None: |
| exception = _parse_text_generation_error(error, error_type) |
| raise exception from http_error |
|
|
| |
| raise http_error |
|
|
|
|
| def _parse_text_generation_error(error: Optional[str], error_type: Optional[str]) -> TextGenerationError: |
| if error_type == "generation": |
| return GenerationError(error) |
| if error_type == "incomplete_generation": |
| return IncompleteGenerationError(error) |
| if error_type == "overloaded": |
| return OverloadedError(error) |
| if error_type == "validation": |
| return ValidationError(error) |
| return UnknownError(error) |
|
|