Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
9.89 kB
import os
import signal
import socket
import sys
from functools import partial
from multiprocessing import Process, Queue
from socketserver import BaseRequestHandler, BaseServer
from types import FrameType
from typing import Any, Dict, Optional, Tuple
from uuid import uuid4
from inference.core import logger
from inference.enterprise.stream_management.manager.communication import (
receive_socket_data,
send_data_trough_socket,
)
from inference.enterprise.stream_management.manager.entities import (
PIPELINE_ID_KEY,
STATUS_KEY,
TYPE_KEY,
CommandType,
ErrorType,
OperationStatus,
)
from inference.enterprise.stream_management.manager.errors import MalformedPayloadError
from inference.enterprise.stream_management.manager.inference_pipeline_manager import (
InferencePipelineManager,
)
from inference.enterprise.stream_management.manager.serialisation import (
describe_error,
prepare_error_response,
prepare_response,
)
from inference.enterprise.stream_management.manager.tcp_server import RoboflowTCPServer
PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue]] = {}
HEADER_SIZE = 4
SOCKET_BUFFER_SIZE = 16384
HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1")
PORT = int(os.getenv("STREAM_MANAGER_PORT", "7070"))
SOCKET_TIMEOUT = float(os.getenv("STREAM_MANAGER_SOCKET_TIMEOUT", "5.0"))
class InferencePipelinesManagerHandler(BaseRequestHandler):
def __init__(
self,
request: socket.socket,
client_address: Any,
server: BaseServer,
processes_table: Dict[str, Tuple[Process, Queue, Queue]],
):
self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes handle()
super().__init__(request, client_address, server)
def handle(self) -> None:
pipeline_id: Optional[str] = None
request_id = str(uuid4())
try:
data = receive_socket_data(
source=self.request,
header_size=HEADER_SIZE,
buffer_size=SOCKET_BUFFER_SIZE,
)
data[TYPE_KEY] = CommandType(data[TYPE_KEY])
if data[TYPE_KEY] is CommandType.LIST_PIPELINES:
return self._list_pipelines(request_id=request_id)
if data[TYPE_KEY] is CommandType.INIT:
return self._initialise_pipeline(request_id=request_id, command=data)
pipeline_id = data[PIPELINE_ID_KEY]
if data[TYPE_KEY] is CommandType.TERMINATE:
self._terminate_pipeline(
request_id=request_id, pipeline_id=pipeline_id, command=data
)
else:
response = handle_command(
processes_table=self._processes_table,
request_id=request_id,
pipeline_id=pipeline_id,
command=data,
)
serialised_response = prepare_response(
request_id=request_id, response=response, pipeline_id=pipeline_id
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=serialised_response,
request_id=request_id,
pipeline_id=pipeline_id,
)
except (KeyError, ValueError, MalformedPayloadError) as error:
logger.error(
f"Invalid payload in processes manager. error={error} request_id={request_id}..."
)
payload = prepare_error_response(
request_id=request_id,
error=error,
error_type=ErrorType.INVALID_PAYLOAD,
pipeline_id=pipeline_id,
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=payload,
request_id=request_id,
pipeline_id=pipeline_id,
)
except Exception as error:
logger.error(
f"Internal error in processes manager. error={error} request_id={request_id}..."
)
payload = prepare_error_response(
request_id=request_id,
error=error,
error_type=ErrorType.INTERNAL_ERROR,
pipeline_id=pipeline_id,
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=payload,
request_id=request_id,
pipeline_id=pipeline_id,
)
def _list_pipelines(self, request_id: str) -> None:
serialised_response = prepare_response(
request_id=request_id,
response={
"pipelines": list(self._processes_table.keys()),
STATUS_KEY: OperationStatus.SUCCESS,
},
pipeline_id=None,
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=serialised_response,
request_id=request_id,
)
def _initialise_pipeline(self, request_id: str, command: dict) -> None:
pipeline_id = str(uuid4())
command_queue = Queue()
responses_queue = Queue()
inference_pipeline_manager = InferencePipelineManager.init(
command_queue=command_queue,
responses_queue=responses_queue,
)
inference_pipeline_manager.start()
self._processes_table[pipeline_id] = (
inference_pipeline_manager,
command_queue,
responses_queue,
)
command_queue.put((request_id, command))
response = get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
)
serialised_response = prepare_response(
request_id=request_id, response=response, pipeline_id=pipeline_id
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=serialised_response,
request_id=request_id,
pipeline_id=pipeline_id,
)
def _terminate_pipeline(
self, request_id: str, pipeline_id: str, command: dict
) -> None:
response = handle_command(
processes_table=self._processes_table,
request_id=request_id,
pipeline_id=pipeline_id,
command=command,
)
if response[STATUS_KEY] is OperationStatus.SUCCESS:
logger.info(
f"Joining inference pipeline. pipeline_id={pipeline_id} request_id={request_id}"
)
join_inference_pipeline(
processes_table=self._processes_table, pipeline_id=pipeline_id
)
logger.info(
f"Joined inference pipeline. pipeline_id={pipeline_id} request_id={request_id}"
)
serialised_response = prepare_response(
request_id=request_id, response=response, pipeline_id=pipeline_id
)
send_data_trough_socket(
target=self.request,
header_size=HEADER_SIZE,
data=serialised_response,
request_id=request_id,
pipeline_id=pipeline_id,
)
def handle_command(
processes_table: Dict[str, Tuple[Process, Queue, Queue]],
request_id: str,
pipeline_id: str,
command: dict,
) -> dict:
if pipeline_id not in processes_table:
return describe_error(exception=None, error_type=ErrorType.NOT_FOUND)
_, command_queue, responses_queue = processes_table[pipeline_id]
command_queue.put((request_id, command))
return get_response_ignoring_thrash(
responses_queue=responses_queue, matching_request_id=request_id
)
def get_response_ignoring_thrash(
responses_queue: Queue, matching_request_id: str
) -> dict:
while True:
response = responses_queue.get()
if response[0] == matching_request_id:
return response[1]
logger.warning(
f"Dropping response for request_id={response[0]} with payload={response[1]}"
)
def execute_termination(
signal_number: int,
frame: FrameType,
processes_table: Dict[str, Tuple[Process, Queue, Queue]],
) -> None:
pipeline_ids = list(processes_table.keys())
for pipeline_id in pipeline_ids:
logger.info(f"Terminating pipeline: {pipeline_id}")
processes_table[pipeline_id][0].terminate()
logger.info(f"Pipeline: {pipeline_id} terminated.")
logger.info(f"Joining pipeline: {pipeline_id}")
processes_table[pipeline_id][0].join()
logger.info(f"Pipeline: {pipeline_id} joined.")
logger.info(f"Termination handler completed.")
sys.exit(0)
def join_inference_pipeline(
processes_table: Dict[str, Tuple[Process, Queue, Queue]], pipeline_id: str
) -> None:
inference_pipeline_manager, command_queue, responses_queue = processes_table[
pipeline_id
]
inference_pipeline_manager.join()
del processes_table[pipeline_id]
if __name__ == "__main__":
signal.signal(
signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE)
)
signal.signal(
signal.SIGTERM, partial(execute_termination, processes_table=PROCESSES_TABLE)
)
with RoboflowTCPServer(
server_address=(HOST, PORT),
handler_class=partial(
InferencePipelinesManagerHandler, processes_table=PROCESSES_TABLE
),
socket_operations_timeout=SOCKET_TIMEOUT,
) as tcp_server:
logger.info(
f"Inference Pipeline Processes Manager is ready to accept connections at {(HOST, PORT)}"
)
tcp_server.serve_forever()