Spaces:
Runtime error
Runtime error
import asyncio | |
import json | |
from asyncio import StreamReader, StreamWriter | |
from json import JSONDecodeError | |
from typing import Optional, Tuple | |
from inference.core import logger | |
from inference.enterprise.stream_management.api.entities import ( | |
CommandContext, | |
CommandResponse, | |
InferencePipelineStatusResponse, | |
ListPipelinesResponse, | |
PipelineInitialisationRequest, | |
) | |
from inference.enterprise.stream_management.api.errors import ( | |
ConnectivityError, | |
ProcessesManagerAuthorisationError, | |
ProcessesManagerClientError, | |
ProcessesManagerInternalError, | |
ProcessesManagerInvalidPayload, | |
ProcessesManagerNotFoundError, | |
ProcessesManagerOperationError, | |
) | |
from inference.enterprise.stream_management.manager.entities import ( | |
ERROR_TYPE_KEY, | |
PIPELINE_ID_KEY, | |
REQUEST_ID_KEY, | |
RESPONSE_KEY, | |
STATUS_KEY, | |
TYPE_KEY, | |
CommandType, | |
ErrorType, | |
OperationStatus, | |
) | |
from inference.enterprise.stream_management.manager.errors import ( | |
CommunicationProtocolError, | |
MalformedHeaderError, | |
MalformedPayloadError, | |
MessageToBigError, | |
TransmissionChannelClosed, | |
) | |
BUFFER_SIZE = 16384 | |
HEADER_SIZE = 4 | |
ERRORS_MAPPING = { | |
ErrorType.INTERNAL_ERROR.value: ProcessesManagerInternalError, | |
ErrorType.INVALID_PAYLOAD.value: ProcessesManagerInvalidPayload, | |
ErrorType.NOT_FOUND.value: ProcessesManagerNotFoundError, | |
ErrorType.OPERATION_ERROR.value: ProcessesManagerOperationError, | |
ErrorType.AUTHORISATION_ERROR.value: ProcessesManagerAuthorisationError, | |
} | |
class StreamManagerClient: | |
def init( | |
cls, | |
host: str, | |
port: int, | |
operations_timeout: Optional[float] = None, | |
header_size: int = HEADER_SIZE, | |
buffer_size: int = BUFFER_SIZE, | |
) -> "StreamManagerClient": | |
return cls( | |
host=host, | |
port=port, | |
operations_timeout=operations_timeout, | |
header_size=header_size, | |
buffer_size=buffer_size, | |
) | |
def __init__( | |
self, | |
host: str, | |
port: int, | |
operations_timeout: Optional[float], | |
header_size: int, | |
buffer_size: int, | |
): | |
self._host = host | |
self._port = port | |
self._operations_timeout = operations_timeout | |
self._header_size = header_size | |
self._buffer_size = buffer_size | |
async def list_pipelines(self) -> ListPipelinesResponse: | |
command = { | |
TYPE_KEY: CommandType.LIST_PIPELINES, | |
} | |
response = await self._handle_command(command=command) | |
status = response[RESPONSE_KEY][STATUS_KEY] | |
context = CommandContext( | |
request_id=response.get(REQUEST_ID_KEY), | |
pipeline_id=response.get(PIPELINE_ID_KEY), | |
) | |
pipelines = response[RESPONSE_KEY]["pipelines"] | |
return ListPipelinesResponse( | |
status=status, | |
context=context, | |
pipelines=pipelines, | |
) | |
async def initialise_pipeline( | |
self, initialisation_request: PipelineInitialisationRequest | |
) -> CommandResponse: | |
command = initialisation_request.dict(exclude_none=True) | |
command[TYPE_KEY] = CommandType.INIT | |
response = await self._handle_command(command=command) | |
return build_response(response=response) | |
async def terminate_pipeline(self, pipeline_id: str) -> CommandResponse: | |
command = { | |
TYPE_KEY: CommandType.TERMINATE, | |
PIPELINE_ID_KEY: pipeline_id, | |
} | |
response = await self._handle_command(command=command) | |
return build_response(response=response) | |
async def pause_pipeline(self, pipeline_id: str) -> CommandResponse: | |
command = { | |
TYPE_KEY: CommandType.MUTE, | |
PIPELINE_ID_KEY: pipeline_id, | |
} | |
response = await self._handle_command(command=command) | |
return build_response(response=response) | |
async def resume_pipeline(self, pipeline_id: str) -> CommandResponse: | |
command = { | |
TYPE_KEY: CommandType.RESUME, | |
PIPELINE_ID_KEY: pipeline_id, | |
} | |
response = await self._handle_command(command=command) | |
return build_response(response=response) | |
async def get_status(self, pipeline_id: str) -> InferencePipelineStatusResponse: | |
command = { | |
TYPE_KEY: CommandType.STATUS, | |
PIPELINE_ID_KEY: pipeline_id, | |
} | |
response = await self._handle_command(command=command) | |
status = response[RESPONSE_KEY][STATUS_KEY] | |
context = CommandContext( | |
request_id=response.get(REQUEST_ID_KEY), | |
pipeline_id=response.get(PIPELINE_ID_KEY), | |
) | |
report = response[RESPONSE_KEY]["report"] | |
return InferencePipelineStatusResponse( | |
status=status, | |
context=context, | |
report=report, | |
) | |
async def _handle_command(self, command: dict) -> dict: | |
response = await send_command( | |
host=self._host, | |
port=self._port, | |
command=command, | |
header_size=self._header_size, | |
buffer_size=self._buffer_size, | |
timeout=self._operations_timeout, | |
) | |
if is_request_unsuccessful(response=response): | |
dispatch_error(error_response=response) | |
return response | |
async def send_command( | |
host: str, | |
port: int, | |
command: dict, | |
header_size: int, | |
buffer_size: int, | |
timeout: Optional[float] = None, | |
) -> dict: | |
try: | |
reader, writer = await establish_socket_connection( | |
host=host, port=port, timeout=timeout | |
) | |
await send_message( | |
writer=writer, message=command, header_size=header_size, timeout=timeout | |
) | |
data = await receive_message( | |
reader, header_size=header_size, buffer_size=buffer_size, timeout=timeout | |
) | |
writer.close() | |
await writer.wait_closed() | |
return json.loads(data) | |
except JSONDecodeError as error: | |
raise MalformedPayloadError( | |
f"Could not decode response. Cause: {error}" | |
) from error | |
except (OSError, asyncio.TimeoutError) as errors: | |
raise ConnectivityError( | |
f"Could not communicate with Process Manager" | |
) from errors | |
async def establish_socket_connection( | |
host: str, port: int, timeout: Optional[float] = None | |
) -> Tuple[StreamReader, StreamWriter]: | |
return await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) | |
async def send_message( | |
writer: StreamWriter, | |
message: dict, | |
header_size: int, | |
timeout: Optional[float] = None, | |
) -> None: | |
try: | |
body = json.dumps(message).encode("utf-8") | |
header = len(body).to_bytes(length=header_size, byteorder="big") | |
payload = header + body | |
writer.write(payload) | |
await asyncio.wait_for(writer.drain(), timeout=timeout) | |
except TypeError as error: | |
raise MalformedPayloadError(f"Could not serialise message. Details: {error}") | |
except OverflowError as error: | |
raise MessageToBigError( | |
f"Could not send message due to size overflow. Details: {error}" | |
) | |
except asyncio.TimeoutError as error: | |
raise ConnectivityError( | |
f"Could not communicate with Process Manager" | |
) from error | |
except Exception as error: | |
raise CommunicationProtocolError( | |
f"Could not send message. Cause: {error}" | |
) from error | |
async def receive_message( | |
reader: StreamReader, | |
header_size: int, | |
buffer_size: int, | |
timeout: Optional[float] = None, | |
) -> bytes: | |
header = await asyncio.wait_for(reader.read(header_size), timeout=timeout) | |
if len(header) != header_size: | |
raise MalformedHeaderError("Header size missmatch") | |
payload_size = int.from_bytes(bytes=header, byteorder="big") | |
received = b"" | |
while len(received) < payload_size: | |
chunk = await asyncio.wait_for(reader.read(buffer_size), timeout=timeout) | |
if len(chunk) == 0: | |
raise TransmissionChannelClosed( | |
"Socket was closed to read before payload was decoded." | |
) | |
received += chunk | |
return received | |
def is_request_unsuccessful(response: dict) -> bool: | |
return ( | |
response.get(RESPONSE_KEY, {}).get(STATUS_KEY, OperationStatus.FAILURE.value) | |
!= OperationStatus.SUCCESS.value | |
) | |
def dispatch_error(error_response: dict) -> None: | |
response_payload = error_response.get(RESPONSE_KEY, {}) | |
error_type = response_payload.get(ERROR_TYPE_KEY) | |
error_class = response_payload.get("error_class", "N/A") | |
error_message = response_payload.get("error_message", "N/A") | |
logger.error( | |
f"Error in ProcessesManagerClient. error_type={error_type} error_class={error_class} " | |
f"error_message={error_message}" | |
) | |
if error_type in ERRORS_MAPPING: | |
raise ERRORS_MAPPING[error_type]( | |
f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" | |
) | |
raise ProcessesManagerClientError( | |
f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" | |
) | |
def build_response(response: dict) -> CommandResponse: | |
status = response[RESPONSE_KEY][STATUS_KEY] | |
context = CommandContext( | |
request_id=response.get(REQUEST_ID_KEY), | |
pipeline_id=response.get(PIPELINE_ID_KEY), | |
) | |
return CommandResponse( | |
status=status, | |
context=context, | |
) | |