Spaces:
Running
Running
""" | |
async with websockets.connect( # type: ignore | |
url, | |
extra_headers={ | |
"api-key": api_key, # type: ignore | |
}, | |
) as backend_ws: | |
forward_task = asyncio.create_task( | |
forward_messages(websocket, backend_ws) | |
) | |
try: | |
while True: | |
message = await websocket.receive_text() | |
await backend_ws.send(message) | |
except websockets.exceptions.ConnectionClosed: # type: ignore | |
forward_task.cancel() | |
finally: | |
if not forward_task.done(): | |
forward_task.cancel() | |
try: | |
await forward_task | |
except asyncio.CancelledError: | |
pass | |
""" | |
import asyncio | |
import concurrent.futures | |
import json | |
from typing import Any, Dict, List, Optional, Union | |
import litellm | |
from litellm._logging import verbose_logger | |
from litellm.types.llms.openai import ( | |
OpenAIRealtimeStreamResponseBaseObject, | |
OpenAIRealtimeStreamSessionEvents, | |
) | |
from .litellm_logging import Logging as LiteLLMLogging | |
# Create a thread pool with a maximum of 10 threads | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) | |
DefaultLoggedRealTimeEventTypes = [ | |
"session.created", | |
"response.create", | |
"response.done", | |
] | |
class RealTimeStreaming: | |
def __init__( | |
self, | |
websocket: Any, | |
backend_ws: Any, | |
logging_obj: Optional[LiteLLMLogging] = None, | |
): | |
self.websocket = websocket | |
self.backend_ws = backend_ws | |
self.logging_obj = logging_obj | |
self.messages: List[ | |
Union[ | |
OpenAIRealtimeStreamResponseBaseObject, | |
OpenAIRealtimeStreamSessionEvents, | |
] | |
] = [] | |
self.input_message: Dict = {} | |
_logged_real_time_event_types = litellm.logged_real_time_event_types | |
if _logged_real_time_event_types is None: | |
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes | |
self.logged_real_time_event_types = _logged_real_time_event_types | |
def _should_store_message( | |
self, | |
message_obj: Union[ | |
dict, | |
OpenAIRealtimeStreamSessionEvents, | |
OpenAIRealtimeStreamResponseBaseObject, | |
], | |
) -> bool: | |
_msg_type = message_obj["type"] | |
if self.logged_real_time_event_types == "*": | |
return True | |
if _msg_type in self.logged_real_time_event_types: | |
return True | |
return False | |
def store_message(self, message: Union[str, bytes]): | |
"""Store message in list""" | |
if isinstance(message, bytes): | |
message = message.decode("utf-8") | |
message_obj = json.loads(message) | |
try: | |
if ( | |
message_obj.get("type") == "session.created" | |
or message_obj.get("type") == "session.updated" | |
): | |
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore | |
else: | |
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore | |
except Exception as e: | |
verbose_logger.debug(f"Error parsing message for logging: {e}") | |
raise e | |
if self._should_store_message(message_obj): | |
self.messages.append(message_obj) | |
def store_input(self, message: dict): | |
"""Store input message""" | |
self.input_message = message | |
if self.logging_obj: | |
self.logging_obj.pre_call(input=message, api_key="") | |
async def log_messages(self): | |
"""Log messages in list""" | |
if self.logging_obj: | |
## ASYNC LOGGING | |
# Create an event loop for the new thread | |
asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) | |
## SYNC LOGGING | |
executor.submit(self.logging_obj.success_handler(self.messages)) | |
async def backend_to_client_send_messages(self): | |
import websockets | |
try: | |
while True: | |
message = await self.backend_ws.recv() | |
await self.websocket.send_text(message) | |
## LOGGING | |
self.store_message(message) | |
except websockets.exceptions.ConnectionClosed: # type: ignore | |
pass | |
except Exception: | |
pass | |
finally: | |
await self.log_messages() | |
async def client_ack_messages(self): | |
try: | |
while True: | |
message = await self.websocket.receive_text() | |
## LOGGING | |
self.store_input(message=message) | |
## FORWARD TO BACKEND | |
await self.backend_ws.send(message) | |
except self.websockets.exceptions.ConnectionClosed: # type: ignore | |
pass | |
async def bidirectional_forward(self): | |
forward_task = asyncio.create_task(self.backend_to_client_send_messages()) | |
try: | |
await self.client_ack_messages() | |
except self.websockets.exceptions.ConnectionClosed: # type: ignore | |
forward_task.cancel() | |
finally: | |
if not forward_task.done(): | |
forward_task.cancel() | |
try: | |
await forward_task | |
except asyncio.CancelledError: | |
pass | |