Spaces:
Sleeping
Sleeping
import json | |
from typing import AsyncIterator, Iterator, List, Optional, Union | |
import httpx | |
import litellm | |
from litellm import verbose_logger | |
from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
from litellm.types.utils import GenericStreamingChunk as GChunk | |
from litellm.types.utils import StreamingChatCompletionChunk | |
_response_stream_shape_cache = None | |
class SagemakerError(BaseLLMException): | |
def __init__( | |
self, | |
status_code: int, | |
message: str, | |
headers: Optional[Union[dict, httpx.Headers]] = None, | |
): | |
super().__init__(status_code=status_code, message=message, headers=headers) | |
class AWSEventStreamDecoder: | |
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: | |
from botocore.parsers import EventStreamJSONParser | |
self.model = model | |
self.parser = EventStreamJSONParser() | |
self.content_blocks: List = [] | |
self.is_messages_api = is_messages_api | |
def _chunk_parser_messages_api( | |
self, chunk_data: dict | |
) -> StreamingChatCompletionChunk: | |
openai_chunk = StreamingChatCompletionChunk(**chunk_data) | |
return openai_chunk | |
def _chunk_parser(self, chunk_data: dict) -> GChunk: | |
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) | |
_token = chunk_data.get("token", {}) or {} | |
_index = chunk_data.get("index", None) or 0 | |
is_finished = False | |
finish_reason = "" | |
_text = _token.get("text", "") | |
if _text == "<|endoftext|>": | |
return GChunk( | |
text="", | |
index=_index, | |
is_finished=True, | |
finish_reason="stop", | |
usage=None, | |
) | |
return GChunk( | |
text=_text, | |
index=_index, | |
is_finished=is_finished, | |
finish_reason=finish_reason, | |
usage=None, | |
) | |
def iter_bytes( | |
self, iterator: Iterator[bytes] | |
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: | |
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" | |
from botocore.eventstream import EventStreamBuffer | |
event_stream_buffer = EventStreamBuffer() | |
accumulated_json = "" | |
for chunk in iterator: | |
event_stream_buffer.add_data(chunk) | |
for event in event_stream_buffer: | |
message = self._parse_message_from_event(event) | |
if message: | |
# remove data: prefix and "\n\n" at the end | |
message = ( | |
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message) | |
or "" | |
) | |
message = message.replace("\n\n", "") | |
# Accumulate JSON data | |
accumulated_json += message | |
# Try to parse the accumulated JSON | |
try: | |
_data = json.loads(accumulated_json) | |
if self.is_messages_api: | |
yield self._chunk_parser_messages_api(chunk_data=_data) | |
else: | |
yield self._chunk_parser(chunk_data=_data) | |
# Reset accumulated_json after successful parsing | |
accumulated_json = "" | |
except json.JSONDecodeError: | |
# If it's not valid JSON yet, continue to the next event | |
continue | |
# Handle any remaining data after the iterator is exhausted | |
if accumulated_json: | |
try: | |
_data = json.loads(accumulated_json) | |
if self.is_messages_api: | |
yield self._chunk_parser_messages_api(chunk_data=_data) | |
else: | |
yield self._chunk_parser(chunk_data=_data) | |
except json.JSONDecodeError: | |
# Handle or log any unparseable data at the end | |
verbose_logger.error( | |
f"Warning: Unparseable JSON data remained: {accumulated_json}" | |
) | |
yield None | |
async def aiter_bytes( | |
self, iterator: AsyncIterator[bytes] | |
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: | |
"""Given an async iterator that yields lines, iterate over it & yield every event encountered""" | |
from botocore.eventstream import EventStreamBuffer | |
event_stream_buffer = EventStreamBuffer() | |
accumulated_json = "" | |
async for chunk in iterator: | |
event_stream_buffer.add_data(chunk) | |
for event in event_stream_buffer: | |
try: | |
message = self._parse_message_from_event(event) | |
if message: | |
verbose_logger.debug( | |
"sagemaker parsed chunk bytes %s", message | |
) | |
# remove data: prefix and "\n\n" at the end | |
message = ( | |
litellm.CustomStreamWrapper._strip_sse_data_from_chunk( | |
message | |
) | |
or "" | |
) | |
message = message.replace("\n\n", "") | |
# Accumulate JSON data | |
accumulated_json += message | |
# Try to parse the accumulated JSON | |
_data = json.loads(accumulated_json) | |
if self.is_messages_api: | |
yield self._chunk_parser_messages_api(chunk_data=_data) | |
else: | |
yield self._chunk_parser(chunk_data=_data) | |
# Reset accumulated_json after successful parsing | |
accumulated_json = "" | |
except json.JSONDecodeError: | |
# If it's not valid JSON yet, continue to the next event | |
continue | |
except UnicodeDecodeError as e: | |
verbose_logger.warning( | |
f"UnicodeDecodeError: {e}. Attempting to combine with next event." | |
) | |
continue | |
except Exception as e: | |
verbose_logger.error( | |
f"Error parsing message: {e}. Attempting to combine with next event." | |
) | |
continue | |
# Handle any remaining data after the iterator is exhausted | |
if accumulated_json: | |
try: | |
_data = json.loads(accumulated_json) | |
if self.is_messages_api: | |
yield self._chunk_parser_messages_api(chunk_data=_data) | |
else: | |
yield self._chunk_parser(chunk_data=_data) | |
except json.JSONDecodeError: | |
# Handle or log any unparseable data at the end | |
verbose_logger.error( | |
f"Warning: Unparseable JSON data remained: {accumulated_json}" | |
) | |
yield None | |
except Exception as e: | |
verbose_logger.error(f"Final error parsing accumulated JSON: {e}") | |
def _parse_message_from_event(self, event) -> Optional[str]: | |
response_dict = event.to_response_dict() | |
parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) | |
if response_dict["status_code"] != 200: | |
raise ValueError(f"Bad response code, expected 200: {response_dict}") | |
if "chunk" in parsed_response: | |
chunk = parsed_response.get("chunk") | |
if not chunk: | |
return None | |
return chunk.get("bytes").decode() # type: ignore[no-any-return] | |
else: | |
chunk = response_dict.get("body") | |
if not chunk: | |
return None | |
return chunk.decode() # type: ignore[no-any-return] | |
def get_response_stream_shape(): | |
global _response_stream_shape_cache | |
if _response_stream_shape_cache is None: | |
from botocore.loaders import Loader | |
from botocore.model import ServiceModel | |
loader = Loader() | |
sagemaker_service_dict = loader.load_service_model( | |
"sagemaker-runtime", "service-2" | |
) | |
sagemaker_service_model = ServiceModel(sagemaker_service_dict) | |
_response_stream_shape_cache = sagemaker_service_model.shape_for( | |
"InvokeEndpointWithResponseStreamOutput" | |
) | |
return _response_stream_shape_cache | |