Spaces:
Running
Running
""" | |
This hook is used to inject cache control directives into the messages of a chat completion. | |
Users can define | |
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points. | |
""" | |
import copy | |
from typing import Dict, List, Optional, Tuple, Union, cast | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.integrations.custom_prompt_management import CustomPromptManagement | |
from litellm.types.integrations.anthropic_cache_control_hook import ( | |
CacheControlInjectionPoint, | |
CacheControlMessageInjectionPoint, | |
) | |
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent | |
from litellm.types.utils import StandardCallbackDynamicParams | |
class AnthropicCacheControlHook(CustomPromptManagement): | |
def get_chat_completion_prompt( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
non_default_params: dict, | |
prompt_id: Optional[str], | |
prompt_variables: Optional[dict], | |
dynamic_callback_params: StandardCallbackDynamicParams, | |
) -> Tuple[str, List[AllMessageValues], dict]: | |
""" | |
Apply cache control directives based on specified injection points. | |
Returns: | |
- model: str - the model to use | |
- messages: List[AllMessageValues] - messages with applied cache controls | |
- non_default_params: dict - params with any global cache controls | |
""" | |
# Extract cache control injection points | |
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop( | |
"cache_control_injection_points", [] | |
) | |
if not injection_points: | |
return model, messages, non_default_params | |
# Create a deep copy of messages to avoid modifying the original list | |
processed_messages = copy.deepcopy(messages) | |
# Process message-level cache controls | |
for point in injection_points: | |
if point.get("location") == "message": | |
point = cast(CacheControlMessageInjectionPoint, point) | |
processed_messages = self._process_message_injection( | |
point=point, messages=processed_messages | |
) | |
return model, processed_messages, non_default_params | |
def _process_message_injection( | |
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues] | |
) -> List[AllMessageValues]: | |
"""Process message-level cache control injection.""" | |
control: ChatCompletionCachedContent = point.get( | |
"control", None | |
) or ChatCompletionCachedContent(type="ephemeral") | |
_targetted_index: Optional[Union[int, str]] = point.get("index", None) | |
targetted_index: Optional[int] = None | |
if isinstance(_targetted_index, str): | |
if _targetted_index.isdigit(): | |
targetted_index = int(_targetted_index) | |
else: | |
targetted_index = _targetted_index | |
targetted_role = point.get("role", None) | |
# Case 1: Target by specific index | |
if targetted_index is not None: | |
if 0 <= targetted_index < len(messages): | |
messages[targetted_index] = ( | |
AnthropicCacheControlHook._safe_insert_cache_control_in_message( | |
messages[targetted_index], control | |
) | |
) | |
# Case 2: Target by role | |
elif targetted_role is not None: | |
for msg in messages: | |
if msg.get("role") == targetted_role: | |
msg = ( | |
AnthropicCacheControlHook._safe_insert_cache_control_in_message( | |
message=msg, control=control | |
) | |
) | |
return messages | |
def _safe_insert_cache_control_in_message( | |
message: AllMessageValues, control: ChatCompletionCachedContent | |
) -> AllMessageValues: | |
""" | |
Safe way to insert cache control in a message | |
OpenAI Message content can be either: | |
- string | |
- list of objects | |
This method handles inserting cache control in both cases. | |
""" | |
message_content = message.get("content", None) | |
# 1. if string, insert cache control in the message | |
if isinstance(message_content, str): | |
message["cache_control"] = control # type: ignore | |
# 2. list of objects | |
elif isinstance(message_content, list): | |
for content_item in message_content: | |
if isinstance(content_item, dict): | |
content_item["cache_control"] = control # type: ignore | |
return message | |
def integration_name(self) -> str: | |
"""Return the integration name for this hook.""" | |
return "anthropic_cache_control_hook" | |
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool: | |
if non_default_params.get("cache_control_injection_points", None): | |
return True | |
return False | |
def get_custom_logger_for_anthropic_cache_control_hook( | |
non_default_params: Dict, | |
) -> Optional[CustomLogger]: | |
from litellm.litellm_core_utils.litellm_logging import ( | |
_init_custom_logger_compatible_class, | |
) | |
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook( | |
non_default_params | |
): | |
return _init_custom_logger_compatible_class( | |
logging_integration="anthropic_cache_control_hook", | |
internal_usage_cache=None, | |
llm_router=None, | |
) | |
return None | |