Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,490 Bytes
2eafbc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import json
import socket
from typing import Optional
from inference.core import logger
from inference.enterprise.stream_management.manager.entities import ErrorType
from inference.enterprise.stream_management.manager.errors import (
MalformedHeaderError,
MalformedPayloadError,
TransmissionChannelClosed,
)
from inference.enterprise.stream_management.manager.serialisation import (
prepare_error_response,
)
def receive_socket_data(
source: socket.socket, header_size: int, buffer_size: int
) -> dict:
header = source.recv(header_size)
if len(header) != header_size:
raise MalformedHeaderError(
f"Expected header size: {header_size}, received: {header}"
)
payload_size = int.from_bytes(bytes=header, byteorder="big")
if payload_size <= 0:
raise MalformedHeaderError(
f"Header is indicating non positive payload size: {payload_size}"
)
received = b""
while len(received) < payload_size:
chunk = source.recv(buffer_size)
if len(chunk) == 0:
raise TransmissionChannelClosed(
"Socket was closed to read before payload was decoded."
)
received += chunk
try:
return json.loads(received)
except ValueError:
raise MalformedPayloadError("Received payload that is not in a JSON format")
def send_data_trough_socket(
target: socket.socket,
header_size: int,
data: bytes,
request_id: str,
recover_from_overflow: bool = True,
pipeline_id: Optional[str] = None,
) -> None:
try:
data_size = len(data)
header = data_size.to_bytes(length=header_size, byteorder="big")
payload = header + data
target.sendall(payload)
except OverflowError as error:
if not recover_from_overflow:
logger.error(f"OverflowError was suppressed. {error}")
return None
error_response = prepare_error_response(
request_id=request_id,
error=error,
error_type=ErrorType.INTERNAL_ERROR,
pipeline_id=pipeline_id,
)
send_data_trough_socket(
target=target,
header_size=header_size,
data=error_response,
request_id=request_id,
recover_from_overflow=False,
pipeline_id=pipeline_id,
)
except Exception as error:
logger.error(f"Could not send the response through socket. Error: {error}")
|