Spaces:
Runtime error
Runtime error
# YOLOv5 π by Ultralytics, GPL-3.0 license | |
""" Utils to interact with the Triton Inference Server | |
""" | |
import typing | |
from urllib.parse import urlparse | |
import torch | |
class TritonRemoteModel: | |
"""A wrapper over a model served by the Triton Inference Server. It can | |
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors | |
as input and returns them as outputs. | |
""" | |
def __init__(self, url: str): | |
""" | |
Keyword arguments: | |
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000 | |
""" | |
parsed_url = urlparse(url) | |
if parsed_url.scheme == "grpc": | |
from tritonclient.grpc import InferenceServerClient, InferInput | |
self.client = InferenceServerClient( | |
parsed_url.netloc | |
) # Triton GRPC client | |
model_repository = self.client.get_model_repository_index() | |
self.model_name = model_repository.models[0].name | |
self.metadata = self.client.get_model_metadata( | |
self.model_name, as_json=True | |
) | |
def create_input_placeholders() -> typing.List[InferInput]: | |
return [ | |
InferInput( | |
i["name"], [int(s) for s in i["shape"]], i["datatype"] | |
) | |
for i in self.metadata["inputs"] | |
] | |
else: | |
from tritonclient.http import InferenceServerClient, InferInput | |
self.client = InferenceServerClient( | |
parsed_url.netloc | |
) # Triton HTTP client | |
model_repository = self.client.get_model_repository_index() | |
self.model_name = model_repository[0]["name"] | |
self.metadata = self.client.get_model_metadata(self.model_name) | |
def create_input_placeholders() -> typing.List[InferInput]: | |
return [ | |
InferInput( | |
i["name"], [int(s) for s in i["shape"]], i["datatype"] | |
) | |
for i in self.metadata["inputs"] | |
] | |
self._create_input_placeholders_fn = create_input_placeholders | |
def runtime(self): | |
"""Returns the model runtime""" | |
return self.metadata.get("backend", self.metadata.get("platform")) | |
def __call__( | |
self, *args, **kwargs | |
) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]: | |
"""Invokes the model. Parameters can be provided via args or kwargs. | |
args, if provided, are assumed to match the order of inputs of the model. | |
kwargs are matched with the model input names. | |
""" | |
inputs = self._create_inputs(*args, **kwargs) | |
response = self.client.infer(model_name=self.model_name, inputs=inputs) | |
result = [] | |
for output in self.metadata["outputs"]: | |
tensor = torch.as_tensor(response.as_numpy(output["name"])) | |
result.append(tensor) | |
return result[0] if len(result) == 1 else result | |
def _create_inputs(self, *args, **kwargs): | |
args_len, kwargs_len = len(args), len(kwargs) | |
if not args_len and not kwargs_len: | |
raise RuntimeError("No inputs provided.") | |
if args_len and kwargs_len: | |
raise RuntimeError( | |
"Cannot specify args and kwargs at the same time" | |
) | |
placeholders = self._create_input_placeholders_fn() | |
if args_len: | |
if args_len != len(placeholders): | |
raise RuntimeError( | |
f"Expected {len(placeholders)} inputs, got {args_len}." | |
) | |
for input, value in zip(placeholders, args): | |
input.set_data_from_numpy(value.cpu().numpy()) | |
else: | |
for input in placeholders: | |
value = kwargs[input.name] | |
input.set_data_from_numpy(value.cpu().numpy()) | |
return placeholders | |