Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Clients for easy interaction with models deployed on the Triton Inference Server. | |
| Typical usage example: | |
| ```python | |
| client = ModelClient("localhost", "MyModel") | |
| result_dict = client.infer_sample(input_a=a, input_b=b) | |
| client.close() | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| result_dict = client.infer_sample(input1, input2) | |
| result_dict = client.infer_sample(a=input1, b=input2) | |
| ``` | |
| Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError. | |
| """ | |
| import asyncio | |
| import contextlib | |
| import itertools | |
| import logging | |
| import socket | |
| import time | |
| import warnings | |
| from concurrent.futures import Future | |
| from queue import Empty, Full, Queue | |
| from threading import Lock, Thread | |
| from typing import Dict, Optional, Tuple, Union | |
| import gevent | |
| import numpy as np | |
| import tritonclient.grpc | |
| import tritonclient.grpc.aio | |
| import tritonclient.http | |
| import tritonclient.http.aio | |
| import tritonclient.utils | |
| from pytriton.client.asyncio_utils import asyncio_get_model_config, asyncio_wait_for_model_ready | |
| from pytriton.client.exceptions import ( | |
| PyTritonClientClosedError, | |
| PyTritonClientInferenceServerError, | |
| PyTritonClientModelDoesntSupportBatchingError, | |
| PyTritonClientQueueFullError, | |
| PyTritonClientTimeoutError, | |
| PyTritonClientValueError, | |
| ) | |
| from pytriton.client.utils import ( | |
| _DEFAULT_NETWORK_TIMEOUT_S, | |
| _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S, | |
| TritonUrl, | |
| get_model_config, | |
| wait_for_model_ready, | |
| wait_for_server_ready, | |
| ) | |
| from pytriton.client.warnings import NotSupportedTimeoutWarning | |
| from pytriton.model_config.triton_model_config import TritonModelConfig | |
| _LOGGER = logging.getLogger(__name__) | |
| _DEFAULT_SYNC_INIT_TIMEOUT_S = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S | |
| _DEFAULT_FUTURES_INIT_TIMEOUT_S = _DEFAULT_WAIT_FOR_MODEL_TIMEOUT_S | |
| DEFAULT_INFERENCE_TIMEOUT_S = 60.0 | |
| _IOType = Union[Tuple[np.ndarray, ...], Dict[str, np.ndarray]] | |
| def _verify_inputs_args(inputs, named_inputs): | |
| if not inputs and not named_inputs: | |
| raise PyTritonClientValueError("Provide input data") | |
| if not bool(inputs) ^ bool(named_inputs): | |
| raise PyTritonClientValueError("Use either positional either keyword method arguments convention") | |
| def _verify_parameters(parameters_or_headers: Optional[Dict[str, Union[str, int, bool]]] = None): | |
| if parameters_or_headers is None: | |
| return | |
| if not isinstance(parameters_or_headers, dict): | |
| raise PyTritonClientValueError("Parameters and headers must be a dictionary") | |
| for key, value in parameters_or_headers.items(): | |
| if not isinstance(key, str): | |
| raise PyTritonClientValueError("Parameter/header key must be a string") | |
| if not isinstance(value, (str, int, bool)): | |
| raise PyTritonClientValueError("Parameter/header value must be a string, integer or boolean") | |
| class BaseModelClient: | |
| """Base client for model deployed on the Triton Inference Server.""" | |
| def __init__( | |
| self, | |
| url: str, | |
| model_name: str, | |
| model_version: Optional[str] = None, | |
| *, | |
| lazy_init: bool = True, | |
| init_timeout_s: Optional[float] = None, | |
| inference_timeout_s: Optional[float] = None, | |
| model_config: Optional[TritonModelConfig] = None, | |
| ensure_model_is_ready: bool = True, | |
| ): | |
| """Inits BaseModelClient for given model deployed on the Triton Inference Server. | |
| Common usage: | |
| ```python | |
| client = ModelClient("localhost", "BERT") | |
| result_dict = client.infer_sample(input1_sample, input2_sample) | |
| client.close() | |
| ``` | |
| Args: | |
| url: The Triton Inference Server url, e.g. `grpc://localhost:8001`. | |
| In case no scheme is provided http scheme will be used as default. | |
| In case no port is provided default port for given scheme will be used - | |
| 8001 for grpc scheme, 8000 for http scheme. | |
| model_name: name of the model to interact with. | |
| model_version: version of the model to interact with. | |
| If model_version is None inference on latest model will be performed. | |
| The latest versions of the model are numerically the greatest version numbers. | |
| lazy_init: if initialization should be performed just before sending first request to inference server. | |
| init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used. | |
| inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used. | |
| model_config: model configuration. If not passed, it will be read from inference server during initialization. | |
| ensure_model_is_ready: if model should be checked if it is ready before first inference request. | |
| Raises: | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientTimeoutError: | |
| if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. | |
| PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. | |
| """ | |
| self._init_timeout_s = _DEFAULT_SYNC_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s | |
| self._inference_timeout_s = DEFAULT_INFERENCE_TIMEOUT_S if inference_timeout_s is None else inference_timeout_s | |
| self._network_timeout_s = min(_DEFAULT_NETWORK_TIMEOUT_S, self._init_timeout_s) | |
| self._general_client = self.create_client_from_url(url, network_timeout_s=self._network_timeout_s) | |
| self._infer_client = self.create_client_from_url(url, network_timeout_s=self._inference_timeout_s) | |
| self._model_name = model_name | |
| self._model_version = model_version | |
| self._request_id_generator = itertools.count(0) | |
| # Monkey patch __del__ method from client to catch error in client when instance is garbage collected. | |
| # This is needed because we are closing client in __exit__ method or in close method. | |
| # (InferenceClient uses gevent library which does not support closing twice from different threads) | |
| self._monkey_patch_client() | |
| if model_config is not None: | |
| self._model_config = model_config | |
| self._model_ready = None if ensure_model_is_ready else True | |
| else: | |
| self._model_config = None | |
| self._model_ready = None | |
| self._lazy_init: bool = lazy_init | |
| self._handle_lazy_init() | |
| def from_existing_client(cls, existing_client: "BaseModelClient"): | |
| """Create a new instance from an existing client using the same class. | |
| Common usage: | |
| ```python | |
| client = BaseModelClient.from_existing_client(existing_client) | |
| ``` | |
| Args: | |
| existing_client: An instance of an already initialized subclass. | |
| Returns: | |
| A new instance of the same subclass with shared configuration and readiness state. | |
| """ | |
| kwargs = {} | |
| # Copy model configuration and readiness state if present | |
| if hasattr(existing_client, "_model_config"): | |
| kwargs["model_config"] = existing_client._model_config | |
| kwargs["ensure_model_is_ready"] = False | |
| new_client = cls( | |
| url=existing_client._url, | |
| model_name=existing_client._model_name, | |
| model_version=existing_client._model_version, | |
| init_timeout_s=existing_client._init_timeout_s, | |
| inference_timeout_s=existing_client._inference_timeout_s, | |
| **kwargs, | |
| ) | |
| return new_client | |
| def create_client_from_url(self, url: str, network_timeout_s: Optional[float] = None): | |
| """Create Triton Inference Server client. | |
| Args: | |
| url: url of the server to connect to. | |
| If url doesn't contain scheme (e.g. "localhost:8001") http scheme is added. | |
| If url doesn't contain port (e.g. "localhost") default port for given scheme is added. | |
| network_timeout_s: timeout for client commands. Default value is 60.0 s. | |
| Returns: | |
| Triton Inference Server client. | |
| Raises: | |
| PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. | |
| """ | |
| self._triton_url = TritonUrl.from_url(url) | |
| self._url = self._triton_url.without_scheme | |
| self._triton_client_lib = self.get_lib() | |
| self._monkey_patch_client() | |
| if self._triton_url.scheme == "grpc": | |
| # by default grpc client has very large number of timeout, thus we want to make it equal to http client timeout | |
| network_timeout_s = _DEFAULT_NETWORK_TIMEOUT_S if network_timeout_s is None else network_timeout_s | |
| warnings.warn( | |
| f"tritonclient.grpc doesn't support timeout for other commands than infer. Ignoring network_timeout: {network_timeout_s}.", | |
| NotSupportedTimeoutWarning, | |
| stacklevel=1, | |
| ) | |
| triton_client_init_kwargs = self._get_init_extra_args() | |
| _LOGGER.debug( | |
| f"Creating InferenceServerClient for {self._triton_url.with_scheme} with {triton_client_init_kwargs}" | |
| ) | |
| return self._triton_client_lib.InferenceServerClient(self._url, **triton_client_init_kwargs) | |
| def get_lib(self): | |
| """Returns tritonclient library for given scheme.""" | |
| raise NotImplementedError | |
| def _next_request_id(self) -> str: | |
| # pytype complained about creating generator in __init__ method | |
| # so we create it lazily | |
| if getattr(self, "_request_id_generator", None) is None: | |
| self._request_id_generator = itertools.count(0) | |
| return str(next(self._request_id_generator)) | |
| def _get_init_extra_args(self): | |
| timeout = self._inference_timeout_s # pytype: disable=attribute-error | |
| # The inference timeout is used for both the HTTP and the GRPC protocols. However, | |
| # the way the timeout is passed to the client differs depending on the protocol. | |
| # For the HTTP protocol, the timeout is set in the ``__init__`` method as ``network_timeout`` | |
| # and ``connection_timeout``. For the GRPC protocol, the timeout | |
| # is passed to the infer method as ``client_timeout``. | |
| # Both protocols support timeouts correctly and will raise an exception | |
| # if the network request or the inference process takes longer than the timeout. | |
| # This is a design choice of the underlying tritonclient library. | |
| if self._triton_url.scheme != "http": | |
| return {} | |
| kwargs = { | |
| # This value sets the maximum time allowed for each network request in both model loading and inference process | |
| "network_timeout": timeout, | |
| # This value sets the maximum time allowed for establishing a connection to the server. | |
| # We use the inference timeout here instead of the init timeout because the init timeout | |
| # is meant for waiting for the model to be ready. The connection timeout should be shorter | |
| # than the init timeout because it only checks if connection is established (e.g. correct port) | |
| "connection_timeout": timeout, | |
| } | |
| return kwargs | |
| def _monkey_patch_client(self): | |
| pass | |
| def _get_model_config_extra_args(self): | |
| # For the GRPC protocol, the timeout must be passed to the each request as client_timeout | |
| # model_config doesn't yet support timeout but it is planned for the future | |
| # grpc_network_timeout_s will be used for model_config | |
| return {} | |
| def _handle_lazy_init(self): | |
| raise NotImplementedError | |
| class ModelClient(BaseModelClient): | |
| """Synchronous client for model deployed on the Triton Inference Server.""" | |
| def __init__( | |
| self, | |
| url: str, | |
| model_name: str, | |
| model_version: Optional[str] = None, | |
| *, | |
| lazy_init: bool = True, | |
| init_timeout_s: Optional[float] = None, | |
| inference_timeout_s: Optional[float] = None, | |
| model_config: Optional[TritonModelConfig] = None, | |
| ensure_model_is_ready: bool = True, | |
| ): | |
| """Inits ModelClient for given model deployed on the Triton Inference Server. | |
| If `lazy_init` argument is False, model configuration will be read | |
| from inference server during initialization. | |
| Common usage: | |
| ```python | |
| client = ModelClient("localhost", "BERT") | |
| result_dict = client.infer_sample(input1_sample, input2_sample) | |
| client.close() | |
| ``` | |
| Client supports also context manager protocol: | |
| ```python | |
| with ModelClient("localhost", "BERT") as client: | |
| result_dict = client.infer_sample(input1_sample, input2_sample) | |
| ``` | |
| The creation of client requires connection to the server and downloading model configuration. You can create client from existing client using the same class: | |
| ```python | |
| client = ModelClient.from_existing_client(existing_client) | |
| ``` | |
| Args: | |
| url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'. | |
| In case no scheme is provided http scheme will be used as default. | |
| In case no port is provided default port for given scheme will be used - | |
| 8001 for grpc scheme, 8000 for http scheme. | |
| model_name: name of the model to interact with. | |
| model_version: version of the model to interact with. | |
| If model_version is None inference on latest model will be performed. | |
| The latest versions of the model are numerically the greatest version numbers. | |
| lazy_init: if initialization should be performed just before sending first request to inference server. | |
| init_timeout_s: timeout for maximum waiting time in loop, which sends retry requests ask if model is ready. It is applied at initialization time only when `lazy_init` argument is False. Default is to do retry loop at first inference. | |
| inference_timeout_s: timeout in seconds for the model inference process. | |
| If non passed default 60 seconds timeout will be used. | |
| For HTTP client it is not only inference timeout but any client request timeout | |
| - get model config, is model loaded. For GRPC client it is only inference timeout. | |
| model_config: model configuration. If not passed, it will be read from inference server during initialization. | |
| ensure_model_is_ready: if model should be checked if it is ready before first inference request. | |
| Raises: | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientTimeoutError: | |
| if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. | |
| PyTritonClientUrlParseError: In case of problems with parsing url. | |
| """ | |
| super().__init__( | |
| url=url, | |
| model_name=model_name, | |
| model_version=model_version, | |
| lazy_init=lazy_init, | |
| init_timeout_s=init_timeout_s, | |
| inference_timeout_s=inference_timeout_s, | |
| model_config=model_config, | |
| ensure_model_is_ready=ensure_model_is_ready, | |
| ) | |
| def get_lib(self): | |
| """Returns tritonclient library for given scheme.""" | |
| return {"grpc": tritonclient.grpc, "http": tritonclient.http}[self._triton_url.scheme.lower()] | |
| def __enter__(self): | |
| """Create context for using ModelClient as a context manager.""" | |
| return self | |
| def __exit__(self, *_): | |
| """Close resources used by ModelClient instance when exiting from the context.""" | |
| self.close() | |
| def load_model(self, config: Optional[str] = None, files: Optional[dict] = None): | |
| """Load model on the Triton Inference Server. | |
| Args: | |
| config: str - Optional JSON representation of a model config provided for | |
| the load request, if provided, this config will be used for | |
| loading the model. | |
| files: dict - Optional dictionary specifying file path (with "file:" prefix) in | |
| the override model directory to the file content as bytes. | |
| The files will form the model directory that the model will be | |
| loaded from. If specified, 'config' must be provided to be | |
| the model configuration of the override model directory. | |
| """ | |
| self._general_client.load_model(self._model_name, config=config, files=files) | |
| def unload_model(self): | |
| """Unload model from the Triton Inference Server.""" | |
| self._general_client.unload_model(self._model_name) | |
| def close(self): | |
| """Close resources used by ModelClient. | |
| This method closes the resources used by the ModelClient instance, | |
| including the Triton Inference Server connections. | |
| Once this method is called, the ModelClient instance should not be used again. | |
| """ | |
| _LOGGER.debug("Closing ModelClient") | |
| try: | |
| if self._general_client is not None: | |
| self._general_client.close() | |
| if self._infer_client is not None: | |
| self._infer_client.close() | |
| self._general_client = None | |
| self._infer_client = None | |
| except Exception as e: | |
| _LOGGER.error(f"Error while closing ModelClient resources: {e}") | |
| raise e | |
| def wait_for_model(self, timeout_s: float): | |
| """Wait for the Triton Inference Server and the deployed model to be ready. | |
| Args: | |
| timeout_s: timeout in seconds to wait for the server and model to be ready. | |
| Raises: | |
| PyTritonClientTimeoutError: If the server and model are not ready before the given timeout. | |
| PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. | |
| KeyboardInterrupt: If the hosting process receives SIGINT. | |
| PyTritonClientClosedError: If the ModelClient is closed. | |
| """ | |
| if self._general_client is None: | |
| raise PyTritonClientClosedError("ModelClient is closed") | |
| wait_for_model_ready(self._general_client, self._model_name, self._model_version, timeout_s=timeout_s) | |
| def is_batching_supported(self): | |
| """Checks if model supports batching. | |
| Also waits for server to get into readiness state. | |
| """ | |
| return self.model_config.max_batch_size > 0 | |
| def wait_for_server(self, timeout_s: float): | |
| """Wait for Triton Inference Server readiness. | |
| Args: | |
| timeout_s: timeout to server get into readiness state. | |
| Raises: | |
| PyTritonClientTimeoutError: If server is not in readiness state before given timeout. | |
| KeyboardInterrupt: If hosting process receives SIGINT | |
| """ | |
| wait_for_server_ready(self._general_client, timeout_s=timeout_s) | |
| def model_config(self) -> TritonModelConfig: | |
| """Obtain the configuration of the model deployed on the Triton Inference Server. | |
| This method waits for the server to get into readiness state before obtaining the model configuration. | |
| Returns: | |
| TritonModelConfig: configuration of the model deployed on the Triton Inference Server. | |
| Raises: | |
| PyTritonClientTimeoutError: If the server and model are not in readiness state before the given timeout. | |
| PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. | |
| KeyboardInterrupt: If the hosting process receives SIGINT. | |
| PyTritonClientClosedError: If the ModelClient is closed. | |
| """ | |
| if not self._model_config: | |
| if self._general_client is None: | |
| raise PyTritonClientClosedError("ModelClient is closed") | |
| self._model_config = get_model_config( | |
| self._general_client, self._model_name, self._model_version, timeout_s=self._init_timeout_s | |
| ) | |
| return self._model_config | |
| def infer_sample( | |
| self, | |
| *inputs, | |
| parameters: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| headers: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| **named_inputs, | |
| ) -> Dict[str, np.ndarray]: | |
| """Run synchronous inference on a single data sample. | |
| Typical usage: | |
| ```python | |
| client = ModelClient("localhost", "MyModel") | |
| result_dict = client.infer_sample(input1, input2) | |
| client.close() | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| result_dict = client.infer_sample(input1, input2) | |
| result_dict = client.infer_sample(a=input1, b=input2) | |
| ``` | |
| Args: | |
| *inputs: Inference inputs provided as positional arguments. | |
| parameters: Custom inference parameters. | |
| headers: Custom inference headers. | |
| **named_inputs: Inference inputs provided as named arguments. | |
| Returns: | |
| Dictionary with inference results, where dictionary keys are output names. | |
| Raises: | |
| PyTritonClientValueError: If mixing of positional and named arguments passing detected. | |
| PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or | |
| inference request time exceeds `inference_timeout_s`. | |
| PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. | |
| PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side. | |
| """ | |
| _verify_inputs_args(inputs, named_inputs) | |
| _verify_parameters(parameters) | |
| _verify_parameters(headers) | |
| if self.is_batching_supported: | |
| if inputs: | |
| inputs = tuple(data[np.newaxis, ...] for data in inputs) | |
| elif named_inputs: | |
| named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()} | |
| result = self._infer(inputs or named_inputs, parameters, headers) | |
| return self._debatch_result(result) | |
| def infer_batch( | |
| self, | |
| *inputs, | |
| parameters: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| headers: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| **named_inputs, | |
| ) -> Dict[str, np.ndarray]: | |
| """Run synchronous inference on batched data. | |
| Typical usage: | |
| ```python | |
| client = ModelClient("localhost", "MyModel") | |
| result_dict = client.infer_batch(input1, input2) | |
| client.close() | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| result_dict = client.infer_batch(input1, input2) | |
| result_dict = client.infer_batch(a=input1, b=input2) | |
| ``` | |
| Args: | |
| *inputs: Inference inputs provided as positional arguments. | |
| parameters: Custom inference parameters. | |
| headers: Custom inference headers. | |
| **named_inputs: Inference inputs provided as named arguments. | |
| Returns: | |
| Dictionary with inference results, where dictionary keys are output names. | |
| Raises: | |
| PyTritonClientValueError: If mixing of positional and named arguments passing detected. | |
| PyTritonClientTimeoutError: If the wait time for the server and model being ready exceeds `init_timeout_s` or | |
| inference request time exceeds `inference_timeout_s`. | |
| PyTritonClientModelUnavailableError: If the model with the given name (and version) is unavailable. | |
| PyTritonClientInferenceServerError: If an error occurred on the inference callable or Triton Inference Server side. | |
| PyTritonClientModelDoesntSupportBatchingError: If the model doesn't support batching. | |
| PyTritonClientValueError: if mixing of positional and named arguments passing detected. | |
| PyTritonClientTimeoutError: | |
| in case of first method call, `lazy_init` argument is False | |
| and wait time for server and model being ready exceeds `init_timeout_s` or | |
| inference time exceeds `inference_timeout_s` passed to `__init__`. | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side, | |
| """ | |
| _verify_inputs_args(inputs, named_inputs) | |
| _verify_parameters(parameters) | |
| _verify_parameters(headers) | |
| if not self.is_batching_supported: | |
| raise PyTritonClientModelDoesntSupportBatchingError( | |
| f"Model {self.model_config.model_name} doesn't support batching - use infer_sample method instead" | |
| ) | |
| return self._infer(inputs or named_inputs, parameters, headers) | |
| def _wait_and_init_model_config(self, init_timeout_s: float): | |
| if self._general_client is None: | |
| raise PyTritonClientClosedError("ModelClient is closed") | |
| should_finish_before_s = time.time() + init_timeout_s | |
| self.wait_for_model(init_timeout_s) | |
| self._model_ready = True | |
| timeout_s = max(0.0, should_finish_before_s - time.time()) | |
| self._model_config = get_model_config( | |
| self._general_client, self._model_name, self._model_version, timeout_s=timeout_s | |
| ) | |
| def _create_request(self, inputs: _IOType): | |
| if self._infer_client is None: | |
| raise PyTritonClientClosedError("ModelClient is closed") | |
| if not self._model_ready: | |
| self._wait_and_init_model_config(self._init_timeout_s) | |
| if isinstance(inputs, Tuple): | |
| inputs = {input_spec.name: input_data for input_spec, input_data in zip(self.model_config.inputs, inputs)} | |
| inputs_wrapped = [] | |
| for input_name, input_data in inputs.items(): | |
| if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes): | |
| raise RuntimeError( | |
| f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \ | |
| \\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}" | |
| ) | |
| if input_data.dtype.type == np.str_: | |
| raise RuntimeError( | |
| "Unicode inputs are not supported. " | |
| f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))." | |
| ) | |
| triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype) | |
| infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype) | |
| infer_input.set_data_from_numpy(input_data) | |
| inputs_wrapped.append(infer_input) | |
| outputs_wrapped = [ | |
| self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in self.model_config.outputs | |
| ] | |
| return inputs_wrapped, outputs_wrapped | |
| def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: | |
| if self.model_config.decoupled: | |
| raise PyTritonClientInferenceServerError("Model config is decoupled. Use DecoupledModelClient instead.") | |
| inputs_wrapped, outputs_wrapped = self._create_request(inputs) | |
| try: | |
| _LOGGER.debug("Sending inference request to Triton Inference Server") | |
| response = self._infer_client.infer( | |
| model_name=self._model_name, | |
| model_version=self._model_version or "", | |
| inputs=inputs_wrapped, | |
| headers=headers, | |
| outputs=outputs_wrapped, | |
| request_id=self._next_request_id, | |
| parameters=parameters, | |
| **self._get_infer_extra_args(), | |
| ) | |
| except tritonclient.utils.InferenceServerException as e: | |
| # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout | |
| if "Deadline Exceeded" in e.message(): | |
| raise PyTritonClientTimeoutError( | |
| f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}" | |
| ) from e | |
| raise PyTritonClientInferenceServerError( | |
| f"Error occurred during inference request. Message: {e.message()}" | |
| ) from e | |
| except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout | |
| message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| except OSError as e: # tritonclient.http raises socket.error for connection error | |
| message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| if isinstance(response, tritonclient.http.InferResult): | |
| outputs = { | |
| output["name"]: response.as_numpy(output["name"]) for output in response.get_response()["outputs"] | |
| } | |
| else: | |
| outputs = {output.name: response.as_numpy(output.name) for output in response.get_response().outputs} | |
| return outputs | |
| def _get_numpy_result(self, result): | |
| if isinstance(result, tritonclient.grpc.InferResult): | |
| result = {output.name: result.as_numpy(output.name) for output in result.get_response().outputs} | |
| else: | |
| result = {output["name"]: result.as_numpy(output["name"]) for output in result.get_response()["outputs"]} | |
| return result | |
| def _debatch_result(self, result): | |
| if self.is_batching_supported: | |
| result = {name: data[0] for name, data in result.items()} | |
| return result | |
| def _handle_lazy_init(self): | |
| if not self._lazy_init: | |
| self._wait_and_init_model_config(self._init_timeout_s) | |
| def _get_infer_extra_args(self): | |
| if self._triton_url.scheme == "http": | |
| return {} | |
| # For the GRPC protocol, the timeout is passed to the infer method as client_timeout | |
| # This timeout applies to the whole inference process and each network request | |
| # The ``infer`` supports also timeout argument for both GRPC and HTTP. | |
| # It is applied at server side and supported only for dynamic batching. | |
| # However, it is not used here yet and planned for future release | |
| kwargs = {"client_timeout": self._inference_timeout_s} | |
| return kwargs | |
| class DecoupledModelClient(ModelClient): | |
| """Synchronous client for decoupled model deployed on the Triton Inference Server.""" | |
| def __init__( | |
| self, | |
| url: str, | |
| model_name: str, | |
| model_version: Optional[str] = None, | |
| *, | |
| lazy_init: bool = True, | |
| init_timeout_s: Optional[float] = None, | |
| inference_timeout_s: Optional[float] = None, | |
| model_config: Optional[TritonModelConfig] = None, | |
| ensure_model_is_ready: bool = True, | |
| ): | |
| """Inits DecoupledModelClient for given decoupled model deployed on the Triton Inference Server. | |
| Common usage: | |
| ```python | |
| client = DecoupledModelClient("localhost", "BERT") | |
| for response in client.infer_sample(input1_sample, input2_sample): | |
| print(response) | |
| client.close() | |
| ``` | |
| Args: | |
| url: The Triton Inference Server url, e.g. `grpc://localhost:8001`. | |
| In case no scheme is provided http scheme will be used as default. | |
| In case no port is provided default port for given scheme will be used - | |
| 8001 for grpc scheme, 8000 for http scheme. | |
| model_name: name of the model to interact with. | |
| model_version: version of the model to interact with. | |
| If model_version is None inference on latest model will be performed. | |
| The latest versions of the model are numerically the greatest version numbers. | |
| lazy_init: if initialization should be performed just before sending first request to inference server. | |
| init_timeout_s: timeout in seconds for the server and model to be ready. If not passed, the default timeout of 300 seconds will be used. | |
| inference_timeout_s: timeout in seconds for a single model inference request. If not passed, the default timeout of 60 seconds will be used. | |
| model_config: model configuration. If not passed, it will be read from inference server during initialization. | |
| ensure_model_is_ready: if model should be checked if it is ready before first inference request. | |
| Raises: | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientTimeoutError: | |
| if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. | |
| PyTritonClientInvalidUrlError: If provided Triton Inference Server url is invalid. | |
| """ | |
| super().__init__( | |
| url, | |
| model_name, | |
| model_version, | |
| lazy_init=lazy_init, | |
| init_timeout_s=init_timeout_s, | |
| inference_timeout_s=inference_timeout_s, | |
| model_config=model_config, | |
| ensure_model_is_ready=ensure_model_is_ready, | |
| ) | |
| if self._triton_url.scheme == "http": | |
| raise PyTritonClientValueError("DecoupledModelClient is only supported for grpc protocol") | |
| self._queue = Queue() | |
| self._lock = Lock() | |
| def close(self): | |
| """Close resources used by DecoupledModelClient.""" | |
| _LOGGER.debug("Closing DecoupledModelClient") | |
| if self._lock.acquire(blocking=False): | |
| try: | |
| super().close() | |
| finally: | |
| self._lock.release() | |
| else: | |
| _LOGGER.warning("DecoupledModelClient is stil streaming answers") | |
| self._infer_client.stop_stream(False) | |
| super().close() | |
| def _infer(self, inputs: _IOType, parameters, headers): | |
| if not self._lock.acquire(blocking=False): | |
| raise PyTritonClientInferenceServerError("Inference is already in progress") | |
| if not self.model_config.decoupled: | |
| raise PyTritonClientInferenceServerError("Model config is coupled. Use ModelClient instead.") | |
| inputs_wrapped, outputs_wrapped = self._create_request(inputs) | |
| if parameters is not None: | |
| raise PyTritonClientValueError("DecoupledModelClient does not support parameters") | |
| if headers is not None: | |
| raise PyTritonClientValueError("DecoupledModelClient does not support headers") | |
| try: | |
| _LOGGER.debug("Sending inference request to Triton Inference Server") | |
| if self._infer_client._stream is None: | |
| self._infer_client.start_stream(callback=lambda result, error: self._response_callback(result, error)) | |
| self._infer_client.async_stream_infer( | |
| model_name=self._model_name, | |
| model_version=self._model_version or "", | |
| inputs=inputs_wrapped, | |
| outputs=outputs_wrapped, | |
| request_id=self._next_request_id, | |
| enable_empty_final_response=True, | |
| **self._get_infer_extra_args(), | |
| ) | |
| except tritonclient.utils.InferenceServerException as e: | |
| # tritonclient.grpc raises execption with message containing "Deadline Exceeded" for timeout | |
| if "Deadline Exceeded" in e.message(): | |
| raise PyTritonClientTimeoutError( | |
| f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s. Message: {e.message()}" | |
| ) from e | |
| raise PyTritonClientInferenceServerError( | |
| f"Error occurred during inference request. Message: {e.message()}" | |
| ) from e | |
| except socket.timeout as e: # tritonclient.http raises socket.timeout for timeout | |
| message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| except OSError as e: # tritonclient.http raises socket.error for connection error | |
| message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s Message: {e}" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| _LOGGER.debug("Returning response iterator") | |
| return self._create_response_iterator() | |
| def _response_callback(self, response, error): | |
| _LOGGER.debug(f"Received response from Triton Inference Server: {response}") | |
| if error: | |
| _LOGGER.error(f"Error occurred during inference request. Message: {error}") | |
| self._queue.put(error) | |
| else: | |
| actual_response = response.get_response() | |
| # Check if the object is not None | |
| triton_final_response = actual_response.parameters.get("triton_final_response") | |
| if triton_final_response and triton_final_response.bool_param: | |
| self._queue.put(None) | |
| else: | |
| result = self._get_numpy_result(response) | |
| self._queue.put(result) | |
| def _create_response_iterator(self): | |
| try: | |
| while True: | |
| try: | |
| item = self._queue.get(self._inference_timeout_s) | |
| except Empty as e: | |
| message = f"Timeout occurred during inference request. Timeout: {self._inference_timeout_s} s" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| if isinstance(item, Exception): | |
| message = f"Error occurred during inference request. Message: {item.message()}" | |
| _LOGGER.error(message) | |
| raise PyTritonClientInferenceServerError(message) from item | |
| if item is None: | |
| break | |
| yield item | |
| finally: | |
| self._lock.release() | |
| def _debatch_result(self, result): | |
| if self.is_batching_supported: | |
| result = ({name: data[0] for name, data in result_.items()} for result_ in result) | |
| return result | |
| def _get_infer_extra_args(self): | |
| # kwargs = super()._get_infer_extra_args() | |
| kwargs = {} | |
| # kwargs["enable_empty_final_response"] = True | |
| return kwargs | |
| class AsyncioModelClient(BaseModelClient): | |
| """Asyncio client for model deployed on the Triton Inference Server. | |
| This client is based on Triton Inference Server Python clients and GRPC library: | |
| - ``tritonclient.http.aio.InferenceServerClient`` | |
| - ``tritonclient.grpc.aio.InferenceServerClient`` | |
| It can wait for server to be ready with model loaded and then perform inference on it. | |
| ``AsyncioModelClient`` supports asyncio context manager protocol. | |
| Typical usage: | |
| ```python | |
| from pytriton.client import AsyncioModelClient | |
| import numpy as np | |
| input1_sample = np.random.rand(1, 3, 224, 224).astype(np.float32) | |
| input2_sample = np.random.rand(1, 3, 224, 224).astype(np.float32) | |
| client = AsyncioModelClient("localhost", "MyModel") | |
| result_dict = await client.infer_sample(input1_sample, input2_sample) | |
| print(result_dict["output_name"]) | |
| await client.close() | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| url: str, | |
| model_name: str, | |
| model_version: Optional[str] = None, | |
| *, | |
| lazy_init: bool = True, | |
| init_timeout_s: Optional[float] = None, | |
| inference_timeout_s: Optional[float] = None, | |
| model_config: Optional[TritonModelConfig] = None, | |
| ensure_model_is_ready: bool = True, | |
| ): | |
| """Inits ModelClient for given model deployed on the Triton Inference Server. | |
| If `lazy_init` argument is False, model configuration will be read | |
| from inference server during initialization. | |
| Args: | |
| url: The Triton Inference Server url, e.g. 'grpc://localhost:8001'. | |
| In case no scheme is provided http scheme will be used as default. | |
| In case no port is provided default port for given scheme will be used - | |
| 8001 for grpc scheme, 8000 for http scheme. | |
| model_name: name of the model to interact with. | |
| model_version: version of the model to interact with. | |
| If model_version is None inference on latest model will be performed. | |
| The latest versions of the model are numerically the greatest version numbers. | |
| lazy_init: if initialization should be performed just before sending first request to inference server. | |
| init_timeout_s: timeout for server and model being ready. | |
| model_config: model configuration. If not passed, it will be read from inference server during initialization. | |
| ensure_model_is_ready: if model should be checked if it is ready before first inference request. | |
| Raises: | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientTimeoutError: if `lazy_init` argument is False and wait time for server and model being ready exceeds `init_timeout_s`. | |
| PyTritonClientUrlParseError: In case of problems with parsing url. | |
| """ | |
| super().__init__( | |
| url=url, | |
| model_name=model_name, | |
| model_version=model_version, | |
| lazy_init=lazy_init, | |
| init_timeout_s=init_timeout_s, | |
| inference_timeout_s=inference_timeout_s, | |
| model_config=model_config, | |
| ensure_model_is_ready=ensure_model_is_ready, | |
| ) | |
| def get_lib(self): | |
| """Get Triton Inference Server Python client library.""" | |
| return {"grpc": tritonclient.grpc.aio, "http": tritonclient.http.aio}[self._triton_url.scheme.lower()] | |
| async def __aenter__(self): | |
| """Create context for use AsyncioModelClient as a context manager.""" | |
| _LOGGER.debug("Entering AsyncioModelClient context") | |
| try: | |
| if not self._lazy_init: | |
| _LOGGER.debug("Waiting in AsyncioModelClient context for model to be ready") | |
| await self._wait_and_init_model_config(self._init_timeout_s) | |
| _LOGGER.debug("Model is ready in AsyncioModelClient context") | |
| return self | |
| except Exception as e: | |
| _LOGGER.error("Error occurred during AsyncioModelClient context initialization") | |
| await self.close() | |
| raise e | |
| async def __aexit__(self, *_): | |
| """Close resources used by AsyncioModelClient when exiting from context.""" | |
| await self.close() | |
| _LOGGER.debug("Exiting AsyncioModelClient context") | |
| async def close(self): | |
| """Close resources used by _ModelClientBase.""" | |
| _LOGGER.debug("Closing InferenceServerClient") | |
| await self._general_client.close() | |
| await self._infer_client.close() | |
| _LOGGER.debug("InferenceServerClient closed") | |
| async def wait_for_model(self, timeout_s: float): | |
| """Asynchronous wait for Triton Inference Server and deployed on it model readiness. | |
| Args: | |
| timeout_s: timeout to server and model get into readiness state. | |
| Raises: | |
| PyTritonClientTimeoutError: If server and model are not in readiness state before given timeout. | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| KeyboardInterrupt: If hosting process receives SIGINT | |
| """ | |
| _LOGGER.debug(f"Waiting for model {self._model_name} to be ready") | |
| try: | |
| await asyncio.wait_for( | |
| asyncio_wait_for_model_ready( | |
| self._general_client, self._model_name, self._model_version, timeout_s=timeout_s | |
| ), | |
| self._init_timeout_s, | |
| ) | |
| except asyncio.TimeoutError as e: | |
| message = f"Timeout while waiting for model {self._model_name} to be ready for {self._init_timeout_s}s" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| async def model_config(self): | |
| """Obtain configuration of model deployed on the Triton Inference Server. | |
| Also waits for server to get into readiness state. | |
| """ | |
| try: | |
| if not self._model_config: | |
| kwargs = self._get_model_config_extra_args() | |
| _LOGGER.debug(f"Obtaining model config for {self._model_name}") | |
| self._model_config = await asyncio.wait_for( | |
| asyncio_get_model_config( | |
| self._general_client, | |
| self._model_name, | |
| self._model_version, | |
| timeout_s=self._init_timeout_s, | |
| **kwargs, | |
| ), | |
| self._init_timeout_s, | |
| ) | |
| _LOGGER.debug(f"Obtained model config for {self._model_name}") | |
| return self._model_config | |
| except asyncio.TimeoutError as e: | |
| message = f"Timeout while waiting for model {self._model_name} to be ready for {self._init_timeout_s}s" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| async def infer_sample( | |
| self, | |
| *inputs, | |
| parameters: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| headers: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| **named_inputs, | |
| ) -> Dict[str, np.ndarray]: | |
| """Run asynchronous inference on single data sample. | |
| Typical usage: | |
| ```python | |
| client = AsyncioModelClient("localhost", "MyModel") | |
| result_dict = await client.infer_sample(input1, input2) | |
| await client.close() | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| result_dict = await client.infer_sample(input1, input2) | |
| result_dict = await client.infer_sample(a=input1, b=input2) | |
| ``` | |
| Mixing of argument passing conventions is not supported and will raise PyTritonClientRuntimeError. | |
| Args: | |
| *inputs: inference inputs provided as positional arguments. | |
| parameters: custom inference parameters. | |
| headers: custom inference headers. | |
| **named_inputs: inference inputs provided as named arguments. | |
| Returns: | |
| dictionary with inference results, where dictionary keys are output names. | |
| Raises: | |
| PyTritonClientValueError: if mixing of positional and named arguments passing detected. | |
| PyTritonClientTimeoutError: | |
| in case of first method call, `lazy_init` argument is False | |
| and wait time for server and model being ready exceeds `init_timeout_s` | |
| or inference time exceeds `timeout_s`. | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side. | |
| """ | |
| _verify_inputs_args(inputs, named_inputs) | |
| _verify_parameters(parameters) | |
| _verify_parameters(headers) | |
| _LOGGER.debug(f"Running inference for {self._model_name}") | |
| model_config = await self.model_config | |
| _LOGGER.debug(f"Model config for {self._model_name} obtained") | |
| model_supports_batching = model_config.max_batch_size > 0 | |
| if model_supports_batching: | |
| if inputs: | |
| inputs = tuple(data[np.newaxis, ...] for data in inputs) | |
| elif named_inputs: | |
| named_inputs = {name: data[np.newaxis, ...] for name, data in named_inputs.items()} | |
| _LOGGER.debug(f"Running _infer for {self._model_name}") | |
| result = await self._infer(inputs or named_inputs, parameters, headers) | |
| _LOGGER.debug(f"_infer for {self._model_name} finished") | |
| if model_supports_batching: | |
| result = {name: data[0] for name, data in result.items()} | |
| return result | |
| async def infer_batch( | |
| self, | |
| *inputs, | |
| parameters: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| headers: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| **named_inputs, | |
| ) -> Dict[str, np.ndarray]: | |
| """Run asynchronous inference on batched data. | |
| Typical usage: | |
| ```python | |
| client = AsyncioModelClient("localhost", "MyModel") | |
| result_dict = await client.infer_batch(input1, input2) | |
| await client.close() | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| result_dict = await client.infer_batch(input1, input2) | |
| result_dict = await client.infer_batch(a=input1, b=input2) | |
| ``` | |
| Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError. | |
| Args: | |
| *inputs: inference inputs provided as positional arguments. | |
| parameters: custom inference parameters. | |
| headers: custom inference headers. | |
| **named_inputs: inference inputs provided as named arguments. | |
| Returns: | |
| dictionary with inference results, where dictionary keys are output names. | |
| Raises: | |
| PyTritonClientValueError: if mixing of positional and named arguments passing detected. | |
| PyTritonClientTimeoutError: | |
| in case of first method call, `lazy_init` argument is False | |
| and wait time for server and model being ready exceeds `init_timeout_s` | |
| or inference time exceeds `timeout_s`. | |
| PyTritonClientModelDoesntSupportBatchingError: if model doesn't support batching. | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| PyTritonClientInferenceServerError: If error occurred on inference callable or Triton Inference Server side. | |
| """ | |
| _verify_inputs_args(inputs, named_inputs) | |
| _verify_parameters(parameters) | |
| _verify_parameters(headers) | |
| _LOGGER.debug(f"Running inference for {self._model_name}") | |
| model_config = await self.model_config | |
| _LOGGER.debug(f"Model config for {self._model_name} obtained") | |
| model_supports_batching = model_config.max_batch_size > 0 | |
| if not model_supports_batching: | |
| _LOGGER.error(f"Model {model_config.model_name} doesn't support batching") | |
| raise PyTritonClientModelDoesntSupportBatchingError( | |
| f"Model {model_config.model_name} doesn't support batching - use infer_sample method instead" | |
| ) | |
| _LOGGER.debug(f"Running _infer for {self._model_name}") | |
| result = await self._infer(inputs or named_inputs, parameters, headers) | |
| _LOGGER.debug(f"_infer for {self._model_name} finished") | |
| return result | |
| async def _wait_and_init_model_config(self, init_timeout_s: float): | |
| """Asynchronous wait for model and obtain model configuration. | |
| Args: | |
| init_timeout_s: timeout for server and model being ready. | |
| Raises: | |
| PyTritonClientTimeoutError: if wait time for server and model being ready exceeds `init_timeout_s` | |
| PyTritonClientModelUnavailableError: If model with given name (and version) is unavailable. | |
| """ | |
| try: | |
| should_finish_before_s = time.time() + init_timeout_s | |
| _LOGGER.debug(f"Waiting for model {self._model_name} to be ready") | |
| await asyncio.wait_for(self.wait_for_model(init_timeout_s), init_timeout_s) | |
| _LOGGER.debug(f"Model {self._model_name} is ready") | |
| self._model_ready = True | |
| timeout_s = max(0.0, should_finish_before_s - time.time()) | |
| _LOGGER.debug(f"Obtaining model config for {self._model_name}") | |
| self._model_config = await asyncio.wait_for( | |
| asyncio_get_model_config( | |
| self._general_client, self._model_name, self._model_version, timeout_s=timeout_s | |
| ), | |
| timeout_s, | |
| ) | |
| _LOGGER.debug(f"Model config for {self._model_name} obtained") | |
| except asyncio.TimeoutError as e: | |
| _LOGGER.error(f"Timeout exceeded while waiting for model {self._model_name} to be ready") | |
| raise PyTritonClientTimeoutError( | |
| f"Timeout exceeded while waiting for model {self._model_name} to be ready" | |
| ) from e | |
| def _validate_input(self, input_name, input_data): | |
| if input_data.dtype == object and not isinstance(input_data.reshape(-1)[0], bytes): | |
| raise RuntimeError( | |
| f"Numpy array for {input_name!r} input with dtype=object should contain encoded strings \ | |
| \\(e.g. into utf-8\\). Element type: {type(input_data.reshape(-1)[0])}" | |
| ) | |
| if input_data.dtype.type == np.str_: | |
| raise RuntimeError( | |
| "Unicode inputs are not supported. " | |
| f"Encode numpy array for {input_name!r} input (ex. with np.char.encode(array, 'utf-8'))." | |
| ) | |
| async def _execute_infer(self, model_config, inputs_wrapped, outputs_wrapped, parameters, headers): | |
| try: | |
| _LOGGER.debug(f"Sending InferRequest for {self._model_name}") | |
| kwargs = self._get_infer_extra_args() | |
| response = await self._infer_client.infer( | |
| model_name=self._model_name, | |
| model_version=self._model_version or "", | |
| inputs=inputs_wrapped, | |
| headers=headers, | |
| outputs=outputs_wrapped, | |
| request_id=self._next_request_id, | |
| parameters=parameters, | |
| **kwargs, | |
| ) | |
| except asyncio.exceptions.TimeoutError as e: | |
| # HTTP aio client raises asyncio.exceptions.TimeoutError for timeout errors | |
| message = f"Timeout exceeded while running inference for {self._model_name}" | |
| _LOGGER.error(message) | |
| raise PyTritonClientTimeoutError(message) from e | |
| except tritonclient.utils.InferenceServerException as e: | |
| message = f"Error occurred on Triton Inference Server side:\n {e.message()}" | |
| _LOGGER.error(message) | |
| if "Deadline Exceeded" in e.message(): | |
| # GRPC aio client raises InferenceServerException with message "Deadline Exceeded" | |
| # for timeout errors | |
| raise PyTritonClientTimeoutError(message) from e | |
| else: | |
| raise PyTritonClientInferenceServerError(message) from e | |
| _LOGGER.debug(f"Received InferResponse for {self._model_name}") | |
| outputs = {output_spec.name: response.as_numpy(output_spec.name) for output_spec in model_config.outputs} | |
| return outputs | |
| async def _infer(self, inputs: _IOType, parameters, headers) -> Dict[str, np.ndarray]: | |
| if self._model_ready: | |
| _LOGGER.debug(f"Waiting for model {self._model_name} config") | |
| await self._wait_and_init_model_config(self._init_timeout_s) | |
| _LOGGER.debug(f"Model wait finished for {self._model_name}") | |
| _LOGGER.debug(f"Obtaining config for {self._model_name}") | |
| model_config = await self.model_config | |
| _LOGGER.debug(f"Model config for {self._model_name} obtained") | |
| if isinstance(inputs, Tuple): | |
| inputs = {input_spec.name: input_data for input_spec, input_data in zip(model_config.inputs, inputs)} | |
| inputs_wrapped = [] | |
| for input_name, input_data in inputs.items(): | |
| if isinstance(input_data, np.ndarray): | |
| self._validate_input(input_name, input_data) | |
| triton_dtype = tritonclient.utils.np_to_triton_dtype(input_data.dtype) | |
| infer_input = self._triton_client_lib.InferInput(input_name, input_data.shape, triton_dtype) | |
| infer_input.set_data_from_numpy(input_data) | |
| input_wrapped = infer_input | |
| inputs_wrapped.append(input_wrapped) | |
| else: | |
| raise PyTritonClientValueError( | |
| f"Input {input_name} is not a numpy array. Got {type(input_data)} instead." | |
| ) | |
| outputs_wrapped = [ | |
| self._triton_client_lib.InferRequestedOutput(output_spec.name) for output_spec in model_config.outputs | |
| ] | |
| return await self._execute_infer(model_config, inputs_wrapped, outputs_wrapped, parameters, headers) | |
| def _handle_lazy_init(self): | |
| # Asynchronous lazy initialization is done in __aenter__ method | |
| pass | |
| def _get_init_extra_args(self): | |
| # The inference timeout is used for both the HTTP and the GRPC protocols. However, | |
| # the way the timeout is passed to the client differs depending on the protocol. | |
| # For the HTTP protocol, the timeout is set in the ``__init__`` method as ``conn_timeout`` for both connection and request timeouts. | |
| # For the GRPC protocol, the timeout | |
| # is passed to the infer method as ``client_timeout``. | |
| # Both protocols support timeouts correctly and will raise an exception | |
| # if the network request or the inference process takes longer than the timeout. | |
| # This is a design choice of the underlying tritonclient library. | |
| if self._triton_url.scheme != "http": | |
| return {} | |
| kwargs = { | |
| # This value sets the maximum time allowed for both connection and network requests in both model loading and inference process | |
| "conn_timeout": self._inference_timeout_s, | |
| } | |
| return kwargs | |
| def _get_infer_extra_args(self): | |
| if self._triton_url.scheme == "http": | |
| return {} | |
| # For the GRPC protocol, the timeout is passed to the infer method as client_timeout | |
| # This timeout applies to the whole inference process and each network request | |
| # The ``infer`` supports also timeout argument for both GRPC and HTTP. | |
| # It is applied at server side and supported only for dynamic batching. | |
| # However, it is not used here yet and planned for future release | |
| kwargs = {"client_timeout": self._inference_timeout_s} | |
| return kwargs | |
| def _hub_context(): | |
| hub = gevent.get_hub() | |
| try: | |
| yield hub | |
| finally: | |
| hub.destroy() | |
| _INIT = "init" | |
| _WAIT_FOR_MODEL = "wait_for_model" | |
| _MODEL_CONFIG = "model_config" | |
| _INFER_BATCH = "infer_batch" | |
| _INFER_SAMPLE = "infer_sample" | |
| _CLOSE = "close" | |
| class FuturesModelClient: | |
| """A client for interacting with a model deployed on the Triton Inference Server using concurrent.futures. | |
| This client allows asynchronous inference requests using a thread pool executor. It can be used to perform inference | |
| on a model by providing input data and receiving the corresponding output data. The client can be used in a `with` | |
| statement to ensure proper resource management. | |
| Example usage with context manager: | |
| ```python | |
| with FuturesModelClient("localhost", "MyModel") as client: | |
| result_future = client.infer_sample(input1=input1_data, input2=input2_data) | |
| # do something else | |
| print(result_future.result()) | |
| ``` | |
| Usage without context manager: | |
| ```python | |
| client = FuturesModelClient("localhost", "MyModel") | |
| result_future = client.infer_sample(input1=input1_data, input2=input2_data) | |
| # do something else | |
| print(result_future.result()) | |
| client.close() | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| url: str, | |
| model_name: str, | |
| model_version: Optional[str] = None, | |
| *, | |
| max_workers: int = 128, | |
| max_queue_size: int = 128, | |
| non_blocking: bool = False, | |
| init_timeout_s: Optional[float] = None, | |
| inference_timeout_s: Optional[float] = None, | |
| ): | |
| """Initializes the FuturesModelClient for a given model. | |
| Args: | |
| url: The Triton Inference Server url, e.g. `grpc://localhost:8001`. | |
| model_name: The name of the model to interact with. | |
| model_version: The version of the model to interact with. If None, the latest version will be used. | |
| max_workers: The maximum number of threads that can be used to execute the given calls. If None, there is not limit on the number of threads. | |
| max_queue_size: The maximum number of requests that can be queued. If None, there is not limit on the number of requests. | |
| non_blocking: If True, the client will raise a PyTritonClientQueueFullError if the queue is full. If False, the client will block until the queue is not full. | |
| init_timeout_s: Timeout in seconds for server and model being ready. If non passed default 60 seconds timeout will be used. | |
| inference_timeout_s: Timeout in seconds for the single model inference request. If non passed default 60 seconds timeout will be used. | |
| """ | |
| self._url = url | |
| self._model_name = model_name | |
| self._model_version = model_version | |
| self._threads = [] | |
| self._max_workers = max_workers | |
| self._max_queue_size = max_queue_size | |
| self._non_blocking = non_blocking | |
| if self._max_workers is not None and self._max_workers <= 0: | |
| raise ValueError("max_workers must be greater than 0") | |
| if self._max_queue_size is not None and self._max_queue_size <= 0: | |
| raise ValueError("max_queue_size must be greater than 0") | |
| kwargs = {} | |
| if self._max_queue_size is not None: | |
| kwargs["maxsize"] = self._max_queue_size | |
| self._queue = Queue(**kwargs) | |
| self._queue.put((_INIT, None, None)) | |
| self._init_timeout_s = _DEFAULT_FUTURES_INIT_TIMEOUT_S if init_timeout_s is None else init_timeout_s | |
| self._inference_timeout_s = inference_timeout_s | |
| self._closed = False | |
| self._lock = Lock() | |
| self._existing_client = None | |
| def __enter__(self): | |
| """Create context for using FuturesModelClient as a context manager.""" | |
| return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| """Close resources used by FuturesModelClient instance when exiting from the context.""" | |
| self.close() | |
| def close(self, wait=True): | |
| """Close resources used by FuturesModelClient. | |
| This method closes the resources used by the FuturesModelClient instance, including the Triton Inference Server connections. | |
| Once this method is called, the FuturesModelClient instance should not be used again. | |
| Args: | |
| wait: If True, then shutdown will not return until all running futures have finished executing. | |
| """ | |
| if self._closed: | |
| return | |
| _LOGGER.debug("Closing FuturesModelClient.") | |
| self._closed = True | |
| for _ in range(len(self._threads)): | |
| self._queue.put((_CLOSE, None, None)) | |
| if wait: | |
| _LOGGER.debug("Waiting for futures to finish.") | |
| for thread in self._threads: | |
| thread.join() | |
| def wait_for_model(self, timeout_s: float) -> Future: | |
| """Returns a Future object which result will be None when the model is ready. | |
| Typical usage: | |
| ```python | |
| with FuturesModelClient("localhost", "BERT") as client | |
| future = client.wait_for_model(300.) | |
| # do something else | |
| future.result() # wait rest of timeout_s time | |
| # till return None if model is ready | |
| # or raise PyTritonClientTimeutError | |
| ``` | |
| Args: | |
| timeout_s: The maximum amount of time to wait for the model to be ready, in seconds. | |
| Returns: | |
| A Future object which result is None when the model is ready. | |
| """ | |
| return self._execute( | |
| name=_WAIT_FOR_MODEL, | |
| request=timeout_s, | |
| ) | |
| def model_config(self) -> Future: | |
| """Obtain the configuration of the model deployed on the Triton Inference Server. | |
| This method returns a Future object that will contain the TritonModelConfig object when it is ready. | |
| Client will wait init_timeout_s for the server to get into readiness state before obtaining the model configuration. | |
| Returns: | |
| A Future object that will contain the TritonModelConfig object when it is ready. | |
| Raises: | |
| PyTritonClientClosedError: If the FuturesModelClient is closed. | |
| """ | |
| return self._execute(name=_MODEL_CONFIG) | |
| def infer_sample( | |
| self, | |
| *inputs, | |
| parameters: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| headers: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| **named_inputs, | |
| ) -> Future: | |
| """Run asynchronous inference on a single data sample and return a Future object. | |
| This method allows the user to perform inference on a single data sample by providing input data and receiving the | |
| corresponding output data. The method returns a Future object that wraps a dictionary of inference results, where dictionary keys are output names. | |
| Example usage: | |
| ```python | |
| with FuturesModelClient("localhost", "BERT") as client: | |
| result_future = client.infer_sample(input1=input1_data, input2=input2_data) | |
| # do something else | |
| print(result_future.result()) | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| future = client.infer_sample(input1, input2) | |
| future = client.infer_sample(a=input1, b=input2) | |
| ``` | |
| Args: | |
| *inputs: Inference inputs provided as positional arguments. | |
| parameters: Optional dictionary of inference parameters. | |
| headers: Optional dictionary of HTTP headers for the inference request. | |
| **named_inputs: Inference inputs provided as named arguments. | |
| Returns: | |
| A Future object wrapping a dictionary of inference results, where dictionary keys are output names. | |
| Raises: | |
| PyTritonClientClosedError: If the FuturesModelClient is closed. | |
| """ | |
| return self._execute( | |
| name=_INFER_SAMPLE, | |
| request=(inputs, parameters, headers, named_inputs), | |
| ) | |
| def infer_batch( | |
| self, | |
| *inputs, | |
| parameters: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| headers: Optional[Dict[str, Union[str, int, bool]]] = None, | |
| **named_inputs, | |
| ) -> Future: | |
| """Run asynchronous inference on batched data and return a Future object. | |
| This method allows the user to perform inference on batched data by providing input data and receiving the corresponding output data. | |
| The method returns a Future object that wraps a dictionary of inference results, where dictionary keys are output names. | |
| Example usage: | |
| ```python | |
| with FuturesModelClient("localhost", "BERT") as client: | |
| future = client.infer_batch(input1_sample, input2_sample) | |
| # do something else | |
| print(future.result()) | |
| ``` | |
| Inference inputs can be provided either as positional or keyword arguments: | |
| ```python | |
| future = client.infer_batch(input1, input2) | |
| future = client.infer_batch(a=input1, b=input2) | |
| ``` | |
| Mixing of argument passing conventions is not supported and will raise PyTritonClientValueError. | |
| Args: | |
| *inputs: Inference inputs provided as positional arguments. | |
| parameters: Optional dictionary of inference parameters. | |
| headers: Optional dictionary of HTTP headers for the inference request. | |
| **named_inputs: Inference inputs provided as named arguments. | |
| Returns: | |
| A Future object wrapping a dictionary of inference results, where dictionary keys are output names. | |
| Raises: | |
| PyTritonClientClosedError: If the FuturesModelClient is closed. | |
| """ | |
| return self._execute(name=_INFER_BATCH, request=(inputs, parameters, headers, named_inputs)) | |
| def _execute(self, name, request=None): | |
| if self._closed: | |
| raise PyTritonClientClosedError("FutureModelClient is already closed") | |
| self._extend_thread_pool() | |
| future = Future() | |
| if self._non_blocking: | |
| try: | |
| self._queue.put_nowait((future, request, name)) | |
| except Full as e: | |
| raise PyTritonClientQueueFullError("Queue is full") from e | |
| else: | |
| kwargs = {} | |
| if self._inference_timeout_s is not None: | |
| kwargs["timeout"] = self._inference_timeout_s | |
| try: | |
| self._queue.put((future, request, name), **kwargs) | |
| except Full as e: | |
| raise PyTritonClientQueueFullError("Queue is full") from e | |
| return future | |
| def _extend_thread_pool(self): | |
| if self._closed: | |
| return | |
| with self._lock: | |
| if not self._queue.empty() and (self._max_workers is None or len(self._threads) < self._max_workers): | |
| _LOGGER.debug("Create new thread") | |
| thread = Thread(target=self._worker) | |
| self._threads.append(thread) | |
| thread.start() | |
| else: | |
| _LOGGER.debug("No need to create new thread") | |
| def _client_request_executor(self, client, request, name): | |
| _LOGGER.debug(f"Running {name} for {self._model_name}") | |
| if name == _INFER_SAMPLE: | |
| inputs, parameters, headers, named_inputs = request | |
| result = client.infer_sample( | |
| *inputs, | |
| parameters=parameters, | |
| headers=headers, | |
| **named_inputs, | |
| ) | |
| elif name == _INFER_BATCH: | |
| inputs, parameters, headers, named_inputs = request | |
| result = client.infer_batch( | |
| *inputs, | |
| parameters=parameters, | |
| headers=headers, | |
| **named_inputs, | |
| ) | |
| elif name == _MODEL_CONFIG: | |
| result = client.model_config | |
| elif name == _WAIT_FOR_MODEL: | |
| timeout_s = request | |
| result = client.wait_for_model(timeout_s) | |
| else: | |
| raise PyTritonClientValueError(f"Unknown request name {name}") | |
| self._set_existing_client(client) | |
| return result | |
| def _create_client(self, lazy_init): | |
| _LOGGER.debug(f"Creating ModelClient lazy_init={lazy_init}") | |
| return ModelClient( | |
| self._url, | |
| self._model_name, | |
| self._model_version, | |
| lazy_init=lazy_init, | |
| init_timeout_s=self._init_timeout_s, | |
| inference_timeout_s=self._inference_timeout_s, | |
| ) | |
| def _set_existing_client(self, client): | |
| if client._model_config is not None: | |
| with self._lock: | |
| if self._existing_client is None: | |
| _LOGGER.debug("Setting existing client") | |
| self._existing_client = client | |
| def _remove_existing_client(self, client): | |
| if client is not None: | |
| with self._lock: | |
| if self._existing_client is not None: | |
| if self._existing_client is client: | |
| _LOGGER.debug("Resetting existing client") | |
| self._existing_client = None | |
| def _worker(self): | |
| _LOGGER.debug("Starting worker thread") | |
| client = None | |
| # Work around for AttributeError: '_Threadlocal' object has no attribute 'hub' | |
| # gevent/_hub_local.py", line 77, in gevent._gevent_c_hub_local.get_hub_noargs | |
| with _hub_context(): | |
| while True: | |
| future, request, name = self._queue.get() | |
| if future == _CLOSE: | |
| _LOGGER.debug("Closing thread") | |
| self._queue.task_done() | |
| break | |
| if future == _INIT: | |
| with self._lock: | |
| if self._existing_client is None: | |
| try: | |
| _LOGGER.debug("Initial client creation") | |
| client = self._create_client(False) | |
| _LOGGER.debug("Setting existing client") | |
| self._existing_client = client | |
| except Exception as e: | |
| _LOGGER.warning(f"Error {e} occurred during init for {self._model_name}") | |
| continue | |
| try: | |
| if client is None: | |
| with self._lock: | |
| if self._existing_client is not None: | |
| _LOGGER.debug("Creating new client from existing client") | |
| client = ModelClient.from_existing_client(self._existing_client) | |
| if client is None: | |
| _LOGGER.debug("Creating new client") | |
| client = self._create_client(name == _WAIT_FOR_MODEL) | |
| with client: | |
| self._set_existing_client(client) | |
| while True: | |
| try: | |
| result = self._client_request_executor(client, request, name) | |
| _LOGGER.debug(f"Finished {name} for {self._model_name}") | |
| future.set_result(result) | |
| self._queue.task_done() | |
| except Exception as e: | |
| _LOGGER.error(f"Error {e} occurred during {name} for {self._model_name}") | |
| future.set_exception(e) | |
| self._queue.task_done() | |
| break | |
| future, request, name = self._queue.get() | |
| if future == _CLOSE: | |
| _LOGGER.debug("Closing thread") | |
| self._queue.task_done() | |
| return | |
| except Exception as e: | |
| _LOGGER.error(f"Error {e} occurred during {name} for {self._model_name}") | |
| future.set_exception(e) | |
| self._queue.task_done() | |
| finally: | |
| self._remove_existing_client(client) | |
| client = None | |
| _LOGGER.debug("Finishing worker thread") | |