yerang's picture
Upload 1110 files
e3af00f verified
raw
history blame
10.5 kB
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""InferenceHandler class.
Create a connection between the Python and Triton Inference Server to pass request and response between the
python model deployed on Triton and model in current environment.
Examples of use:
model_config = ModelConfigParser.from_file("repo/model/config.pbtxt")
handler = InferenceHandler(
model_callable=infer_func,
model_config=model_config,
shared_memory_socket="ipc:///tmp/shared-memory-socker.ipc"
)
handler.run()
"""
import enum
import inspect
import itertools
import logging
import threading as th
import traceback
import typing
from typing import Callable
import zmq # pytype: disable=import-error
from pytriton.exceptions import PyTritonRuntimeError, PyTritonUnrecoverableError
from pytriton.model_config.triton_model_config import TritonModelConfig
from pytriton.proxy.communication import (
InferenceHandlerRequests,
InferenceHandlerResponses,
MetaRequestResponse,
TensorStore,
)
from pytriton.proxy.types import Request
from pytriton.proxy.validators import TritonResultsValidator
LOGGER = logging.getLogger(__name__)
class InferenceHandlerEvent(enum.Enum):
"""Represents proxy backend event."""
STARTED = "started"
UNRECOVERABLE_ERROR = "unrecoverable-error"
FINISHED = "finished"
InferenceEventsHandler = typing.Callable[["InferenceHandler", InferenceHandlerEvent, typing.Optional[typing.Any]], None]
class _ResponsesIterator:
"""Iterator for gathering all responses (also from generator)."""
def __init__(self, responses, decoupled: bool = False):
"""Init _ResponsesIterator object.
Args:
responses: responses from model callable
decoupled: is model decoupled
"""
self._is_generator = inspect.isgenerator(responses)
if self._is_generator and not decoupled:
raise PyTritonRuntimeError(
"Results generator is not supported for non-decoupled models. "
"Don't know how to aggregate partial results from generator into single response. "
"Enable streaming in model configuration or return list of responses."
)
self._responses = responses
self._iterator = None
def __iter__(self):
"""Return iterator."""
if self._is_generator:
self._iterator = self._responses
else:
self._iterator = iter([self._responses]) if self._responses is not None else iter([])
self._responses = None
return self
def __next__(self):
"""Return next response."""
return next(self._iterator)
class InferenceHandler(th.Thread):
"""InferenceHandler class for handling the communication between Triton and local model."""
def __init__(
self,
model_callable: Callable,
model_config: TritonModelConfig,
shared_memory_socket: str,
data_store_socket: str,
zmq_context: zmq.Context,
validator: TritonResultsValidator,
):
"""Create a PythonBackend object.
Args:
model_callable: A model callable to pass and receive the data
model_config: Triton model configuration
shared_memory_socket: Socket path for shared memory communication
data_store_socket: Socket path for data store communication
zmq_context: zero mq context
validator: Result validator instance
"""
super().__init__()
self._model_config = model_config
self._model_callable = model_callable
self._validator = validator
self.stopped = False
self._tensor_store = TensorStore(data_store_socket)
self.zmq_context = zmq_context
self.socket = None
self.shared_memory_socket = shared_memory_socket
self._proxy_backend_events_observers: typing.List[InferenceEventsHandler] = []
def run(self) -> None:
"""Start the InferenceHandler communication."""
self.socket = self.zmq_context.socket(zmq.REP)
model_name = self._model_config.model_name
try:
LOGGER.debug(f"Binding IPC socket at {self.shared_memory_socket}.")
self.socket.bind(self.shared_memory_socket)
self._tensor_store.connect()
while not self.stopped:
LOGGER.debug(f"Waiting for requests from proxy model for {model_name}.")
request_payload = self.socket.recv()
requests = InferenceHandlerRequests.from_bytes(request_payload).requests
LOGGER.debug(f"Preparing inputs for {model_name}.")
inputs = [
Request(
data={
input_name: self._tensor_store.get(tensor_id)
for input_name, tensor_id in request.data.items()
},
parameters=request.parameters,
)
for request in requests
]
try:
LOGGER.debug(f"Processing inference callback for {model_name}.")
responses = self._model_callable(inputs)
responses_iterator = _ResponsesIterator(responses, decoupled=self._model_config.decoupled)
for responses in responses_iterator:
LOGGER.debug(f"Validating outputs for {self._model_config.model_name}.")
self._validator.validate_responses(inputs, responses)
LOGGER.debug(f"Copying outputs to shared memory for {model_name}.")
output_arrays_with_coords = [
(response_idx, output_name, tensor)
for response_idx, response in enumerate(responses)
for output_name, tensor in response.items()
]
tensor_ids = self._tensor_store.put([tensor for _, _, tensor in output_arrays_with_coords])
responses = [{} for _ in range(len(responses))]
for (response_idx, output_name, _), tensor_id in zip(output_arrays_with_coords, tensor_ids):
responses[response_idx][output_name] = tensor_id
responses = InferenceHandlerResponses(
responses=[
MetaRequestResponse(idx=idx, data=response, eos=False)
for idx, response in enumerate(responses)
],
)
LOGGER.debug(f"Sending response: {responses}")
self.socket.send(responses.as_bytes())
self.socket.recv() # wait for ack
responses = InferenceHandlerResponses(
responses=[MetaRequestResponse(idx=idx, eos=True) for idx in range(len(requests))]
)
LOGGER.debug(f"Send eos response to proxy model for {model_name}.")
self.socket.send(responses.as_bytes())
except PyTritonUnrecoverableError:
error = traceback.format_exc()
responses = InferenceHandlerResponses(error=error)
LOGGER.error(
"Unrecoverable error thrown during calling model callable. "
"Shutting down Triton Inference Server. "
f"{error}"
)
self.stopped = True
self._notify_proxy_backend_observers(InferenceHandlerEvent.UNRECOVERABLE_ERROR, error)
LOGGER.debug(f"Send response to proxy model for {model_name}.")
self.socket.send(responses.as_bytes())
except Exception:
error = traceback.format_exc()
responses = InferenceHandlerResponses(error=error)
LOGGER.error(f"Error occurred during calling model callable: {error}")
self.socket.send(responses.as_bytes())
finally:
for tensor_id in itertools.chain(*[request.data.values() for request in requests]):
self._tensor_store.release_block(tensor_id)
except zmq.error.ContextTerminated:
LOGGER.info("Context was terminated. InferenceHandler will be closed.")
except Exception as exception:
LOGGER.error("Internal proxy backend error. InferenceHandler will be closed.")
LOGGER.exception(exception)
finally:
LOGGER.info("Closing socket")
socket_close_timeout_s = 0
self.socket.close(linger=socket_close_timeout_s)
LOGGER.info("Closing TensorStore")
self._tensor_store.close()
LOGGER.info("Leaving proxy backend thread")
self._notify_proxy_backend_observers(InferenceHandlerEvent.FINISHED, None)
def stop(self) -> None:
"""Stop the InferenceHandler communication."""
LOGGER.info("Closing proxy")
self.stopped = True
self.join()
def on_proxy_backend_event(self, proxy_backend_event_handle_fn: InferenceEventsHandler):
"""Register InferenceEventsHandler callable.
Args:
proxy_backend_event_handle_fn: function to be called when proxy backend events arises
"""
self._proxy_backend_events_observers.append(proxy_backend_event_handle_fn)
def _notify_proxy_backend_observers(self, event: InferenceHandlerEvent, context: typing.Optional[typing.Any]):
for proxy_backend_event_handle_fn in self._proxy_backend_events_observers:
proxy_backend_event_handle_fn(self, event, context)