Spaces:
Sleeping
Sleeping
import typing | |
from urllib.parse import urlparse | |
import torch | |
class TritonRemoteModel: | |
""" A wrapper over a model served by the Triton Inference Server. It can | |
be configured to communicate over GRPC or HTTP. It accepts Torch Tensors | |
as input and returns them as outputs. | |
""" | |
def __init__(self, url: str): | |
""" | |
Keyword arguments: | |
url: Fully qualified address of the Triton server - for e.g. grpc://localhost:8000 | |
""" | |
parsed_url = urlparse(url) | |
if parsed_url.scheme == "grpc": | |
from tritonclient.grpc import InferenceServerClient, InferInput | |
self.client = InferenceServerClient(parsed_url.netloc) # Triton GRPC client | |
model_repository = self.client.get_model_repository_index() | |
self.model_name = model_repository.models[0].name | |
self.metadata = self.client.get_model_metadata(self.model_name, as_json=True) | |
def create_input_placeholders() -> typing.List[InferInput]: | |
return [ | |
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']] | |
else: | |
from tritonclient.http import InferenceServerClient, InferInput | |
self.client = InferenceServerClient(parsed_url.netloc) # Triton HTTP client | |
model_repository = self.client.get_model_repository_index() | |
self.model_name = model_repository[0]['name'] | |
self.metadata = self.client.get_model_metadata(self.model_name) | |
def create_input_placeholders() -> typing.List[InferInput]: | |
return [ | |
InferInput(i['name'], [int(s) for s in i["shape"]], i['datatype']) for i in self.metadata['inputs']] | |
self._create_input_placeholders_fn = create_input_placeholders | |
def runtime(self): | |
"""Returns the model runtime""" | |
return self.metadata.get("backend", self.metadata.get("platform")) | |
def __call__(self, *args, **kwargs) -> typing.Union[torch.Tensor, typing.Tuple[torch.Tensor, ...]]: | |
""" Invokes the model. Parameters can be provided via args or kwargs. | |
args, if provided, are assumed to match the order of inputs of the model. | |
kwargs are matched with the model input names. | |
""" | |
inputs = self._create_inputs(*args, **kwargs) | |
response = self.client.infer(model_name=self.model_name, inputs=inputs) | |
result = [] | |
for output in self.metadata['outputs']: | |
tensor = torch.as_tensor(response.as_numpy(output['name'])) | |
result.append(tensor) | |
return result[0] if len(result) == 1 else result | |
def _create_inputs(self, *args, **kwargs): | |
args_len, kwargs_len = len(args), len(kwargs) | |
if not args_len and not kwargs_len: | |
raise RuntimeError("No inputs provided.") | |
if args_len and kwargs_len: | |
raise RuntimeError("Cannot specify args and kwargs at the same time") | |
placeholders = self._create_input_placeholders_fn() | |
if args_len: | |
if args_len != len(placeholders): | |
raise RuntimeError(f"Expected {len(placeholders)} inputs, got {args_len}.") | |
for input, value in zip(placeholders, args): | |
input.set_data_from_numpy(value.cpu().numpy()) | |
else: | |
for input in placeholders: | |
value = kwargs[input.name] | |
input.set_data_from_numpy(value.cpu().numpy()) | |
return placeholders | |