import asyncio import json from asyncio import StreamReader, StreamWriter from json import JSONDecodeError from typing import Optional, Tuple from inference.core import logger from inference.enterprise.stream_management.api.entities import ( CommandContext, CommandResponse, InferencePipelineStatusResponse, ListPipelinesResponse, PipelineInitialisationRequest, ) from inference.enterprise.stream_management.api.errors import ( ConnectivityError, ProcessesManagerAuthorisationError, ProcessesManagerClientError, ProcessesManagerInternalError, ProcessesManagerInvalidPayload, ProcessesManagerNotFoundError, ProcessesManagerOperationError, ) from inference.enterprise.stream_management.manager.entities import ( ERROR_TYPE_KEY, PIPELINE_ID_KEY, REQUEST_ID_KEY, RESPONSE_KEY, STATUS_KEY, TYPE_KEY, CommandType, ErrorType, OperationStatus, ) from inference.enterprise.stream_management.manager.errors import ( CommunicationProtocolError, MalformedHeaderError, MalformedPayloadError, MessageToBigError, TransmissionChannelClosed, ) BUFFER_SIZE = 16384 HEADER_SIZE = 4 ERRORS_MAPPING = { ErrorType.INTERNAL_ERROR.value: ProcessesManagerInternalError, ErrorType.INVALID_PAYLOAD.value: ProcessesManagerInvalidPayload, ErrorType.NOT_FOUND.value: ProcessesManagerNotFoundError, ErrorType.OPERATION_ERROR.value: ProcessesManagerOperationError, ErrorType.AUTHORISATION_ERROR.value: ProcessesManagerAuthorisationError, } class StreamManagerClient: @classmethod def init( cls, host: str, port: int, operations_timeout: Optional[float] = None, header_size: int = HEADER_SIZE, buffer_size: int = BUFFER_SIZE, ) -> "StreamManagerClient": return cls( host=host, port=port, operations_timeout=operations_timeout, header_size=header_size, buffer_size=buffer_size, ) def __init__( self, host: str, port: int, operations_timeout: Optional[float], header_size: int, buffer_size: int, ): self._host = host self._port = port self._operations_timeout = operations_timeout self._header_size = header_size self._buffer_size = buffer_size async def list_pipelines(self) -> ListPipelinesResponse: command = { TYPE_KEY: CommandType.LIST_PIPELINES, } response = await self._handle_command(command=command) status = response[RESPONSE_KEY][STATUS_KEY] context = CommandContext( request_id=response.get(REQUEST_ID_KEY), pipeline_id=response.get(PIPELINE_ID_KEY), ) pipelines = response[RESPONSE_KEY]["pipelines"] return ListPipelinesResponse( status=status, context=context, pipelines=pipelines, ) async def initialise_pipeline( self, initialisation_request: PipelineInitialisationRequest ) -> CommandResponse: command = initialisation_request.dict(exclude_none=True) command[TYPE_KEY] = CommandType.INIT response = await self._handle_command(command=command) return build_response(response=response) async def terminate_pipeline(self, pipeline_id: str) -> CommandResponse: command = { TYPE_KEY: CommandType.TERMINATE, PIPELINE_ID_KEY: pipeline_id, } response = await self._handle_command(command=command) return build_response(response=response) async def pause_pipeline(self, pipeline_id: str) -> CommandResponse: command = { TYPE_KEY: CommandType.MUTE, PIPELINE_ID_KEY: pipeline_id, } response = await self._handle_command(command=command) return build_response(response=response) async def resume_pipeline(self, pipeline_id: str) -> CommandResponse: command = { TYPE_KEY: CommandType.RESUME, PIPELINE_ID_KEY: pipeline_id, } response = await self._handle_command(command=command) return build_response(response=response) async def get_status(self, pipeline_id: str) -> InferencePipelineStatusResponse: command = { TYPE_KEY: CommandType.STATUS, PIPELINE_ID_KEY: pipeline_id, } response = await self._handle_command(command=command) status = response[RESPONSE_KEY][STATUS_KEY] context = CommandContext( request_id=response.get(REQUEST_ID_KEY), pipeline_id=response.get(PIPELINE_ID_KEY), ) report = response[RESPONSE_KEY]["report"] return InferencePipelineStatusResponse( status=status, context=context, report=report, ) async def _handle_command(self, command: dict) -> dict: response = await send_command( host=self._host, port=self._port, command=command, header_size=self._header_size, buffer_size=self._buffer_size, timeout=self._operations_timeout, ) if is_request_unsuccessful(response=response): dispatch_error(error_response=response) return response async def send_command( host: str, port: int, command: dict, header_size: int, buffer_size: int, timeout: Optional[float] = None, ) -> dict: try: reader, writer = await establish_socket_connection( host=host, port=port, timeout=timeout ) await send_message( writer=writer, message=command, header_size=header_size, timeout=timeout ) data = await receive_message( reader, header_size=header_size, buffer_size=buffer_size, timeout=timeout ) writer.close() await writer.wait_closed() return json.loads(data) except JSONDecodeError as error: raise MalformedPayloadError( f"Could not decode response. Cause: {error}" ) from error except (OSError, asyncio.TimeoutError) as errors: raise ConnectivityError( f"Could not communicate with Process Manager" ) from errors async def establish_socket_connection( host: str, port: int, timeout: Optional[float] = None ) -> Tuple[StreamReader, StreamWriter]: return await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) async def send_message( writer: StreamWriter, message: dict, header_size: int, timeout: Optional[float] = None, ) -> None: try: body = json.dumps(message).encode("utf-8") header = len(body).to_bytes(length=header_size, byteorder="big") payload = header + body writer.write(payload) await asyncio.wait_for(writer.drain(), timeout=timeout) except TypeError as error: raise MalformedPayloadError(f"Could not serialise message. Details: {error}") except OverflowError as error: raise MessageToBigError( f"Could not send message due to size overflow. Details: {error}" ) except asyncio.TimeoutError as error: raise ConnectivityError( f"Could not communicate with Process Manager" ) from error except Exception as error: raise CommunicationProtocolError( f"Could not send message. Cause: {error}" ) from error async def receive_message( reader: StreamReader, header_size: int, buffer_size: int, timeout: Optional[float] = None, ) -> bytes: header = await asyncio.wait_for(reader.read(header_size), timeout=timeout) if len(header) != header_size: raise MalformedHeaderError("Header size missmatch") payload_size = int.from_bytes(bytes=header, byteorder="big") received = b"" while len(received) < payload_size: chunk = await asyncio.wait_for(reader.read(buffer_size), timeout=timeout) if len(chunk) == 0: raise TransmissionChannelClosed( "Socket was closed to read before payload was decoded." ) received += chunk return received def is_request_unsuccessful(response: dict) -> bool: return ( response.get(RESPONSE_KEY, {}).get(STATUS_KEY, OperationStatus.FAILURE.value) != OperationStatus.SUCCESS.value ) def dispatch_error(error_response: dict) -> None: response_payload = error_response.get(RESPONSE_KEY, {}) error_type = response_payload.get(ERROR_TYPE_KEY) error_class = response_payload.get("error_class", "N/A") error_message = response_payload.get("error_message", "N/A") logger.error( f"Error in ProcessesManagerClient. error_type={error_type} error_class={error_class} " f"error_message={error_message}" ) if error_type in ERRORS_MAPPING: raise ERRORS_MAPPING[error_type]( f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" ) raise ProcessesManagerClientError( f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" ) def build_response(response: dict) -> CommandResponse: status = response[RESPONSE_KEY][STATUS_KEY] context = CommandContext( request_id=response.get(REQUEST_ID_KEY), pipeline_id=response.get(PIPELINE_ID_KEY), ) return CommandResponse( status=status, context=context, )