Spaces:
Paused
Paused
| """ | |
| This contains LLMCachingHandler | |
| This exposes two methods: | |
| - async_get_cache | |
| - async_set_cache | |
| This file is a wrapper around caching.py | |
| This class is used to handle caching logic specific for LLM API requests (completion / embedding / text_completion / transcription etc) | |
| It utilizes the (RedisCache, s3Cache, RedisSemanticCache, QdrantSemanticCache, InMemoryCache, DiskCache) based on what the user has setup | |
| In each method it will call the appropriate method from caching.py | |
| """ | |
| import asyncio | |
| import datetime | |
| import inspect | |
| import threading | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| AsyncGenerator, | |
| Callable, | |
| Dict, | |
| Generator, | |
| List, | |
| Optional, | |
| Tuple, | |
| Union, | |
| ) | |
| from pydantic import BaseModel | |
| import litellm | |
| from litellm._logging import print_verbose, verbose_logger | |
| from litellm.caching.caching import S3Cache | |
| from litellm.litellm_core_utils.logging_utils import ( | |
| _assemble_complete_response_from_streaming_chunks, | |
| ) | |
| from litellm.types.rerank import RerankResponse | |
| from litellm.types.utils import ( | |
| CallTypes, | |
| Embedding, | |
| EmbeddingResponse, | |
| ModelResponse, | |
| TextCompletionResponse, | |
| TranscriptionResponse, | |
| Usage, | |
| ) | |
| if TYPE_CHECKING: | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
| from litellm.utils import CustomStreamWrapper | |
| else: | |
| LiteLLMLoggingObj = Any | |
| CustomStreamWrapper = Any | |
| class CachingHandlerResponse(BaseModel): | |
| """ | |
| This is the response object for the caching handler. We need to separate embedding cached responses and (completion / text_completion / transcription) cached responses | |
| For embeddings there can be a cache hit for some of the inputs in the list and a cache miss for others | |
| """ | |
| cached_result: Optional[Any] = None | |
| final_embedding_cached_response: Optional[EmbeddingResponse] = None | |
| embedding_all_elements_cache_hit: bool = False # this is set to True when all elements in the list have a cache hit in the embedding cache, if true return the final_embedding_cached_response no need to make an API call | |
| class LLMCachingHandler: | |
| def __init__( | |
| self, | |
| original_function: Callable, | |
| request_kwargs: Dict[str, Any], | |
| start_time: datetime.datetime, | |
| ): | |
| self.async_streaming_chunks: List[ModelResponse] = [] | |
| self.sync_streaming_chunks: List[ModelResponse] = [] | |
| self.request_kwargs = request_kwargs | |
| self.original_function = original_function | |
| self.start_time = start_time | |
| pass | |
| async def _async_get_cache( | |
| self, | |
| model: str, | |
| original_function: Callable, | |
| logging_obj: LiteLLMLoggingObj, | |
| start_time: datetime.datetime, | |
| call_type: str, | |
| kwargs: Dict[str, Any], | |
| args: Optional[Tuple[Any, ...]] = None, | |
| ) -> CachingHandlerResponse: | |
| """ | |
| Internal method to get from the cache. | |
| Handles different call types (embeddings, chat/completions, text_completion, transcription) | |
| and accordingly returns the cached response | |
| Args: | |
| model: str: | |
| original_function: Callable: | |
| logging_obj: LiteLLMLoggingObj: | |
| start_time: datetime.datetime: | |
| call_type: str: | |
| kwargs: Dict[str, Any]: | |
| args: Optional[Tuple[Any, ...]] = None: | |
| Returns: | |
| CachingHandlerResponse: | |
| Raises: | |
| None | |
| """ | |
| from litellm.utils import CustomStreamWrapper | |
| args = args or () | |
| final_embedding_cached_response: Optional[EmbeddingResponse] = None | |
| embedding_all_elements_cache_hit: bool = False | |
| cached_result: Optional[Any] = None | |
| if ( | |
| (kwargs.get("caching", None) is None and litellm.cache is not None) | |
| or kwargs.get("caching", False) is True | |
| ) and ( | |
| kwargs.get("cache", {}).get("no-cache", False) is not True | |
| ): # allow users to control returning cached responses from the completion function | |
| if litellm.cache is not None and self._is_call_type_supported_by_cache( | |
| original_function=original_function | |
| ): | |
| verbose_logger.debug("Checking Async Cache") | |
| cached_result = await self._retrieve_from_cache( | |
| call_type=call_type, | |
| kwargs=kwargs, | |
| args=args, | |
| ) | |
| if cached_result is not None and not isinstance(cached_result, list): | |
| verbose_logger.debug("Cache Hit!") | |
| cache_hit = True | |
| end_time = datetime.datetime.now() | |
| model, _, _, _ = litellm.get_llm_provider( | |
| model=model, | |
| custom_llm_provider=kwargs.get("custom_llm_provider", None), | |
| api_base=kwargs.get("api_base", None), | |
| api_key=kwargs.get("api_key", None), | |
| ) | |
| self._update_litellm_logging_obj_environment( | |
| logging_obj=logging_obj, | |
| model=model, | |
| kwargs=kwargs, | |
| cached_result=cached_result, | |
| is_async=True, | |
| ) | |
| call_type = original_function.__name__ | |
| cached_result = self._convert_cached_result_to_model_response( | |
| cached_result=cached_result, | |
| call_type=call_type, | |
| kwargs=kwargs, | |
| logging_obj=logging_obj, | |
| model=model, | |
| custom_llm_provider=kwargs.get("custom_llm_provider", None), | |
| args=args, | |
| ) | |
| if kwargs.get("stream", False) is False: | |
| # LOG SUCCESS | |
| self._async_log_cache_hit_on_callbacks( | |
| logging_obj=logging_obj, | |
| cached_result=cached_result, | |
| start_time=start_time, | |
| end_time=end_time, | |
| cache_hit=cache_hit, | |
| ) | |
| cache_key = litellm.cache._get_preset_cache_key_from_kwargs( | |
| **kwargs | |
| ) | |
| if ( | |
| isinstance(cached_result, BaseModel) | |
| or isinstance(cached_result, CustomStreamWrapper) | |
| ) and hasattr(cached_result, "_hidden_params"): | |
| cached_result._hidden_params["cache_key"] = cache_key # type: ignore | |
| return CachingHandlerResponse(cached_result=cached_result) | |
| elif ( | |
| call_type == CallTypes.aembedding.value | |
| and cached_result is not None | |
| and isinstance(cached_result, list) | |
| and litellm.cache is not None | |
| and not isinstance( | |
| litellm.cache.cache, S3Cache | |
| ) # s3 doesn't support bulk writing. Exclude. | |
| ): | |
| ( | |
| final_embedding_cached_response, | |
| embedding_all_elements_cache_hit, | |
| ) = self._process_async_embedding_cached_response( | |
| final_embedding_cached_response=final_embedding_cached_response, | |
| cached_result=cached_result, | |
| kwargs=kwargs, | |
| logging_obj=logging_obj, | |
| start_time=start_time, | |
| model=model, | |
| ) | |
| return CachingHandlerResponse( | |
| final_embedding_cached_response=final_embedding_cached_response, | |
| embedding_all_elements_cache_hit=embedding_all_elements_cache_hit, | |
| ) | |
| verbose_logger.debug(f"CACHE RESULT: {cached_result}") | |
| return CachingHandlerResponse( | |
| cached_result=cached_result, | |
| final_embedding_cached_response=final_embedding_cached_response, | |
| ) | |
| def _sync_get_cache( | |
| self, | |
| model: str, | |
| original_function: Callable, | |
| logging_obj: LiteLLMLoggingObj, | |
| start_time: datetime.datetime, | |
| call_type: str, | |
| kwargs: Dict[str, Any], | |
| args: Optional[Tuple[Any, ...]] = None, | |
| ) -> CachingHandlerResponse: | |
| from litellm.utils import CustomStreamWrapper | |
| args = args or () | |
| new_kwargs = kwargs.copy() | |
| new_kwargs.update( | |
| convert_args_to_kwargs( | |
| self.original_function, | |
| args, | |
| ) | |
| ) | |
| cached_result: Optional[Any] = None | |
| if litellm.cache is not None and self._is_call_type_supported_by_cache( | |
| original_function=original_function | |
| ): | |
| print_verbose("Checking Sync Cache") | |
| cached_result = litellm.cache.get_cache(**new_kwargs) | |
| if cached_result is not None: | |
| if "detail" in cached_result: | |
| # implies an error occurred | |
| pass | |
| else: | |
| call_type = original_function.__name__ | |
| cached_result = self._convert_cached_result_to_model_response( | |
| cached_result=cached_result, | |
| call_type=call_type, | |
| kwargs=kwargs, | |
| logging_obj=logging_obj, | |
| model=model, | |
| custom_llm_provider=kwargs.get("custom_llm_provider", None), | |
| args=args, | |
| ) | |
| # LOG SUCCESS | |
| cache_hit = True | |
| end_time = datetime.datetime.now() | |
| ( | |
| model, | |
| custom_llm_provider, | |
| dynamic_api_key, | |
| api_base, | |
| ) = litellm.get_llm_provider( | |
| model=model or "", | |
| custom_llm_provider=kwargs.get("custom_llm_provider", None), | |
| api_base=kwargs.get("api_base", None), | |
| api_key=kwargs.get("api_key", None), | |
| ) | |
| self._update_litellm_logging_obj_environment( | |
| logging_obj=logging_obj, | |
| model=model, | |
| kwargs=kwargs, | |
| cached_result=cached_result, | |
| is_async=False, | |
| ) | |
| threading.Thread( | |
| target=logging_obj.success_handler, | |
| args=(cached_result, start_time, end_time, cache_hit), | |
| ).start() | |
| cache_key = litellm.cache._get_preset_cache_key_from_kwargs( | |
| **kwargs | |
| ) | |
| if ( | |
| isinstance(cached_result, BaseModel) | |
| or isinstance(cached_result, CustomStreamWrapper) | |
| ) and hasattr(cached_result, "_hidden_params"): | |
| cached_result._hidden_params["cache_key"] = cache_key # type: ignore | |
| return CachingHandlerResponse(cached_result=cached_result) | |
| return CachingHandlerResponse(cached_result=cached_result) | |
| def _process_async_embedding_cached_response( | |
| self, | |
| final_embedding_cached_response: Optional[EmbeddingResponse], | |
| cached_result: List[Optional[Dict[str, Any]]], | |
| kwargs: Dict[str, Any], | |
| logging_obj: LiteLLMLoggingObj, | |
| start_time: datetime.datetime, | |
| model: str, | |
| ) -> Tuple[Optional[EmbeddingResponse], bool]: | |
| """ | |
| Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit | |
| For embedding responses, there can be a cache hit for some of the inputs in the list and a cache miss for others | |
| This function processes the cached embedding responses and returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit | |
| Args: | |
| final_embedding_cached_response: Optional[EmbeddingResponse]: | |
| cached_result: List[Optional[Dict[str, Any]]]: | |
| kwargs: Dict[str, Any]: | |
| logging_obj: LiteLLMLoggingObj: | |
| start_time: datetime.datetime: | |
| model: str: | |
| Returns: | |
| Tuple[Optional[EmbeddingResponse], bool]: | |
| Returns the final embedding cached response and a boolean indicating if all elements in the list have a cache hit | |
| """ | |
| embedding_all_elements_cache_hit: bool = False | |
| remaining_list = [] | |
| non_null_list = [] | |
| for idx, cr in enumerate(cached_result): | |
| if cr is None: | |
| remaining_list.append(kwargs["input"][idx]) | |
| else: | |
| non_null_list.append((idx, cr)) | |
| original_kwargs_input = kwargs["input"] | |
| kwargs["input"] = remaining_list | |
| if len(non_null_list) > 0: | |
| print_verbose(f"EMBEDDING CACHE HIT! - {len(non_null_list)}") | |
| final_embedding_cached_response = EmbeddingResponse( | |
| model=kwargs.get("model"), | |
| data=[None] * len(original_kwargs_input), | |
| ) | |
| final_embedding_cached_response._hidden_params["cache_hit"] = True | |
| prompt_tokens = 0 | |
| for val in non_null_list: | |
| idx, cr = val # (idx, cr) tuple | |
| if cr is not None: | |
| final_embedding_cached_response.data[idx] = Embedding( | |
| embedding=cr["embedding"], | |
| index=idx, | |
| object="embedding", | |
| ) | |
| if isinstance(original_kwargs_input[idx], str): | |
| from litellm.utils import token_counter | |
| prompt_tokens += token_counter( | |
| text=original_kwargs_input[idx], count_response_tokens=True | |
| ) | |
| ## USAGE | |
| usage = Usage( | |
| prompt_tokens=prompt_tokens, | |
| completion_tokens=0, | |
| total_tokens=prompt_tokens, | |
| ) | |
| final_embedding_cached_response.usage = usage | |
| if len(remaining_list) == 0: | |
| # LOG SUCCESS | |
| cache_hit = True | |
| embedding_all_elements_cache_hit = True | |
| end_time = datetime.datetime.now() | |
| ( | |
| model, | |
| custom_llm_provider, | |
| dynamic_api_key, | |
| api_base, | |
| ) = litellm.get_llm_provider( | |
| model=model, | |
| custom_llm_provider=kwargs.get("custom_llm_provider", None), | |
| api_base=kwargs.get("api_base", None), | |
| api_key=kwargs.get("api_key", None), | |
| ) | |
| self._update_litellm_logging_obj_environment( | |
| logging_obj=logging_obj, | |
| model=model, | |
| kwargs=kwargs, | |
| cached_result=final_embedding_cached_response, | |
| is_async=True, | |
| is_embedding=True, | |
| ) | |
| self._async_log_cache_hit_on_callbacks( | |
| logging_obj=logging_obj, | |
| cached_result=final_embedding_cached_response, | |
| start_time=start_time, | |
| end_time=end_time, | |
| cache_hit=cache_hit, | |
| ) | |
| return final_embedding_cached_response, embedding_all_elements_cache_hit | |
| return final_embedding_cached_response, embedding_all_elements_cache_hit | |
| def combine_usage(self, usage1: Usage, usage2: Usage) -> Usage: | |
| return Usage( | |
| prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens, | |
| completion_tokens=usage1.completion_tokens + usage2.completion_tokens, | |
| total_tokens=usage1.total_tokens + usage2.total_tokens, | |
| ) | |
| def _combine_cached_embedding_response_with_api_result( | |
| self, | |
| _caching_handler_response: CachingHandlerResponse, | |
| embedding_response: EmbeddingResponse, | |
| start_time: datetime.datetime, | |
| end_time: datetime.datetime, | |
| ) -> EmbeddingResponse: | |
| """ | |
| Combines the cached embedding response with the API EmbeddingResponse | |
| For caching there can be a cache hit for some of the inputs in the list and a cache miss for others | |
| This function combines the cached embedding response with the API EmbeddingResponse | |
| Args: | |
| caching_handler_response: CachingHandlerResponse: | |
| embedding_response: EmbeddingResponse: | |
| Returns: | |
| EmbeddingResponse: | |
| """ | |
| if _caching_handler_response.final_embedding_cached_response is None: | |
| return embedding_response | |
| idx = 0 | |
| final_data_list = [] | |
| for item in _caching_handler_response.final_embedding_cached_response.data: | |
| if item is None and embedding_response.data is not None: | |
| final_data_list.append(embedding_response.data[idx]) | |
| idx += 1 | |
| else: | |
| final_data_list.append(item) | |
| _caching_handler_response.final_embedding_cached_response.data = final_data_list | |
| _caching_handler_response.final_embedding_cached_response._hidden_params[ | |
| "cache_hit" | |
| ] = True | |
| _caching_handler_response.final_embedding_cached_response._response_ms = ( | |
| end_time - start_time | |
| ).total_seconds() * 1000 | |
| ## USAGE | |
| if ( | |
| _caching_handler_response.final_embedding_cached_response.usage is not None | |
| and embedding_response.usage is not None | |
| ): | |
| _caching_handler_response.final_embedding_cached_response.usage = self.combine_usage( | |
| usage1=_caching_handler_response.final_embedding_cached_response.usage, | |
| usage2=embedding_response.usage, | |
| ) | |
| return _caching_handler_response.final_embedding_cached_response | |
| def _async_log_cache_hit_on_callbacks( | |
| self, | |
| logging_obj: LiteLLMLoggingObj, | |
| cached_result: Any, | |
| start_time: datetime.datetime, | |
| end_time: datetime.datetime, | |
| cache_hit: bool, | |
| ): | |
| """ | |
| Helper function to log the success of a cached result on callbacks | |
| Args: | |
| logging_obj (LiteLLMLoggingObj): The logging object. | |
| cached_result: The cached result. | |
| start_time (datetime): The start time of the operation. | |
| end_time (datetime): The end time of the operation. | |
| cache_hit (bool): Whether it was a cache hit. | |
| """ | |
| asyncio.create_task( | |
| logging_obj.async_success_handler( | |
| cached_result, start_time, end_time, cache_hit | |
| ) | |
| ) | |
| threading.Thread( | |
| target=logging_obj.success_handler, | |
| args=(cached_result, start_time, end_time, cache_hit), | |
| ).start() | |
| async def _retrieve_from_cache( | |
| self, call_type: str, kwargs: Dict[str, Any], args: Tuple[Any, ...] | |
| ) -> Optional[Any]: | |
| """ | |
| Internal method to | |
| - get cache key | |
| - check what type of cache is used - Redis, RedisSemantic, Qdrant, S3 | |
| - async get cache value | |
| - return the cached value | |
| Args: | |
| call_type: str: | |
| kwargs: Dict[str, Any]: | |
| args: Optional[Tuple[Any, ...]] = None: | |
| Returns: | |
| Optional[Any]: | |
| Raises: | |
| None | |
| """ | |
| if litellm.cache is None: | |
| return None | |
| new_kwargs = kwargs.copy() | |
| new_kwargs.update( | |
| convert_args_to_kwargs( | |
| self.original_function, | |
| args, | |
| ) | |
| ) | |
| cached_result: Optional[Any] = None | |
| if call_type == CallTypes.aembedding.value and isinstance( | |
| new_kwargs["input"], list | |
| ): | |
| tasks = [] | |
| for idx, i in enumerate(new_kwargs["input"]): | |
| preset_cache_key = litellm.cache.get_cache_key( | |
| **{**new_kwargs, "input": i} | |
| ) | |
| tasks.append(litellm.cache.async_get_cache(cache_key=preset_cache_key)) | |
| cached_result = await asyncio.gather(*tasks) | |
| ## check if cached result is None ## | |
| if cached_result is not None and isinstance(cached_result, list): | |
| # set cached_result to None if all elements are None | |
| if all(result is None for result in cached_result): | |
| cached_result = None | |
| else: | |
| if litellm.cache._supports_async() is True: | |
| cached_result = await litellm.cache.async_get_cache(**new_kwargs) | |
| else: # for s3 caching. [NOT RECOMMENDED IN PROD - this will slow down responses since boto3 is sync] | |
| cached_result = litellm.cache.get_cache(**new_kwargs) | |
| return cached_result | |
| def _convert_cached_result_to_model_response( | |
| self, | |
| cached_result: Any, | |
| call_type: str, | |
| kwargs: Dict[str, Any], | |
| logging_obj: LiteLLMLoggingObj, | |
| model: str, | |
| args: Tuple[Any, ...], | |
| custom_llm_provider: Optional[str] = None, | |
| ) -> Optional[ | |
| Union[ | |
| ModelResponse, | |
| TextCompletionResponse, | |
| EmbeddingResponse, | |
| RerankResponse, | |
| TranscriptionResponse, | |
| CustomStreamWrapper, | |
| ] | |
| ]: | |
| """ | |
| Internal method to process the cached result | |
| Checks the call type and converts the cached result to the appropriate model response object | |
| example if call type is text_completion -> returns TextCompletionResponse object | |
| Args: | |
| cached_result: Any: | |
| call_type: str: | |
| kwargs: Dict[str, Any]: | |
| logging_obj: LiteLLMLoggingObj: | |
| model: str: | |
| custom_llm_provider: Optional[str] = None: | |
| args: Optional[Tuple[Any, ...]] = None: | |
| Returns: | |
| Optional[Any]: | |
| """ | |
| from litellm.utils import convert_to_model_response_object | |
| if ( | |
| call_type == CallTypes.acompletion.value | |
| or call_type == CallTypes.completion.value | |
| ) and isinstance(cached_result, dict): | |
| if kwargs.get("stream", False) is True: | |
| cached_result = self._convert_cached_stream_response( | |
| cached_result=cached_result, | |
| call_type=call_type, | |
| logging_obj=logging_obj, | |
| model=model, | |
| ) | |
| else: | |
| cached_result = convert_to_model_response_object( | |
| response_object=cached_result, | |
| model_response_object=ModelResponse(), | |
| ) | |
| if ( | |
| call_type == CallTypes.atext_completion.value | |
| or call_type == CallTypes.text_completion.value | |
| ) and isinstance(cached_result, dict): | |
| if kwargs.get("stream", False) is True: | |
| cached_result = self._convert_cached_stream_response( | |
| cached_result=cached_result, | |
| call_type=call_type, | |
| logging_obj=logging_obj, | |
| model=model, | |
| ) | |
| else: | |
| cached_result = TextCompletionResponse(**cached_result) | |
| elif ( | |
| call_type == CallTypes.aembedding.value | |
| or call_type == CallTypes.embedding.value | |
| ) and isinstance(cached_result, dict): | |
| cached_result = convert_to_model_response_object( | |
| response_object=cached_result, | |
| model_response_object=EmbeddingResponse(), | |
| response_type="embedding", | |
| ) | |
| elif ( | |
| call_type == CallTypes.arerank.value or call_type == CallTypes.rerank.value | |
| ) and isinstance(cached_result, dict): | |
| cached_result = convert_to_model_response_object( | |
| response_object=cached_result, | |
| model_response_object=None, | |
| response_type="rerank", | |
| ) | |
| elif ( | |
| call_type == CallTypes.atranscription.value | |
| or call_type == CallTypes.transcription.value | |
| ) and isinstance(cached_result, dict): | |
| hidden_params = { | |
| "model": "whisper-1", | |
| "custom_llm_provider": custom_llm_provider, | |
| "cache_hit": True, | |
| } | |
| cached_result = convert_to_model_response_object( | |
| response_object=cached_result, | |
| model_response_object=TranscriptionResponse(), | |
| response_type="audio_transcription", | |
| hidden_params=hidden_params, | |
| ) | |
| if ( | |
| hasattr(cached_result, "_hidden_params") | |
| and cached_result._hidden_params is not None | |
| and isinstance(cached_result._hidden_params, dict) | |
| ): | |
| cached_result._hidden_params["cache_hit"] = True | |
| return cached_result | |
| def _convert_cached_stream_response( | |
| self, | |
| cached_result: Any, | |
| call_type: str, | |
| logging_obj: LiteLLMLoggingObj, | |
| model: str, | |
| ) -> CustomStreamWrapper: | |
| from litellm.utils import ( | |
| CustomStreamWrapper, | |
| convert_to_streaming_response, | |
| convert_to_streaming_response_async, | |
| ) | |
| _stream_cached_result: Union[AsyncGenerator, Generator] | |
| if ( | |
| call_type == CallTypes.acompletion.value | |
| or call_type == CallTypes.atext_completion.value | |
| ): | |
| _stream_cached_result = convert_to_streaming_response_async( | |
| response_object=cached_result, | |
| ) | |
| else: | |
| _stream_cached_result = convert_to_streaming_response( | |
| response_object=cached_result, | |
| ) | |
| return CustomStreamWrapper( | |
| completion_stream=_stream_cached_result, | |
| model=model, | |
| custom_llm_provider="cached_response", | |
| logging_obj=logging_obj, | |
| ) | |
| async def async_set_cache( | |
| self, | |
| result: Any, | |
| original_function: Callable, | |
| kwargs: Dict[str, Any], | |
| args: Optional[Tuple[Any, ...]] = None, | |
| ): | |
| """ | |
| Internal method to check the type of the result & cache used and adds the result to the cache accordingly | |
| Args: | |
| result: Any: | |
| original_function: Callable: | |
| kwargs: Dict[str, Any]: | |
| args: Optional[Tuple[Any, ...]] = None: | |
| Returns: | |
| None | |
| Raises: | |
| None | |
| """ | |
| if litellm.cache is None: | |
| return | |
| new_kwargs = kwargs.copy() | |
| new_kwargs.update( | |
| convert_args_to_kwargs( | |
| original_function, | |
| args, | |
| ) | |
| ) | |
| # [OPTIONAL] ADD TO CACHE | |
| if self._should_store_result_in_cache( | |
| original_function=original_function, kwargs=new_kwargs | |
| ): | |
| if ( | |
| isinstance(result, litellm.ModelResponse) | |
| or isinstance(result, litellm.EmbeddingResponse) | |
| or isinstance(result, TranscriptionResponse) | |
| or isinstance(result, RerankResponse) | |
| ): | |
| if ( | |
| isinstance(result, EmbeddingResponse) | |
| and litellm.cache is not None | |
| and not isinstance( | |
| litellm.cache.cache, S3Cache | |
| ) # s3 doesn't support bulk writing. Exclude. | |
| ): | |
| asyncio.create_task( | |
| litellm.cache.async_add_cache_pipeline(result, **new_kwargs) | |
| ) | |
| elif isinstance(litellm.cache.cache, S3Cache): | |
| threading.Thread( | |
| target=litellm.cache.add_cache, | |
| args=(result,), | |
| kwargs=new_kwargs, | |
| ).start() | |
| else: | |
| asyncio.create_task( | |
| litellm.cache.async_add_cache( | |
| result.model_dump_json(), **new_kwargs | |
| ) | |
| ) | |
| else: | |
| asyncio.create_task(litellm.cache.async_add_cache(result, **new_kwargs)) | |
| def sync_set_cache( | |
| self, | |
| result: Any, | |
| kwargs: Dict[str, Any], | |
| args: Optional[Tuple[Any, ...]] = None, | |
| ): | |
| """ | |
| Sync internal method to add the result to the cache | |
| """ | |
| new_kwargs = kwargs.copy() | |
| new_kwargs.update( | |
| convert_args_to_kwargs( | |
| self.original_function, | |
| args, | |
| ) | |
| ) | |
| if litellm.cache is None: | |
| return | |
| if self._should_store_result_in_cache( | |
| original_function=self.original_function, kwargs=new_kwargs | |
| ): | |
| litellm.cache.add_cache(result, **new_kwargs) | |
| return | |
| def _should_store_result_in_cache( | |
| self, original_function: Callable, kwargs: Dict[str, Any] | |
| ) -> bool: | |
| """ | |
| Helper function to determine if the result should be stored in the cache. | |
| Returns: | |
| bool: True if the result should be stored in the cache, False otherwise. | |
| """ | |
| return ( | |
| (litellm.cache is not None) | |
| and litellm.cache.supported_call_types is not None | |
| and (str(original_function.__name__) in litellm.cache.supported_call_types) | |
| and (kwargs.get("cache", {}).get("no-store", False) is not True) | |
| ) | |
| def _is_call_type_supported_by_cache( | |
| self, | |
| original_function: Callable, | |
| ) -> bool: | |
| """ | |
| Helper function to determine if the call type is supported by the cache. | |
| call types are acompletion, aembedding, atext_completion, atranscription, arerank | |
| Defined on `litellm.types.utils.CallTypes` | |
| Returns: | |
| bool: True if the call type is supported by the cache, False otherwise. | |
| """ | |
| if ( | |
| litellm.cache is not None | |
| and litellm.cache.supported_call_types is not None | |
| and str(original_function.__name__) in litellm.cache.supported_call_types | |
| ): | |
| return True | |
| return False | |
| async def _add_streaming_response_to_cache(self, processed_chunk: ModelResponse): | |
| """ | |
| Internal method to add the streaming response to the cache | |
| - If 'streaming_chunk' has a 'finish_reason' then assemble a litellm.ModelResponse object | |
| - Else append the chunk to self.async_streaming_chunks | |
| """ | |
| complete_streaming_response: Optional[ | |
| Union[ModelResponse, TextCompletionResponse] | |
| ] = _assemble_complete_response_from_streaming_chunks( | |
| result=processed_chunk, | |
| start_time=self.start_time, | |
| end_time=datetime.datetime.now(), | |
| request_kwargs=self.request_kwargs, | |
| streaming_chunks=self.async_streaming_chunks, | |
| is_async=True, | |
| ) | |
| # if a complete_streaming_response is assembled, add it to the cache | |
| if complete_streaming_response is not None: | |
| await self.async_set_cache( | |
| result=complete_streaming_response, | |
| original_function=self.original_function, | |
| kwargs=self.request_kwargs, | |
| ) | |
| def _sync_add_streaming_response_to_cache(self, processed_chunk: ModelResponse): | |
| """ | |
| Sync internal method to add the streaming response to the cache | |
| """ | |
| complete_streaming_response: Optional[ | |
| Union[ModelResponse, TextCompletionResponse] | |
| ] = _assemble_complete_response_from_streaming_chunks( | |
| result=processed_chunk, | |
| start_time=self.start_time, | |
| end_time=datetime.datetime.now(), | |
| request_kwargs=self.request_kwargs, | |
| streaming_chunks=self.sync_streaming_chunks, | |
| is_async=False, | |
| ) | |
| # if a complete_streaming_response is assembled, add it to the cache | |
| if complete_streaming_response is not None: | |
| self.sync_set_cache( | |
| result=complete_streaming_response, | |
| kwargs=self.request_kwargs, | |
| ) | |
| def _update_litellm_logging_obj_environment( | |
| self, | |
| logging_obj: LiteLLMLoggingObj, | |
| model: str, | |
| kwargs: Dict[str, Any], | |
| cached_result: Any, | |
| is_async: bool, | |
| is_embedding: bool = False, | |
| ): | |
| """ | |
| Helper function to update the LiteLLMLoggingObj environment variables. | |
| Args: | |
| logging_obj (LiteLLMLoggingObj): The logging object to update. | |
| model (str): The model being used. | |
| kwargs (Dict[str, Any]): The keyword arguments from the original function call. | |
| cached_result (Any): The cached result to log. | |
| is_async (bool): Whether the call is asynchronous or not. | |
| is_embedding (bool): Whether the call is for embeddings or not. | |
| Returns: | |
| None | |
| """ | |
| litellm_params = { | |
| "logger_fn": kwargs.get("logger_fn", None), | |
| "acompletion": is_async, | |
| "api_base": kwargs.get("api_base", ""), | |
| "metadata": kwargs.get("metadata", {}), | |
| "model_info": kwargs.get("model_info", {}), | |
| "proxy_server_request": kwargs.get("proxy_server_request", None), | |
| "stream_response": kwargs.get("stream_response", {}), | |
| } | |
| if litellm.cache is not None: | |
| litellm_params[ | |
| "preset_cache_key" | |
| ] = litellm.cache._get_preset_cache_key_from_kwargs(**kwargs) | |
| else: | |
| litellm_params["preset_cache_key"] = None | |
| logging_obj.update_environment_variables( | |
| model=model, | |
| user=kwargs.get("user", None), | |
| optional_params={}, | |
| litellm_params=litellm_params, | |
| input=( | |
| kwargs.get("messages", "") | |
| if not is_embedding | |
| else kwargs.get("input", "") | |
| ), | |
| api_key=kwargs.get("api_key", None), | |
| original_response=str(cached_result), | |
| additional_args=None, | |
| stream=kwargs.get("stream", False), | |
| ) | |
| def convert_args_to_kwargs( | |
| original_function: Callable, | |
| args: Optional[Tuple[Any, ...]] = None, | |
| ) -> Dict[str, Any]: | |
| # Get the signature of the original function | |
| signature = inspect.signature(original_function) | |
| # Get parameter names in the order they appear in the original function | |
| param_names = list(signature.parameters.keys()) | |
| # Create a mapping of positional arguments to parameter names | |
| args_to_kwargs = {} | |
| if args: | |
| for index, arg in enumerate(args): | |
| if index < len(param_names): | |
| param_name = param_names[index] | |
| args_to_kwargs[param_name] = arg | |
| return args_to_kwargs | |