Spaces:
Paused
Paused
| import time | |
| from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union | |
| import httpx | |
| import litellm | |
| from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2 | |
| from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException | |
| from litellm.types.llms.cohere import CohereV2ChatResponse | |
| from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk | |
| from litellm.types.utils import ModelResponse, Usage | |
| from ..common_utils import CohereError | |
| from ..common_utils import ModelResponseIterator as CohereModelResponseIterator | |
| from ..common_utils import validate_environment as cohere_validate_environment | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
| LiteLLMLoggingObj = _LiteLLMLoggingObj | |
| else: | |
| LiteLLMLoggingObj = Any | |
| class CohereV2ChatConfig(BaseConfig): | |
| """ | |
| Configuration class for Cohere's API interface. | |
| Args: | |
| preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one. | |
| chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model. | |
| generation_id (str, optional): Unique identifier for the generated reply. | |
| response_id (str, optional): Unique identifier for the response. | |
| conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation. | |
| prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'. | |
| connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply. | |
| search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries. | |
| documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite. | |
| temperature (float, optional): A non-negative float that tunes the degree of randomness in generation. | |
| max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response. | |
| k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step. | |
| p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation. | |
| frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens. | |
| presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens. | |
| tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking. | |
| tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools. | |
| seed (int, optional): A seed to assist reproducibility of the model's response. | |
| """ | |
| preamble: Optional[str] = None | |
| chat_history: Optional[list] = None | |
| generation_id: Optional[str] = None | |
| response_id: Optional[str] = None | |
| conversation_id: Optional[str] = None | |
| prompt_truncation: Optional[str] = None | |
| connectors: Optional[list] = None | |
| search_queries_only: Optional[bool] = None | |
| documents: Optional[list] = None | |
| temperature: Optional[int] = None | |
| max_tokens: Optional[int] = None | |
| k: Optional[int] = None | |
| p: Optional[int] = None | |
| frequency_penalty: Optional[int] = None | |
| presence_penalty: Optional[int] = None | |
| tools: Optional[list] = None | |
| tool_results: Optional[list] = None | |
| seed: Optional[int] = None | |
| def __init__( | |
| self, | |
| preamble: Optional[str] = None, | |
| chat_history: Optional[list] = None, | |
| generation_id: Optional[str] = None, | |
| response_id: Optional[str] = None, | |
| conversation_id: Optional[str] = None, | |
| prompt_truncation: Optional[str] = None, | |
| connectors: Optional[list] = None, | |
| search_queries_only: Optional[bool] = None, | |
| documents: Optional[list] = None, | |
| temperature: Optional[int] = None, | |
| max_tokens: Optional[int] = None, | |
| k: Optional[int] = None, | |
| p: Optional[int] = None, | |
| frequency_penalty: Optional[int] = None, | |
| presence_penalty: Optional[int] = None, | |
| tools: Optional[list] = None, | |
| tool_results: Optional[list] = None, | |
| seed: Optional[int] = None, | |
| ) -> None: | |
| locals_ = locals() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def validate_environment( | |
| self, | |
| headers: dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> dict: | |
| return cohere_validate_environment( | |
| headers=headers, | |
| model=model, | |
| messages=messages, | |
| optional_params=optional_params, | |
| api_key=api_key, | |
| ) | |
| def get_supported_openai_params(self, model: str) -> List[str]: | |
| return [ | |
| "stream", | |
| "temperature", | |
| "max_tokens", | |
| "top_p", | |
| "frequency_penalty", | |
| "presence_penalty", | |
| "stop", | |
| "n", | |
| "tools", | |
| "tool_choice", | |
| "seed", | |
| "extra_headers", | |
| ] | |
| def map_openai_params( | |
| self, | |
| non_default_params: dict, | |
| optional_params: dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> dict: | |
| for param, value in non_default_params.items(): | |
| if param == "stream": | |
| optional_params["stream"] = value | |
| if param == "temperature": | |
| optional_params["temperature"] = value | |
| if param == "max_tokens": | |
| optional_params["max_tokens"] = value | |
| if param == "n": | |
| optional_params["num_generations"] = value | |
| if param == "top_p": | |
| optional_params["p"] = value | |
| if param == "frequency_penalty": | |
| optional_params["frequency_penalty"] = value | |
| if param == "presence_penalty": | |
| optional_params["presence_penalty"] = value | |
| if param == "stop": | |
| optional_params["stop_sequences"] = value | |
| if param == "tools": | |
| optional_params["tools"] = value | |
| if param == "seed": | |
| optional_params["seed"] = value | |
| return optional_params | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| ## Load Config | |
| for k, v in litellm.CohereChatConfig.get_config().items(): | |
| if ( | |
| k not in optional_params | |
| ): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in | |
| optional_params[k] = v | |
| most_recent_message, chat_history = cohere_messages_pt_v2( | |
| messages=messages, model=model, llm_provider="cohere_chat" | |
| ) | |
| ## Handle Tool Calling | |
| if "tools" in optional_params: | |
| _is_function_call = True | |
| cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"]) | |
| optional_params["tools"] = cohere_tools | |
| if isinstance(most_recent_message, dict): | |
| optional_params["tool_results"] = [most_recent_message] | |
| elif isinstance(most_recent_message, str): | |
| optional_params["message"] = most_recent_message | |
| ## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails | |
| if len(chat_history) > 0 and chat_history[-1]["role"] == "USER": | |
| optional_params["force_single_step"] = True | |
| return optional_params | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: httpx.Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: dict, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| try: | |
| raw_response_json = raw_response.json() | |
| except Exception: | |
| raise CohereError( | |
| message=raw_response.text, status_code=raw_response.status_code | |
| ) | |
| try: | |
| cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore | |
| except Exception: | |
| raise CohereError(message=raw_response.text, status_code=422) | |
| cohere_content = cohere_v2_chat_response["message"].get("content", None) | |
| if cohere_content is not None: | |
| model_response.choices[0].message.content = "".join( # type: ignore | |
| [ | |
| content.get("text", "") | |
| for content in cohere_content | |
| if content is not None | |
| ] | |
| ) | |
| ## ADD CITATIONS | |
| if "citations" in cohere_v2_chat_response: | |
| setattr(model_response, "citations", cohere_v2_chat_response["citations"]) | |
| ## Tool calling response | |
| cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", []) | |
| if cohere_tools_response is not None and cohere_tools_response != []: | |
| # convert cohere_tools_response to OpenAI response format | |
| tool_calls: List[ChatCompletionToolCallChunk] = [] | |
| for index, tool in enumerate(cohere_tools_response): | |
| tool_call: ChatCompletionToolCallChunk = { | |
| **tool, # type: ignore | |
| "index": index, | |
| } | |
| tool_calls.append(tool_call) | |
| _message = litellm.Message( | |
| tool_calls=tool_calls, | |
| content=None, | |
| ) | |
| model_response.choices[0].message = _message # type: ignore | |
| ## CALCULATING USAGE - use cohere `billed_units` for returning usage | |
| token_usage = cohere_v2_chat_response["usage"].get("tokens", {}) | |
| prompt_tokens = token_usage.get("input_tokens", 0) | |
| completion_tokens = token_usage.get("output_tokens", 0) | |
| model_response.created = int(time.time()) | |
| model_response.model = model | |
| usage = Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=completion_tokens, | |
| total_tokens=prompt_tokens + completion_tokens, | |
| ) | |
| setattr(model_response, "usage", usage) | |
| return model_response | |
| def _construct_cohere_tool( | |
| self, | |
| tools: Optional[list] = None, | |
| ): | |
| if tools is None: | |
| tools = [] | |
| cohere_tools = [] | |
| for tool in tools: | |
| cohere_tool = self._translate_openai_tool_to_cohere(tool) | |
| cohere_tools.append(cohere_tool) | |
| return cohere_tools | |
| def _translate_openai_tool_to_cohere( | |
| self, | |
| openai_tool: dict, | |
| ): | |
| # cohere tools look like this | |
| """ | |
| { | |
| "name": "query_daily_sales_report", | |
| "description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.", | |
| "parameter_definitions": { | |
| "day": { | |
| "description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.", | |
| "type": "str", | |
| "required": True | |
| } | |
| } | |
| } | |
| """ | |
| # OpenAI tools look like this | |
| """ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "get_current_weather", | |
| "description": "Get the current weather in a given location", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "location": { | |
| "type": "string", | |
| "description": "The city and state, e.g. San Francisco, CA", | |
| }, | |
| "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, | |
| }, | |
| "required": ["location"], | |
| }, | |
| }, | |
| } | |
| """ | |
| cohere_tool = { | |
| "name": openai_tool["function"]["name"], | |
| "description": openai_tool["function"]["description"], | |
| "parameter_definitions": {}, | |
| } | |
| for param_name, param_def in openai_tool["function"]["parameters"][ | |
| "properties" | |
| ].items(): | |
| required_params = ( | |
| openai_tool.get("function", {}) | |
| .get("parameters", {}) | |
| .get("required", []) | |
| ) | |
| cohere_param_def = { | |
| "description": param_def.get("description", ""), | |
| "type": param_def.get("type", ""), | |
| "required": param_name in required_params, | |
| } | |
| cohere_tool["parameter_definitions"][param_name] = cohere_param_def | |
| return cohere_tool | |
| def get_model_response_iterator( | |
| self, | |
| streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
| sync_stream: bool, | |
| json_mode: Optional[bool] = False, | |
| ): | |
| return CohereModelResponseIterator( | |
| streaming_response=streaming_response, | |
| sync_stream=sync_stream, | |
| json_mode=json_mode, | |
| ) | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] | |
| ) -> BaseLLMException: | |
| return CohereError(status_code=status_code, message=error_message) | |