Spaces:
Sleeping
Sleeping
import json | |
from typing import List, Optional | |
from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
from litellm.types.llms.openai import AllMessageValues | |
from litellm.types.utils import ( | |
ChatCompletionToolCallChunk, | |
ChatCompletionUsageBlock, | |
GenericStreamingChunk, | |
) | |
class CohereError(BaseLLMException): | |
def __init__(self, status_code, message): | |
super().__init__(status_code=status_code, message=message) | |
def validate_environment( | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
api_key: Optional[str] = None, | |
) -> dict: | |
""" | |
Return headers to use for cohere chat completion request | |
Cohere API Ref: https://docs.cohere.com/reference/chat | |
Expected headers: | |
{ | |
"Request-Source": "unspecified:litellm", | |
"accept": "application/json", | |
"content-type": "application/json", | |
"Authorization": "bearer $CO_API_KEY" | |
} | |
""" | |
headers.update( | |
{ | |
"Request-Source": "unspecified:litellm", | |
"accept": "application/json", | |
"content-type": "application/json", | |
} | |
) | |
if api_key: | |
headers["Authorization"] = f"bearer {api_key}" | |
return headers | |
class ModelResponseIterator: | |
def __init__( | |
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False | |
): | |
self.streaming_response = streaming_response | |
self.response_iterator = self.streaming_response | |
self.content_blocks: List = [] | |
self.tool_index = -1 | |
self.json_mode = json_mode | |
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: | |
try: | |
text = "" | |
tool_use: Optional[ChatCompletionToolCallChunk] = None | |
is_finished = False | |
finish_reason = "" | |
usage: Optional[ChatCompletionUsageBlock] = None | |
provider_specific_fields = None | |
index = int(chunk.get("index", 0)) | |
if "text" in chunk: | |
text = chunk["text"] | |
elif "is_finished" in chunk and chunk["is_finished"] is True: | |
is_finished = chunk["is_finished"] | |
finish_reason = chunk["finish_reason"] | |
if "citations" in chunk: | |
provider_specific_fields = {"citations": chunk["citations"]} | |
returned_chunk = GenericStreamingChunk( | |
text=text, | |
tool_use=tool_use, | |
is_finished=is_finished, | |
finish_reason=finish_reason, | |
usage=usage, | |
index=index, | |
provider_specific_fields=provider_specific_fields, | |
) | |
return returned_chunk | |
except json.JSONDecodeError: | |
raise ValueError(f"Failed to decode JSON from chunk: {chunk}") | |
# Sync iterator | |
def __iter__(self): | |
return self | |
def __next__(self): | |
try: | |
chunk = self.response_iterator.__next__() | |
except StopIteration: | |
raise StopIteration | |
except ValueError as e: | |
raise RuntimeError(f"Error receiving chunk from stream: {e}") | |
try: | |
return self.convert_str_chunk_to_generic_chunk(chunk=chunk) | |
except StopIteration: | |
raise StopIteration | |
except ValueError as e: | |
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") | |
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk: | |
""" | |
Convert a string chunk to a GenericStreamingChunk | |
Note: This is used for Cohere pass through streaming logging | |
""" | |
str_line = chunk | |
if isinstance(chunk, bytes): # Handle binary data | |
str_line = chunk.decode("utf-8") # Convert bytes to string | |
index = str_line.find("data:") | |
if index != -1: | |
str_line = str_line[index:] | |
data_json = json.loads(str_line) | |
return self.chunk_parser(chunk=data_json) | |
# Async iterator | |
def __aiter__(self): | |
self.async_response_iterator = self.streaming_response.__aiter__() | |
return self | |
async def __anext__(self): | |
try: | |
chunk = await self.async_response_iterator.__anext__() | |
except StopAsyncIteration: | |
raise StopAsyncIteration | |
except ValueError as e: | |
raise RuntimeError(f"Error receiving chunk from stream: {e}") | |
try: | |
return self.convert_str_chunk_to_generic_chunk(chunk=chunk) | |
except StopAsyncIteration: | |
raise StopAsyncIteration | |
except ValueError as e: | |
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") | |