OMG / inference /enterprise /stream_management /api /stream_manager_client.py
Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame contribute delete
No virus
9.55 kB
import asyncio
import json
from asyncio import StreamReader, StreamWriter
from json import JSONDecodeError
from typing import Optional, Tuple
from inference.core import logger
from inference.enterprise.stream_management.api.entities import (
CommandContext,
CommandResponse,
InferencePipelineStatusResponse,
ListPipelinesResponse,
PipelineInitialisationRequest,
)
from inference.enterprise.stream_management.api.errors import (
ConnectivityError,
ProcessesManagerAuthorisationError,
ProcessesManagerClientError,
ProcessesManagerInternalError,
ProcessesManagerInvalidPayload,
ProcessesManagerNotFoundError,
ProcessesManagerOperationError,
)
from inference.enterprise.stream_management.manager.entities import (
ERROR_TYPE_KEY,
PIPELINE_ID_KEY,
REQUEST_ID_KEY,
RESPONSE_KEY,
STATUS_KEY,
TYPE_KEY,
CommandType,
ErrorType,
OperationStatus,
)
from inference.enterprise.stream_management.manager.errors import (
CommunicationProtocolError,
MalformedHeaderError,
MalformedPayloadError,
MessageToBigError,
TransmissionChannelClosed,
)
BUFFER_SIZE = 16384
HEADER_SIZE = 4
ERRORS_MAPPING = {
ErrorType.INTERNAL_ERROR.value: ProcessesManagerInternalError,
ErrorType.INVALID_PAYLOAD.value: ProcessesManagerInvalidPayload,
ErrorType.NOT_FOUND.value: ProcessesManagerNotFoundError,
ErrorType.OPERATION_ERROR.value: ProcessesManagerOperationError,
ErrorType.AUTHORISATION_ERROR.value: ProcessesManagerAuthorisationError,
}
class StreamManagerClient:
@classmethod
def init(
cls,
host: str,
port: int,
operations_timeout: Optional[float] = None,
header_size: int = HEADER_SIZE,
buffer_size: int = BUFFER_SIZE,
) -> "StreamManagerClient":
return cls(
host=host,
port=port,
operations_timeout=operations_timeout,
header_size=header_size,
buffer_size=buffer_size,
)
def __init__(
self,
host: str,
port: int,
operations_timeout: Optional[float],
header_size: int,
buffer_size: int,
):
self._host = host
self._port = port
self._operations_timeout = operations_timeout
self._header_size = header_size
self._buffer_size = buffer_size
async def list_pipelines(self) -> ListPipelinesResponse:
command = {
TYPE_KEY: CommandType.LIST_PIPELINES,
}
response = await self._handle_command(command=command)
status = response[RESPONSE_KEY][STATUS_KEY]
context = CommandContext(
request_id=response.get(REQUEST_ID_KEY),
pipeline_id=response.get(PIPELINE_ID_KEY),
)
pipelines = response[RESPONSE_KEY]["pipelines"]
return ListPipelinesResponse(
status=status,
context=context,
pipelines=pipelines,
)
async def initialise_pipeline(
self, initialisation_request: PipelineInitialisationRequest
) -> CommandResponse:
command = initialisation_request.dict(exclude_none=True)
command[TYPE_KEY] = CommandType.INIT
response = await self._handle_command(command=command)
return build_response(response=response)
async def terminate_pipeline(self, pipeline_id: str) -> CommandResponse:
command = {
TYPE_KEY: CommandType.TERMINATE,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
return build_response(response=response)
async def pause_pipeline(self, pipeline_id: str) -> CommandResponse:
command = {
TYPE_KEY: CommandType.MUTE,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
return build_response(response=response)
async def resume_pipeline(self, pipeline_id: str) -> CommandResponse:
command = {
TYPE_KEY: CommandType.RESUME,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
return build_response(response=response)
async def get_status(self, pipeline_id: str) -> InferencePipelineStatusResponse:
command = {
TYPE_KEY: CommandType.STATUS,
PIPELINE_ID_KEY: pipeline_id,
}
response = await self._handle_command(command=command)
status = response[RESPONSE_KEY][STATUS_KEY]
context = CommandContext(
request_id=response.get(REQUEST_ID_KEY),
pipeline_id=response.get(PIPELINE_ID_KEY),
)
report = response[RESPONSE_KEY]["report"]
return InferencePipelineStatusResponse(
status=status,
context=context,
report=report,
)
async def _handle_command(self, command: dict) -> dict:
response = await send_command(
host=self._host,
port=self._port,
command=command,
header_size=self._header_size,
buffer_size=self._buffer_size,
timeout=self._operations_timeout,
)
if is_request_unsuccessful(response=response):
dispatch_error(error_response=response)
return response
async def send_command(
host: str,
port: int,
command: dict,
header_size: int,
buffer_size: int,
timeout: Optional[float] = None,
) -> dict:
try:
reader, writer = await establish_socket_connection(
host=host, port=port, timeout=timeout
)
await send_message(
writer=writer, message=command, header_size=header_size, timeout=timeout
)
data = await receive_message(
reader, header_size=header_size, buffer_size=buffer_size, timeout=timeout
)
writer.close()
await writer.wait_closed()
return json.loads(data)
except JSONDecodeError as error:
raise MalformedPayloadError(
f"Could not decode response. Cause: {error}"
) from error
except (OSError, asyncio.TimeoutError) as errors:
raise ConnectivityError(
f"Could not communicate with Process Manager"
) from errors
async def establish_socket_connection(
host: str, port: int, timeout: Optional[float] = None
) -> Tuple[StreamReader, StreamWriter]:
return await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)
async def send_message(
writer: StreamWriter,
message: dict,
header_size: int,
timeout: Optional[float] = None,
) -> None:
try:
body = json.dumps(message).encode("utf-8")
header = len(body).to_bytes(length=header_size, byteorder="big")
payload = header + body
writer.write(payload)
await asyncio.wait_for(writer.drain(), timeout=timeout)
except TypeError as error:
raise MalformedPayloadError(f"Could not serialise message. Details: {error}")
except OverflowError as error:
raise MessageToBigError(
f"Could not send message due to size overflow. Details: {error}"
)
except asyncio.TimeoutError as error:
raise ConnectivityError(
f"Could not communicate with Process Manager"
) from error
except Exception as error:
raise CommunicationProtocolError(
f"Could not send message. Cause: {error}"
) from error
async def receive_message(
reader: StreamReader,
header_size: int,
buffer_size: int,
timeout: Optional[float] = None,
) -> bytes:
header = await asyncio.wait_for(reader.read(header_size), timeout=timeout)
if len(header) != header_size:
raise MalformedHeaderError("Header size missmatch")
payload_size = int.from_bytes(bytes=header, byteorder="big")
received = b""
while len(received) < payload_size:
chunk = await asyncio.wait_for(reader.read(buffer_size), timeout=timeout)
if len(chunk) == 0:
raise TransmissionChannelClosed(
"Socket was closed to read before payload was decoded."
)
received += chunk
return received
def is_request_unsuccessful(response: dict) -> bool:
return (
response.get(RESPONSE_KEY, {}).get(STATUS_KEY, OperationStatus.FAILURE.value)
!= OperationStatus.SUCCESS.value
)
def dispatch_error(error_response: dict) -> None:
response_payload = error_response.get(RESPONSE_KEY, {})
error_type = response_payload.get(ERROR_TYPE_KEY)
error_class = response_payload.get("error_class", "N/A")
error_message = response_payload.get("error_message", "N/A")
logger.error(
f"Error in ProcessesManagerClient. error_type={error_type} error_class={error_class} "
f"error_message={error_message}"
)
if error_type in ERRORS_MAPPING:
raise ERRORS_MAPPING[error_type](
f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}"
)
raise ProcessesManagerClientError(
f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}"
)
def build_response(response: dict) -> CommandResponse:
status = response[RESPONSE_KEY][STATUS_KEY]
context = CommandContext(
request_id=response.get(REQUEST_ID_KEY),
pipeline_id=response.get(PIPELINE_ID_KEY),
)
return CommandResponse(
status=status,
context=context,
)