Spaces:
Sleeping
Sleeping
""" | |
Translates from OpenAI's `/v1/chat/completions` to Databricks' `/chat/completions` | |
""" | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
AsyncIterator, | |
Iterator, | |
List, | |
Optional, | |
Tuple, | |
Union, | |
cast, | |
) | |
import httpx | |
from pydantic import BaseModel | |
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME | |
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import ( | |
_handle_invalid_parallel_tool_calls, | |
_should_convert_tool_call_to_json_mode, | |
) | |
from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
handle_messages_with_content_list_to_str_conversion, | |
strip_name_from_messages, | |
) | |
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator | |
from litellm.types.llms.anthropic import AllAnthropicToolsValues | |
from litellm.types.llms.databricks import ( | |
AllDatabricksContentValues, | |
DatabricksChoice, | |
DatabricksFunction, | |
DatabricksResponse, | |
DatabricksTool, | |
) | |
from litellm.types.llms.openai import ( | |
AllMessageValues, | |
ChatCompletionRedactedThinkingBlock, | |
ChatCompletionThinkingBlock, | |
ChatCompletionToolChoiceFunctionParam, | |
ChatCompletionToolChoiceObjectParam, | |
) | |
from litellm.types.utils import ( | |
ChatCompletionMessageToolCall, | |
Choices, | |
Message, | |
ModelResponse, | |
ModelResponseStream, | |
ProviderField, | |
Usage, | |
) | |
from ...anthropic.chat.transformation import AnthropicConfig | |
from ...openai_like.chat.transformation import OpenAILikeChatConfig | |
from ..common_utils import DatabricksBase, DatabricksException | |
if TYPE_CHECKING: | |
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj | |
LiteLLMLoggingObj = _LiteLLMLoggingObj | |
else: | |
LiteLLMLoggingObj = Any | |
class DatabricksConfig(DatabricksBase, OpenAILikeChatConfig, AnthropicConfig): | |
""" | |
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request | |
""" | |
max_tokens: Optional[int] = None | |
temperature: Optional[int] = None | |
top_p: Optional[int] = None | |
top_k: Optional[int] = None | |
stop: Optional[Union[List[str], str]] = None | |
n: Optional[int] = None | |
def __init__( | |
self, | |
max_tokens: Optional[int] = None, | |
temperature: Optional[int] = None, | |
top_p: Optional[int] = None, | |
top_k: Optional[int] = None, | |
stop: Optional[Union[List[str], str]] = None, | |
n: Optional[int] = None, | |
) -> None: | |
locals_ = locals().copy() | |
for key, value in locals_.items(): | |
if key != "self" and value is not None: | |
setattr(self.__class__, key, value) | |
def get_config(cls): | |
return super().get_config() | |
def get_required_params(self) -> List[ProviderField]: | |
"""For a given provider, return it's required fields with a description""" | |
return [ | |
ProviderField( | |
field_name="api_key", | |
field_type="string", | |
field_description="Your Databricks API Key.", | |
field_value="dapi...", | |
), | |
ProviderField( | |
field_name="api_base", | |
field_type="string", | |
field_description="Your Databricks API Base.", | |
field_value="https://adb-..", | |
), | |
] | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> dict: | |
api_base, headers = self.databricks_validate_environment( | |
api_base=api_base, | |
api_key=api_key, | |
endpoint_type="chat_completions", | |
custom_endpoint=False, | |
headers=headers, | |
) | |
# Ensure Content-Type header is set | |
headers["Content-Type"] = "application/json" | |
return headers | |
def get_complete_url( | |
self, | |
api_base: Optional[str], | |
api_key: Optional[str], | |
model: str, | |
optional_params: dict, | |
litellm_params: dict, | |
stream: Optional[bool] = None, | |
) -> str: | |
api_base = self._get_api_base(api_base) | |
complete_url = f"{api_base}/chat/completions" | |
return complete_url | |
def get_supported_openai_params(self, model: Optional[str] = None) -> list: | |
return [ | |
"stream", | |
"stop", | |
"temperature", | |
"top_p", | |
"max_tokens", | |
"max_completion_tokens", | |
"n", | |
"response_format", | |
"tools", | |
"tool_choice", | |
"reasoning_effort", | |
"thinking", | |
] | |
def convert_anthropic_tool_to_databricks_tool( | |
self, tool: Optional[AllAnthropicToolsValues] | |
) -> Optional[DatabricksTool]: | |
if tool is None: | |
return None | |
return DatabricksTool( | |
type="function", | |
function=DatabricksFunction( | |
name=tool["name"], | |
parameters=cast(dict, tool.get("input_schema") or {}), | |
), | |
) | |
def _map_openai_to_dbrx_tool(self, model: str, tools: List) -> List[DatabricksTool]: | |
# if not claude, send as is | |
if "claude" not in model: | |
return tools | |
# if claude, convert to anthropic tool and then to databricks tool | |
anthropic_tools = self._map_tools(tools=tools) | |
databricks_tools = [ | |
cast(DatabricksTool, self.convert_anthropic_tool_to_databricks_tool(tool)) | |
for tool in anthropic_tools | |
] | |
return databricks_tools | |
def map_response_format_to_databricks_tool( | |
self, | |
model: str, | |
value: Optional[dict], | |
optional_params: dict, | |
is_thinking_enabled: bool, | |
) -> Optional[DatabricksTool]: | |
if value is None: | |
return None | |
tool = self.map_response_format_to_anthropic_tool( | |
value, optional_params, is_thinking_enabled | |
) | |
databricks_tool = self.convert_anthropic_tool_to_databricks_tool(tool) | |
return databricks_tool | |
def map_openai_params( | |
self, | |
non_default_params: dict, | |
optional_params: dict, | |
model: str, | |
drop_params: bool, | |
replace_max_completion_tokens_with_max_tokens: bool = True, | |
) -> dict: | |
is_thinking_enabled = self.is_thinking_enabled(non_default_params) | |
mapped_params = super().map_openai_params( | |
non_default_params, optional_params, model, drop_params | |
) | |
if "tools" in mapped_params: | |
mapped_params["tools"] = self._map_openai_to_dbrx_tool( | |
model=model, tools=mapped_params["tools"] | |
) | |
if ( | |
"max_completion_tokens" in non_default_params | |
and replace_max_completion_tokens_with_max_tokens | |
): | |
mapped_params["max_tokens"] = non_default_params[ | |
"max_completion_tokens" | |
] # most openai-compatible providers support 'max_tokens' not 'max_completion_tokens' | |
mapped_params.pop("max_completion_tokens", None) | |
if "response_format" in non_default_params and "claude" in model: | |
_tool = self.map_response_format_to_databricks_tool( | |
model, | |
non_default_params["response_format"], | |
mapped_params, | |
is_thinking_enabled, | |
) | |
if _tool is not None: | |
self._add_tools_to_optional_params( | |
optional_params=optional_params, tools=[_tool] | |
) | |
optional_params["json_mode"] = True | |
if not is_thinking_enabled: | |
_tool_choice = ChatCompletionToolChoiceObjectParam( | |
type="function", | |
function=ChatCompletionToolChoiceFunctionParam( | |
name=RESPONSE_FORMAT_TOOL_NAME | |
), | |
) | |
optional_params["tool_choice"] = _tool_choice | |
optional_params.pop( | |
"response_format", None | |
) # unsupported for claude models - if json_schema -> convert to tool call | |
if "reasoning_effort" in non_default_params and "claude" in model: | |
optional_params["thinking"] = AnthropicConfig._map_reasoning_effort( | |
non_default_params.get("reasoning_effort") | |
) | |
optional_params.pop("reasoning_effort", None) | |
## handle thinking tokens | |
self.update_optional_params_with_thinking_tokens( | |
non_default_params=non_default_params, optional_params=mapped_params | |
) | |
return mapped_params | |
def _should_fake_stream(self, optional_params: dict) -> bool: | |
""" | |
Databricks doesn't support 'response_format' while streaming | |
""" | |
if optional_params.get("response_format") is not None: | |
return True | |
return False | |
def _transform_messages( | |
self, messages: List[AllMessageValues], model: str | |
) -> List[AllMessageValues]: | |
""" | |
Databricks does not support: | |
- content in list format. | |
- 'name' in user message. | |
""" | |
new_messages = [] | |
for idx, message in enumerate(messages): | |
if isinstance(message, BaseModel): | |
_message = message.model_dump(exclude_none=True) | |
else: | |
_message = message | |
new_messages.append(_message) | |
new_messages = handle_messages_with_content_list_to_str_conversion(new_messages) | |
new_messages = strip_name_from_messages(new_messages) | |
return super()._transform_messages(messages=new_messages, model=model) | |
def extract_content_str( | |
content: Optional[AllDatabricksContentValues], | |
) -> Optional[str]: | |
if content is None: | |
return None | |
if isinstance(content, str): | |
return content | |
elif isinstance(content, list): | |
content_str = "" | |
for item in content: | |
if item["type"] == "text": | |
content_str += item["text"] | |
return content_str | |
else: | |
raise Exception(f"Unsupported content type: {type(content)}") | |
def extract_reasoning_content( | |
content: Optional[AllDatabricksContentValues], | |
) -> Tuple[ | |
Optional[str], | |
Optional[ | |
List[ | |
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] | |
] | |
], | |
]: | |
""" | |
Extract and return the reasoning content and thinking blocks | |
""" | |
if content is None: | |
return None, None | |
thinking_blocks: Optional[ | |
List[ | |
Union[ChatCompletionThinkingBlock, ChatCompletionRedactedThinkingBlock] | |
] | |
] = None | |
reasoning_content: Optional[str] = None | |
if isinstance(content, list): | |
for item in content: | |
if item["type"] == "reasoning": | |
for sum in item["summary"]: | |
if reasoning_content is None: | |
reasoning_content = "" | |
reasoning_content += sum["text"] | |
thinking_block = ChatCompletionThinkingBlock( | |
type="thinking", | |
thinking=sum["text"], | |
signature=sum["signature"], | |
) | |
if thinking_blocks is None: | |
thinking_blocks = [] | |
thinking_blocks.append(thinking_block) | |
return reasoning_content, thinking_blocks | |
def _transform_dbrx_choices( | |
self, choices: List[DatabricksChoice], json_mode: Optional[bool] = None | |
) -> List[Choices]: | |
transformed_choices = [] | |
for choice in choices: | |
## HANDLE JSON MODE - anthropic returns single function call] | |
tool_calls = choice["message"].get("tool_calls", None) | |
if tool_calls is not None: | |
_openai_tool_calls = [] | |
for _tc in tool_calls: | |
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore | |
_openai_tool_calls.append(_openai_tc) | |
fixed_tool_calls = _handle_invalid_parallel_tool_calls( | |
_openai_tool_calls | |
) | |
if fixed_tool_calls is not None: | |
tool_calls = fixed_tool_calls | |
translated_message: Optional[Message] = None | |
finish_reason: Optional[str] = None | |
if tool_calls and _should_convert_tool_call_to_json_mode( | |
tool_calls=tool_calls, | |
convert_tool_call_to_json_mode=json_mode, | |
): | |
# to support response_format on claude models | |
json_mode_content_str: Optional[str] = ( | |
str(tool_calls[0]["function"].get("arguments", "")) or None | |
) | |
if json_mode_content_str is not None: | |
translated_message = Message(content=json_mode_content_str) | |
finish_reason = "stop" | |
if translated_message is None: | |
## get the content str | |
content_str = DatabricksConfig.extract_content_str( | |
choice["message"]["content"] | |
) | |
## get the reasoning content | |
( | |
reasoning_content, | |
thinking_blocks, | |
) = DatabricksConfig.extract_reasoning_content( | |
choice["message"].get("content") | |
) | |
translated_message = Message( | |
role="assistant", | |
content=content_str, | |
reasoning_content=reasoning_content, | |
thinking_blocks=thinking_blocks, | |
tool_calls=choice["message"].get("tool_calls"), | |
) | |
if finish_reason is None: | |
finish_reason = choice["finish_reason"] | |
translated_choice = Choices( | |
finish_reason=finish_reason, | |
index=choice["index"], | |
message=translated_message, | |
logprobs=None, | |
enhancements=None, | |
) | |
transformed_choices.append(translated_choice) | |
return transformed_choices | |
def transform_response( | |
self, | |
model: str, | |
raw_response: httpx.Response, | |
model_response: ModelResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: dict, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
json_mode: Optional[bool] = None, | |
) -> ModelResponse: | |
## LOGGING | |
logging_obj.post_call( | |
input=messages, | |
api_key=api_key, | |
original_response=raw_response.text, | |
additional_args={"complete_input_dict": request_data}, | |
) | |
## RESPONSE OBJECT | |
try: | |
completion_response = DatabricksResponse(**raw_response.json()) # type: ignore | |
except Exception as e: | |
response_headers = getattr(raw_response, "headers", None) | |
raise DatabricksException( | |
message="Unable to get json response - {}, Original Response: {}".format( | |
str(e), raw_response.text | |
), | |
status_code=raw_response.status_code, | |
headers=response_headers, | |
) | |
model_response.model = completion_response["model"] | |
model_response.id = completion_response["id"] | |
model_response.created = completion_response["created"] | |
setattr(model_response, "usage", Usage(**completion_response["usage"])) | |
model_response.choices = self._transform_dbrx_choices( # type: ignore | |
choices=completion_response["choices"], | |
json_mode=json_mode, | |
) | |
return model_response | |
def get_model_response_iterator( | |
self, | |
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
sync_stream: bool, | |
json_mode: Optional[bool] = False, | |
): | |
return DatabricksChatResponseIterator( | |
streaming_response=streaming_response, | |
sync_stream=sync_stream, | |
json_mode=json_mode, | |
) | |
class DatabricksChatResponseIterator(BaseModelResponseIterator): | |
def __init__( | |
self, | |
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
sync_stream: bool, | |
json_mode: Optional[bool] = False, | |
): | |
super().__init__(streaming_response, sync_stream) | |
self.json_mode = json_mode | |
self._last_function_name = None # Track the last seen function name | |
def chunk_parser(self, chunk: dict) -> ModelResponseStream: | |
try: | |
translated_choices = [] | |
for choice in chunk["choices"]: | |
tool_calls = choice["delta"].get("tool_calls") | |
if tool_calls and self.json_mode: | |
# 1. Check if the function name is set and == RESPONSE_FORMAT_TOOL_NAME | |
# 2. If no function name, just args -> check last function name (saved via state variable) | |
# 3. Convert args to json | |
# 4. Convert json to message | |
# 5. Set content to message.content | |
# 6. Set tool_calls to None | |
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME | |
from litellm.llms.base_llm.base_utils import ( | |
_convert_tool_response_to_message, | |
) | |
# Check if this chunk has a function name | |
function_name = tool_calls[0].get("function", {}).get("name") | |
if function_name is not None: | |
self._last_function_name = function_name | |
# If we have a saved function name that matches RESPONSE_FORMAT_TOOL_NAME | |
# or this chunk has the matching function name | |
if ( | |
self._last_function_name == RESPONSE_FORMAT_TOOL_NAME | |
or function_name == RESPONSE_FORMAT_TOOL_NAME | |
): | |
# Convert tool calls to message format | |
message = _convert_tool_response_to_message(tool_calls) | |
if message is not None: | |
if message.content == "{}": # empty json | |
message.content = "" | |
choice["delta"]["content"] = message.content | |
choice["delta"]["tool_calls"] = None | |
elif tool_calls: | |
for _tc in tool_calls: | |
if _tc.get("function", {}).get("arguments") == "{}": | |
_tc["function"]["arguments"] = "" # avoid invalid json | |
# extract the content str | |
content_str = DatabricksConfig.extract_content_str( | |
choice["delta"].get("content") | |
) | |
# extract the reasoning content | |
( | |
reasoning_content, | |
thinking_blocks, | |
) = DatabricksConfig.extract_reasoning_content( | |
choice["delta"]["content"] | |
) | |
choice["delta"]["content"] = content_str | |
choice["delta"]["reasoning_content"] = reasoning_content | |
choice["delta"]["thinking_blocks"] = thinking_blocks | |
translated_choices.append(choice) | |
return ModelResponseStream( | |
id=chunk["id"], | |
object="chat.completion.chunk", | |
created=chunk["created"], | |
model=chunk["model"], | |
choices=translated_choices, | |
) | |
except KeyError as e: | |
raise DatabricksException( | |
message=f"KeyError: {e}, Got unexpected response from Databricks: {chunk}", | |
status_code=400, | |
) | |
except Exception as e: | |
raise e | |