Spaces:
Sleeping
Sleeping
""" | |
Utility functions for base LLM classes. | |
""" | |
import copy | |
import json | |
from abc import ABC, abstractmethod | |
from typing import List, Optional, Type, Union | |
from openai.lib import _parsing, _pydantic | |
from pydantic import BaseModel | |
from litellm._logging import verbose_logger | |
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk | |
from litellm.types.utils import Message, ProviderSpecificModelInfo | |
class BaseLLMModelInfo(ABC): | |
def get_provider_info( | |
self, | |
model: str, | |
) -> Optional[ProviderSpecificModelInfo]: | |
""" | |
Default values all models of this provider support. | |
""" | |
return None | |
def get_models( | |
self, api_key: Optional[str] = None, api_base: Optional[str] = None | |
) -> List[str]: | |
""" | |
Returns a list of models supported by this provider. | |
""" | |
return [] | |
def get_api_key(api_key: Optional[str] = None) -> Optional[str]: | |
pass | |
def get_api_base(api_base: Optional[str] = None) -> Optional[str]: | |
pass | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> dict: | |
pass | |
def get_base_model(model: str) -> Optional[str]: | |
""" | |
Returns the base model name from the given model name. | |
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0` | |
This function will return `anthropic.claude-3-opus-20240229-v1:0` | |
""" | |
pass | |
def _convert_tool_response_to_message( | |
tool_calls: List[ChatCompletionToolCallChunk], | |
) -> Optional[Message]: | |
""" | |
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format | |
""" | |
## HANDLE JSON MODE - anthropic returns single function call | |
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments") | |
try: | |
if json_mode_content_str is not None: | |
args = json.loads(json_mode_content_str) | |
if isinstance(args, dict) and (values := args.get("values")) is not None: | |
_message = Message(content=json.dumps(values)) | |
return _message | |
else: | |
# a lot of the times the `values` key is not present in the tool response | |
# relevant issue: https://github.com/BerriAI/litellm/issues/6741 | |
_message = Message(content=json.dumps(args)) | |
return _message | |
except json.JSONDecodeError: | |
# json decode error does occur, return the original tool response str | |
return Message(content=json_mode_content_str) | |
return None | |
def _dict_to_response_format_helper( | |
response_format: dict, ref_template: Optional[str] = None | |
) -> dict: | |
if ref_template is not None and response_format.get("type") == "json_schema": | |
# Deep copy to avoid modifying original | |
modified_format = copy.deepcopy(response_format) | |
schema = modified_format["json_schema"]["schema"] | |
# Update all $ref values in the schema | |
def update_refs(schema): | |
stack = [(schema, [])] | |
visited = set() | |
while stack: | |
obj, path = stack.pop() | |
obj_id = id(obj) | |
if obj_id in visited: | |
continue | |
visited.add(obj_id) | |
if isinstance(obj, dict): | |
if "$ref" in obj: | |
ref_path = obj["$ref"] | |
model_name = ref_path.split("/")[-1] | |
obj["$ref"] = ref_template.format(model=model_name) | |
for k, v in obj.items(): | |
if isinstance(v, (dict, list)): | |
stack.append((v, path + [k])) | |
elif isinstance(obj, list): | |
for i, item in enumerate(obj): | |
if isinstance(item, (dict, list)): | |
stack.append((item, path + [i])) | |
update_refs(schema) | |
return modified_format | |
return response_format | |
def type_to_response_format_param( | |
response_format: Optional[Union[Type[BaseModel], dict]], | |
ref_template: Optional[str] = None, | |
) -> Optional[dict]: | |
""" | |
Re-implementation of openai's 'type_to_response_format_param' function | |
Used for converting pydantic object to api schema. | |
""" | |
if response_format is None: | |
return None | |
if isinstance(response_format, dict): | |
return _dict_to_response_format_helper(response_format, ref_template) | |
# type checkers don't narrow the negation of a `TypeGuard` as it isn't | |
# a safe default behaviour but we know that at this point the `response_format` | |
# can only be a `type` | |
if not _parsing._completions.is_basemodel_type(response_format): | |
raise TypeError(f"Unsupported response_format type - {response_format}") | |
if ref_template is not None: | |
schema = response_format.model_json_schema(ref_template=ref_template) | |
else: | |
schema = _pydantic.to_strict_json_schema(response_format) | |
return { | |
"type": "json_schema", | |
"json_schema": { | |
"schema": schema, | |
"name": response_format.__name__, | |
"strict": True, | |
}, | |
} | |
def map_developer_role_to_system_role( | |
messages: List[AllMessageValues], | |
) -> List[AllMessageValues]: | |
""" | |
Translate `developer` role to `system` role for non-OpenAI providers. | |
""" | |
new_messages: List[AllMessageValues] = [] | |
for m in messages: | |
if m["role"] == "developer": | |
verbose_logger.debug( | |
"Translating developer role to system role for non-OpenAI providers." | |
) # ensure user knows what's happening with their input. | |
new_messages.append({"role": "system", "content": m["content"]}) | |
else: | |
new_messages.append(m) | |
return new_messages | |