| import time |
| from dataclasses import dataclass, field |
| from datetime import datetime |
| from enum import Enum |
| from typing import TYPE_CHECKING, Optional, Union |
|
|
| from huggingface_hub.errors import InferenceEndpointError, InferenceEndpointTimeoutError |
|
|
| from .utils import get_session, logging, parse_datetime |
|
|
|
|
| if TYPE_CHECKING: |
| from .hf_api import HfApi |
| from .inference._client import InferenceClient |
| from .inference._generated._async_client import AsyncInferenceClient |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class InferenceEndpointStatus(str, Enum): |
| PENDING = "pending" |
| INITIALIZING = "initializing" |
| UPDATING = "updating" |
| UPDATE_FAILED = "updateFailed" |
| RUNNING = "running" |
| PAUSED = "paused" |
| FAILED = "failed" |
| SCALED_TO_ZERO = "scaledToZero" |
|
|
|
|
| class InferenceEndpointType(str, Enum): |
| PUBlIC = "public" |
| PROTECTED = "protected" |
| PRIVATE = "private" |
|
|
|
|
| class InferenceEndpointScalingMetric(str, Enum): |
| PENDING_REQUESTS = "pendingRequests" |
| HARDWARE_USAGE = "hardwareUsage" |
|
|
|
|
| @dataclass |
| class InferenceEndpoint: |
| """ |
| Contains information about a deployed Inference Endpoint. |
| |
| Args: |
| name (`str`): |
| The unique name of the Inference Endpoint. |
| namespace (`str`): |
| The namespace where the Inference Endpoint is located. |
| repository (`str`): |
| The name of the model repository deployed on this Inference Endpoint. |
| status ([`InferenceEndpointStatus`]): |
| The current status of the Inference Endpoint. |
| url (`str`, *optional*): |
| The URL of the Inference Endpoint, if available. Only a deployed Inference Endpoint will have a URL. |
| framework (`str`): |
| The machine learning framework used for the model. |
| revision (`str`): |
| The specific model revision deployed on the Inference Endpoint. |
| task (`str`): |
| The task associated with the deployed model. |
| created_at (`datetime.datetime`): |
| The timestamp when the Inference Endpoint was created. |
| updated_at (`datetime.datetime`): |
| The timestamp of the last update of the Inference Endpoint. |
| type ([`InferenceEndpointType`]): |
| The type of the Inference Endpoint (public, protected, private). |
| raw (`dict`): |
| The raw dictionary data returned from the API. |
| token (`str` or `bool`, *optional*): |
| Authentication token for the Inference Endpoint, if set when requesting the API. Will default to the |
| locally saved token if not provided. Pass `token=False` if you don't want to send your token to the server. |
| |
| Example: |
| ```python |
| >>> from huggingface_hub import get_inference_endpoint |
| >>> endpoint = get_inference_endpoint("my-text-to-image") |
| >>> endpoint |
| InferenceEndpoint(name='my-text-to-image', ...) |
| |
| # Get status |
| >>> endpoint.status |
| 'running' |
| >>> endpoint.url |
| 'https://my-text-to-image.region.vendor.endpoints.huggingface.cloud' |
| |
| # Run inference |
| >>> endpoint.client.text_to_image(...) |
| |
| # Pause endpoint to save $$$ |
| >>> endpoint.pause() |
| |
| # ... |
| # Resume and wait for deployment |
| >>> endpoint.resume() |
| >>> endpoint.wait() |
| >>> endpoint.client.text_to_image(...) |
| ``` |
| """ |
|
|
| |
| name: str = field(init=False) |
| namespace: str |
| repository: str = field(init=False) |
| status: InferenceEndpointStatus = field(init=False) |
| health_route: str = field(init=False) |
| url: Optional[str] = field(init=False) |
|
|
| |
| framework: str = field(repr=False, init=False) |
| revision: str = field(repr=False, init=False) |
| task: str = field(repr=False, init=False) |
| created_at: datetime = field(repr=False, init=False) |
| updated_at: datetime = field(repr=False, init=False) |
| type: InferenceEndpointType = field(repr=False, init=False) |
|
|
| |
| raw: dict = field(repr=False) |
|
|
| |
| _token: Union[str, bool, None] = field(repr=False, compare=False) |
| _api: "HfApi" = field(repr=False, compare=False) |
|
|
| @classmethod |
| def from_raw( |
| cls, raw: dict, namespace: str, token: Union[str, bool, None] = None, api: Optional["HfApi"] = None |
| ) -> "InferenceEndpoint": |
| """Initialize object from raw dictionary.""" |
| if api is None: |
| from .hf_api import HfApi |
|
|
| api = HfApi() |
| if token is None: |
| token = api.token |
|
|
| |
| return cls(raw=raw, namespace=namespace, _token=token, _api=api) |
|
|
| def __post_init__(self) -> None: |
| """Populate fields from raw dictionary.""" |
| self._populate_from_raw() |
|
|
| @property |
| def client(self) -> "InferenceClient": |
| """Returns a client to make predictions on this Inference Endpoint. |
| |
| Returns: |
| [`InferenceClient`]: an inference client pointing to the deployed endpoint. |
| |
| Raises: |
| [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. |
| """ |
| if self.url is None: |
| raise InferenceEndpointError( |
| "Cannot create a client for this Inference Endpoint as it is not yet deployed. " |
| "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." |
| ) |
| from .inference._client import InferenceClient |
|
|
| return InferenceClient( |
| model=self.url, |
| token=self._token, |
| ) |
|
|
| @property |
| def async_client(self) -> "AsyncInferenceClient": |
| """Returns a client to make predictions on this Inference Endpoint. |
| |
| Returns: |
| [`AsyncInferenceClient`]: an asyncio-compatible inference client pointing to the deployed endpoint. |
| |
| Raises: |
| [`InferenceEndpointError`]: If the Inference Endpoint is not yet deployed. |
| """ |
| if self.url is None: |
| raise InferenceEndpointError( |
| "Cannot create a client for this Inference Endpoint as it is not yet deployed. " |
| "Please wait for the Inference Endpoint to be deployed using `endpoint.wait()` and try again." |
| ) |
| from .inference._generated._async_client import AsyncInferenceClient |
|
|
| return AsyncInferenceClient( |
| model=self.url, |
| token=self._token, |
| ) |
|
|
| def wait(self, timeout: Optional[int] = None, refresh_every: int = 5) -> "InferenceEndpoint": |
| """Wait for the Inference Endpoint to be deployed. |
| |
| Information from the server will be fetched every 1s. If the Inference Endpoint is not deployed after `timeout` |
| seconds, a [`InferenceEndpointTimeoutError`] will be raised. The [`InferenceEndpoint`] will be mutated in place with the latest |
| data. |
| |
| Args: |
| timeout (`int`, *optional*): |
| The maximum time to wait for the Inference Endpoint to be deployed, in seconds. If `None`, will wait |
| indefinitely. |
| refresh_every (`int`, *optional*): |
| The time to wait between each fetch of the Inference Endpoint status, in seconds. Defaults to 5s. |
| |
| Returns: |
| [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. |
| |
| Raises: |
| [`InferenceEndpointError`] |
| If the Inference Endpoint ended up in a failed state. |
| [`InferenceEndpointTimeoutError`] |
| If the Inference Endpoint is not deployed after `timeout` seconds. |
| """ |
| if timeout is not None and timeout < 0: |
| raise ValueError("`timeout` cannot be negative.") |
| if refresh_every <= 0: |
| raise ValueError("`refresh_every` must be positive.") |
|
|
| start = time.time() |
| while True: |
| if self.status == InferenceEndpointStatus.FAILED: |
| raise InferenceEndpointError( |
| f"Inference Endpoint {self.name} failed to deploy. Please check the logs for more information." |
| ) |
| if self.status == InferenceEndpointStatus.UPDATE_FAILED: |
| raise InferenceEndpointError( |
| f"Inference Endpoint {self.name} failed to update. Please check the logs for more information." |
| ) |
| if self.status == InferenceEndpointStatus.RUNNING and self.url is not None: |
| |
| _health_url = f"{self.url.rstrip('/')}/{self.health_route.lstrip('/')}" |
| response = get_session().get(_health_url, headers=self._api._build_hf_headers(token=self._token)) |
| if response.status_code == 200: |
| logger.info("Inference Endpoint is ready to be used.") |
| return self |
|
|
| if timeout is not None: |
| if time.time() - start > timeout: |
| raise InferenceEndpointTimeoutError("Timeout while waiting for Inference Endpoint to be deployed.") |
| logger.info(f"Inference Endpoint is not deployed yet ({self.status}). Waiting {refresh_every}s...") |
| time.sleep(refresh_every) |
| self.fetch() |
|
|
| def fetch(self) -> "InferenceEndpoint": |
| """Fetch latest information about the Inference Endpoint. |
| |
| Returns: |
| [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. |
| """ |
| obj = self._api.get_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) |
| self.raw = obj.raw |
| self._populate_from_raw() |
| return self |
|
|
| def update( |
| self, |
| *, |
| |
| accelerator: Optional[str] = None, |
| instance_size: Optional[str] = None, |
| instance_type: Optional[str] = None, |
| min_replica: Optional[int] = None, |
| max_replica: Optional[int] = None, |
| scale_to_zero_timeout: Optional[int] = None, |
| |
| repository: Optional[str] = None, |
| framework: Optional[str] = None, |
| revision: Optional[str] = None, |
| task: Optional[str] = None, |
| custom_image: Optional[dict] = None, |
| secrets: Optional[dict[str, str]] = None, |
| ) -> "InferenceEndpoint": |
| """Update the Inference Endpoint. |
| |
| This method allows the update of either the compute configuration, the deployed model, or both. All arguments are |
| optional but at least one must be provided. |
| |
| This is an alias for [`HfApi.update_inference_endpoint`]. The current object is mutated in place with the |
| latest data from the server. |
| |
| Args: |
| accelerator (`str`, *optional*): |
| The hardware accelerator to be used for inference (e.g. `"cpu"`). |
| instance_size (`str`, *optional*): |
| The size or type of the instance to be used for hosting the model (e.g. `"x4"`). |
| instance_type (`str`, *optional*): |
| The cloud instance type where the Inference Endpoint will be deployed (e.g. `"intel-icl"`). |
| min_replica (`int`, *optional*): |
| The minimum number of replicas (instances) to keep running for the Inference Endpoint. |
| max_replica (`int`, *optional*): |
| The maximum number of replicas (instances) to scale to for the Inference Endpoint. |
| scale_to_zero_timeout (`int`, *optional*): |
| The duration in minutes before an inactive endpoint is scaled to zero. |
| |
| repository (`str`, *optional*): |
| The name of the model repository associated with the Inference Endpoint (e.g. `"gpt2"`). |
| framework (`str`, *optional*): |
| The machine learning framework used for the model (e.g. `"custom"`). |
| revision (`str`, *optional*): |
| The specific model revision to deploy on the Inference Endpoint (e.g. `"6c0e6080953db56375760c0471a8c5f2929baf11"`). |
| task (`str`, *optional*): |
| The task on which to deploy the model (e.g. `"text-classification"`). |
| custom_image (`dict`, *optional*): |
| A custom Docker image to use for the Inference Endpoint. This is useful if you want to deploy an |
| Inference Endpoint running on the `text-generation-inference` (TGI) framework (see examples). |
| secrets (`dict[str, str]`, *optional*): |
| Secret values to inject in the container environment. |
| Returns: |
| [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. |
| """ |
| |
| obj = self._api.update_inference_endpoint( |
| name=self.name, |
| namespace=self.namespace, |
| accelerator=accelerator, |
| instance_size=instance_size, |
| instance_type=instance_type, |
| min_replica=min_replica, |
| max_replica=max_replica, |
| scale_to_zero_timeout=scale_to_zero_timeout, |
| repository=repository, |
| framework=framework, |
| revision=revision, |
| task=task, |
| custom_image=custom_image, |
| secrets=secrets, |
| token=self._token, |
| ) |
|
|
| |
| self.raw = obj.raw |
| self._populate_from_raw() |
| return self |
|
|
| def pause(self) -> "InferenceEndpoint": |
| """Pause the Inference Endpoint. |
| |
| A paused Inference Endpoint will not be charged. It can be resumed at any time using [`InferenceEndpoint.resume`]. |
| This is different from scaling the Inference Endpoint to zero with [`InferenceEndpoint.scale_to_zero`], which |
| would be automatically restarted when a request is made to it. |
| |
| This is an alias for [`HfApi.pause_inference_endpoint`]. The current object is mutated in place with the |
| latest data from the server. |
| |
| Returns: |
| [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. |
| """ |
| obj = self._api.pause_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) |
| self.raw = obj.raw |
| self._populate_from_raw() |
| return self |
|
|
| def resume(self, running_ok: bool = True) -> "InferenceEndpoint": |
| """Resume the Inference Endpoint. |
| |
| This is an alias for [`HfApi.resume_inference_endpoint`]. The current object is mutated in place with the |
| latest data from the server. |
| |
| Args: |
| running_ok (`bool`, *optional*): |
| If `True`, the method will not raise an error if the Inference Endpoint is already running. Defaults to |
| `True`. |
| |
| Returns: |
| [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. |
| """ |
| obj = self._api.resume_inference_endpoint( |
| name=self.name, namespace=self.namespace, running_ok=running_ok, token=self._token |
| ) |
| self.raw = obj.raw |
| self._populate_from_raw() |
| return self |
|
|
| def scale_to_zero(self) -> "InferenceEndpoint": |
| """Scale Inference Endpoint to zero. |
| |
| An Inference Endpoint scaled to zero will not be charged. It will be resumed on the next request to it, with a |
| cold start delay. This is different from pausing the Inference Endpoint with [`InferenceEndpoint.pause`], which |
| would require a manual resume with [`InferenceEndpoint.resume`]. |
| |
| This is an alias for [`HfApi.scale_to_zero_inference_endpoint`]. The current object is mutated in place with the |
| latest data from the server. |
| |
| Returns: |
| [`InferenceEndpoint`]: the same Inference Endpoint, mutated in place with the latest data. |
| """ |
| obj = self._api.scale_to_zero_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) |
| self.raw = obj.raw |
| self._populate_from_raw() |
| return self |
|
|
| def delete(self) -> None: |
| """Delete the Inference Endpoint. |
| |
| This operation is not reversible. If you don't want to be charged for an Inference Endpoint, it is preferable |
| to pause it with [`InferenceEndpoint.pause`] or scale it to zero with [`InferenceEndpoint.scale_to_zero`]. |
| |
| This is an alias for [`HfApi.delete_inference_endpoint`]. |
| """ |
| self._api.delete_inference_endpoint(name=self.name, namespace=self.namespace, token=self._token) |
|
|
| def _populate_from_raw(self) -> None: |
| """Populate fields from raw dictionary. |
| |
| Called in __post_init__ + each time the Inference Endpoint is updated. |
| """ |
| |
| self.name = self.raw["name"] |
| self.repository = self.raw["model"]["repository"] |
| self.status = self.raw["status"]["state"] |
| self.url = self.raw["status"].get("url") |
| self.health_route = self.raw["healthRoute"] |
|
|
| |
| self.framework = self.raw["model"]["framework"] |
| self.revision = self.raw["model"]["revision"] |
| self.task = self.raw["model"]["task"] |
| self.created_at = parse_datetime(self.raw["status"]["createdAt"]) |
| self.updated_at = parse_datetime(self.raw["status"]["updatedAt"]) |
| self.type = self.raw["type"] |
|
|