Spaces:
Sleeping
Sleeping
File size: 6,335 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
"""
Utility functions for base LLM classes.
"""
import copy
import json
from abc import ABC, abstractmethod
from typing import List, Optional, Type, Union
from openai.lib import _parsing, _pydantic
from pydantic import BaseModel
from litellm._logging import verbose_logger
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
from litellm.types.utils import Message, ProviderSpecificModelInfo
class BaseLLMModelInfo(ABC):
def get_provider_info(
self,
model: str,
) -> Optional[ProviderSpecificModelInfo]:
"""
Default values all models of this provider support.
"""
return None
@abstractmethod
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Returns a list of models supported by this provider.
"""
return []
@staticmethod
@abstractmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
pass
@staticmethod
@abstractmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass
@staticmethod
@abstractmethod
def get_base_model(model: str) -> Optional[str]:
"""
Returns the base model name from the given model name.
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
This function will return `anthropic.claude-3-opus-20240229-v1:0`
"""
pass
def _convert_tool_response_to_message(
tool_calls: List[ChatCompletionToolCallChunk],
) -> Optional[Message]:
"""
In JSON mode, Anthropic API returns JSON schema as a tool call, we need to convert it to a message to follow the OpenAI format
"""
## HANDLE JSON MODE - anthropic returns single function call
json_mode_content_str: Optional[str] = tool_calls[0]["function"].get("arguments")
try:
if json_mode_content_str is not None:
args = json.loads(json_mode_content_str)
if isinstance(args, dict) and (values := args.get("values")) is not None:
_message = Message(content=json.dumps(values))
return _message
else:
# a lot of the times the `values` key is not present in the tool response
# relevant issue: https://github.com/BerriAI/litellm/issues/6741
_message = Message(content=json.dumps(args))
return _message
except json.JSONDecodeError:
# json decode error does occur, return the original tool response str
return Message(content=json_mode_content_str)
return None
def _dict_to_response_format_helper(
response_format: dict, ref_template: Optional[str] = None
) -> dict:
if ref_template is not None and response_format.get("type") == "json_schema":
# Deep copy to avoid modifying original
modified_format = copy.deepcopy(response_format)
schema = modified_format["json_schema"]["schema"]
# Update all $ref values in the schema
def update_refs(schema):
stack = [(schema, [])]
visited = set()
while stack:
obj, path = stack.pop()
obj_id = id(obj)
if obj_id in visited:
continue
visited.add(obj_id)
if isinstance(obj, dict):
if "$ref" in obj:
ref_path = obj["$ref"]
model_name = ref_path.split("/")[-1]
obj["$ref"] = ref_template.format(model=model_name)
for k, v in obj.items():
if isinstance(v, (dict, list)):
stack.append((v, path + [k]))
elif isinstance(obj, list):
for i, item in enumerate(obj):
if isinstance(item, (dict, list)):
stack.append((item, path + [i]))
update_refs(schema)
return modified_format
return response_format
def type_to_response_format_param(
response_format: Optional[Union[Type[BaseModel], dict]],
ref_template: Optional[str] = None,
) -> Optional[dict]:
"""
Re-implementation of openai's 'type_to_response_format_param' function
Used for converting pydantic object to api schema.
"""
if response_format is None:
return None
if isinstance(response_format, dict):
return _dict_to_response_format_helper(response_format, ref_template)
# type checkers don't narrow the negation of a `TypeGuard` as it isn't
# a safe default behaviour but we know that at this point the `response_format`
# can only be a `type`
if not _parsing._completions.is_basemodel_type(response_format):
raise TypeError(f"Unsupported response_format type - {response_format}")
if ref_template is not None:
schema = response_format.model_json_schema(ref_template=ref_template)
else:
schema = _pydantic.to_strict_json_schema(response_format)
return {
"type": "json_schema",
"json_schema": {
"schema": schema,
"name": response_format.__name__,
"strict": True,
},
}
def map_developer_role_to_system_role(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
"""
new_messages: List[AllMessageValues] = []
for m in messages:
if m["role"] == "developer":
verbose_logger.debug(
"Translating developer role to system role for non-OpenAI providers."
) # ensure user knows what's happening with their input.
new_messages.append({"role": "system", "content": m["content"]})
else:
new_messages.append(m)
return new_messages
|