Shyamnath's picture
Push core package and essential files
469eae6
raw
history blame
8.68 kB
import json
from typing import AsyncIterator, Iterator, List, Optional, Union
import httpx
import litellm
from litellm import verbose_logger
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import StreamingChatCompletionChunk
_response_stream_shape_cache = None
class SagemakerError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
class AWSEventStreamDecoder:
def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None:
from botocore.parsers import EventStreamJSONParser
self.model = model
self.parser = EventStreamJSONParser()
self.content_blocks: List = []
self.is_messages_api = is_messages_api
def _chunk_parser_messages_api(
self, chunk_data: dict
) -> StreamingChatCompletionChunk:
openai_chunk = StreamingChatCompletionChunk(**chunk_data)
return openai_chunk
def _chunk_parser(self, chunk_data: dict) -> GChunk:
verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data)
_token = chunk_data.get("token", {}) or {}
_index = chunk_data.get("index", None) or 0
is_finished = False
finish_reason = ""
_text = _token.get("text", "")
if _text == "<|endoftext|>":
return GChunk(
text="",
index=_index,
is_finished=True,
finish_reason="stop",
usage=None,
)
return GChunk(
text=_text,
index=_index,
is_finished=is_finished,
finish_reason=finish_reason,
usage=None,
)
def iter_bytes(
self, iterator: Iterator[bytes]
) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
message = self._parse_message_from_event(event)
if message:
# remove data: prefix and "\n\n" at the end
message = (
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(message)
or ""
)
message = message.replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
async def aiter_bytes(
self, iterator: AsyncIterator[bytes]
) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]:
"""Given an async iterator that yields lines, iterate over it & yield every event encountered"""
from botocore.eventstream import EventStreamBuffer
event_stream_buffer = EventStreamBuffer()
accumulated_json = ""
async for chunk in iterator:
event_stream_buffer.add_data(chunk)
for event in event_stream_buffer:
try:
message = self._parse_message_from_event(event)
if message:
verbose_logger.debug(
"sagemaker parsed chunk bytes %s", message
)
# remove data: prefix and "\n\n" at the end
message = (
litellm.CustomStreamWrapper._strip_sse_data_from_chunk(
message
)
or ""
)
message = message.replace("\n\n", "")
# Accumulate JSON data
accumulated_json += message
# Try to parse the accumulated JSON
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
# Reset accumulated_json after successful parsing
accumulated_json = ""
except json.JSONDecodeError:
# If it's not valid JSON yet, continue to the next event
continue
except UnicodeDecodeError as e:
verbose_logger.warning(
f"UnicodeDecodeError: {e}. Attempting to combine with next event."
)
continue
except Exception as e:
verbose_logger.error(
f"Error parsing message: {e}. Attempting to combine with next event."
)
continue
# Handle any remaining data after the iterator is exhausted
if accumulated_json:
try:
_data = json.loads(accumulated_json)
if self.is_messages_api:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"
)
yield None
except Exception as e:
verbose_logger.error(f"Final error parsing accumulated JSON: {e}")
def _parse_message_from_event(self, event) -> Optional[str]:
response_dict = event.to_response_dict()
parsed_response = self.parser.parse(response_dict, get_response_stream_shape())
if response_dict["status_code"] != 200:
raise ValueError(f"Bad response code, expected 200: {response_dict}")
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode() # type: ignore[no-any-return]
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode() # type: ignore[no-any-return]
def get_response_stream_shape():
global _response_stream_shape_cache
if _response_stream_shape_cache is None:
from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader()
sagemaker_service_dict = loader.load_service_model(
"sagemaker-runtime", "service-2"
)
sagemaker_service_model = ServiceModel(sagemaker_service_dict)
_response_stream_shape_cache = sagemaker_service_model.shape_for(
"InvokeEndpointWithResponseStreamOutput"
)
return _response_stream_shape_cache