Spaces:
Sleeping
Sleeping
# +-----------------------------------------------+ | |
# | | | |
# | Give Feedback / Get Help | | |
# | https://github.com/BerriAI/litellm/issues/new | | |
# | | | |
# +-----------------------------------------------+ | |
# | |
# Thank you users! We ❤️ you! - Krrish & Ishaan | |
import ast | |
import hashlib | |
import json | |
import time | |
import traceback | |
from enum import Enum | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
from pydantic import BaseModel | |
import litellm | |
from litellm._logging import verbose_logger | |
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY | |
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper | |
from litellm.types.caching import * | |
from litellm.types.utils import EmbeddingResponse, all_litellm_params | |
from .base_cache import BaseCache | |
from .disk_cache import DiskCache | |
from .dual_cache import DualCache # noqa | |
from .in_memory_cache import InMemoryCache | |
from .qdrant_semantic_cache import QdrantSemanticCache | |
from .redis_cache import RedisCache | |
from .redis_cluster_cache import RedisClusterCache | |
from .redis_semantic_cache import RedisSemanticCache | |
from .s3_cache import S3Cache | |
def print_verbose(print_statement): | |
try: | |
verbose_logger.debug(print_statement) | |
if litellm.set_verbose: | |
print(print_statement) # noqa | |
except Exception: | |
pass | |
class CacheMode(str, Enum): | |
default_on = "default_on" | |
default_off = "default_off" | |
#### LiteLLM.Completion / Embedding Cache #### | |
class Cache: | |
def __init__( | |
self, | |
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, | |
mode: Optional[ | |
CacheMode | |
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in | |
host: Optional[str] = None, | |
port: Optional[str] = None, | |
password: Optional[str] = None, | |
namespace: Optional[str] = None, | |
ttl: Optional[float] = None, | |
default_in_memory_ttl: Optional[float] = None, | |
default_in_redis_ttl: Optional[float] = None, | |
similarity_threshold: Optional[float] = None, | |
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ | |
"completion", | |
"acompletion", | |
"embedding", | |
"aembedding", | |
"atranscription", | |
"transcription", | |
"atext_completion", | |
"text_completion", | |
"arerank", | |
"rerank", | |
], | |
# s3 Bucket, boto3 configuration | |
s3_bucket_name: Optional[str] = None, | |
s3_region_name: Optional[str] = None, | |
s3_api_version: Optional[str] = None, | |
s3_use_ssl: Optional[bool] = True, | |
s3_verify: Optional[Union[bool, str]] = None, | |
s3_endpoint_url: Optional[str] = None, | |
s3_aws_access_key_id: Optional[str] = None, | |
s3_aws_secret_access_key: Optional[str] = None, | |
s3_aws_session_token: Optional[str] = None, | |
s3_config: Optional[Any] = None, | |
s3_path: Optional[str] = None, | |
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002", | |
redis_semantic_cache_index_name: Optional[str] = None, | |
redis_flush_size: Optional[int] = None, | |
redis_startup_nodes: Optional[List] = None, | |
disk_cache_dir: Optional[str] = None, | |
qdrant_api_base: Optional[str] = None, | |
qdrant_api_key: Optional[str] = None, | |
qdrant_collection_name: Optional[str] = None, | |
qdrant_quantization_config: Optional[str] = None, | |
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002", | |
**kwargs, | |
): | |
""" | |
Initializes the cache based on the given type. | |
Args: | |
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local". | |
# Redis Cache Args | |
host (str, optional): The host address for the Redis cache. Required if type is "redis". | |
port (int, optional): The port number for the Redis cache. Required if type is "redis". | |
password (str, optional): The password for the Redis cache. Required if type is "redis". | |
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis". | |
ttl (float, optional): The ttl for the Redis cache | |
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used. | |
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None. | |
# Qdrant Cache Args | |
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic". | |
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster. | |
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic". | |
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic". | |
# Disk Cache Args | |
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None. | |
# S3 Cache Args | |
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None. | |
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None. | |
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None. | |
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True. | |
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None. | |
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None. | |
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None. | |
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None. | |
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None. | |
s3_config (dict, optional): The config for the s3 cache. Defaults to None. | |
# Common Cache Args | |
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types. | |
**kwargs: Additional keyword arguments for redis.Redis() cache | |
Raises: | |
ValueError: If an invalid cache type is provided. | |
Returns: | |
None. Cache is set as a litellm param | |
""" | |
if type == LiteLLMCacheType.REDIS: | |
if redis_startup_nodes: | |
self.cache: BaseCache = RedisClusterCache( | |
host=host, | |
port=port, | |
password=password, | |
redis_flush_size=redis_flush_size, | |
startup_nodes=redis_startup_nodes, | |
**kwargs, | |
) | |
else: | |
self.cache = RedisCache( | |
host=host, | |
port=port, | |
password=password, | |
redis_flush_size=redis_flush_size, | |
**kwargs, | |
) | |
elif type == LiteLLMCacheType.REDIS_SEMANTIC: | |
self.cache = RedisSemanticCache( | |
host=host, | |
port=port, | |
password=password, | |
similarity_threshold=similarity_threshold, | |
embedding_model=redis_semantic_cache_embedding_model, | |
index_name=redis_semantic_cache_index_name, | |
**kwargs, | |
) | |
elif type == LiteLLMCacheType.QDRANT_SEMANTIC: | |
self.cache = QdrantSemanticCache( | |
qdrant_api_base=qdrant_api_base, | |
qdrant_api_key=qdrant_api_key, | |
collection_name=qdrant_collection_name, | |
similarity_threshold=similarity_threshold, | |
quantization_config=qdrant_quantization_config, | |
embedding_model=qdrant_semantic_cache_embedding_model, | |
) | |
elif type == LiteLLMCacheType.LOCAL: | |
self.cache = InMemoryCache() | |
elif type == LiteLLMCacheType.S3: | |
self.cache = S3Cache( | |
s3_bucket_name=s3_bucket_name, | |
s3_region_name=s3_region_name, | |
s3_api_version=s3_api_version, | |
s3_use_ssl=s3_use_ssl, | |
s3_verify=s3_verify, | |
s3_endpoint_url=s3_endpoint_url, | |
s3_aws_access_key_id=s3_aws_access_key_id, | |
s3_aws_secret_access_key=s3_aws_secret_access_key, | |
s3_aws_session_token=s3_aws_session_token, | |
s3_config=s3_config, | |
s3_path=s3_path, | |
**kwargs, | |
) | |
elif type == LiteLLMCacheType.DISK: | |
self.cache = DiskCache(disk_cache_dir=disk_cache_dir) | |
if "cache" not in litellm.input_callback: | |
litellm.input_callback.append("cache") | |
if "cache" not in litellm.success_callback: | |
litellm.logging_callback_manager.add_litellm_success_callback("cache") | |
if "cache" not in litellm._async_success_callback: | |
litellm.logging_callback_manager.add_litellm_async_success_callback("cache") | |
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"] | |
self.type = type | |
self.namespace = namespace | |
self.redis_flush_size = redis_flush_size | |
self.ttl = ttl | |
self.mode: CacheMode = mode or CacheMode.default_on | |
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None: | |
self.ttl = default_in_memory_ttl | |
if ( | |
self.type == LiteLLMCacheType.REDIS | |
or self.type == LiteLLMCacheType.REDIS_SEMANTIC | |
) and default_in_redis_ttl is not None: | |
self.ttl = default_in_redis_ttl | |
if self.namespace is not None and isinstance(self.cache, RedisCache): | |
self.cache.namespace = self.namespace | |
def get_cache_key(self, **kwargs) -> str: | |
""" | |
Get the cache key for the given arguments. | |
Args: | |
**kwargs: kwargs to litellm.completion() or embedding() | |
Returns: | |
str: The cache key generated from the arguments, or None if no cache key could be generated. | |
""" | |
cache_key = "" | |
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs) | |
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs) | |
if preset_cache_key is not None: | |
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key) | |
return preset_cache_key | |
combined_kwargs = ModelParamHelper._get_all_llm_api_params() | |
litellm_param_kwargs = all_litellm_params | |
for param in kwargs: | |
if param in combined_kwargs: | |
param_value: Optional[str] = self._get_param_value(param, kwargs) | |
if param_value is not None: | |
cache_key += f"{str(param)}: {str(param_value)}" | |
elif ( | |
param not in litellm_param_kwargs | |
): # check if user passed in optional param - e.g. top_k | |
if ( | |
litellm.enable_caching_on_provider_specific_optional_params is True | |
): # feature flagged for now | |
if kwargs[param] is None: | |
continue # ignore None params | |
param_value = kwargs[param] | |
cache_key += f"{str(param)}: {str(param_value)}" | |
verbose_logger.debug("\nCreated cache key: %s", cache_key) | |
hashed_cache_key = Cache._get_hashed_cache_key(cache_key) | |
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs) | |
self._set_preset_cache_key_in_kwargs( | |
preset_cache_key=hashed_cache_key, **kwargs | |
) | |
return hashed_cache_key | |
def _get_param_value( | |
self, | |
param: str, | |
kwargs: dict, | |
) -> Optional[str]: | |
""" | |
Get the value for the given param from kwargs | |
""" | |
if param == "model": | |
return self._get_model_param_value(kwargs) | |
elif param == "file": | |
return self._get_file_param_value(kwargs) | |
return kwargs[param] | |
def _get_model_param_value(self, kwargs: dict) -> str: | |
""" | |
Handles getting the value for the 'model' param from kwargs | |
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups | |
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router() | |
3. Else use the `model` passed in kwargs | |
""" | |
metadata: Dict = kwargs.get("metadata", {}) or {} | |
litellm_params: Dict = kwargs.get("litellm_params", {}) or {} | |
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {} | |
model_group: Optional[str] = metadata.get( | |
"model_group" | |
) or metadata_in_litellm_params.get("model_group") | |
caching_group = self._get_caching_group(metadata, model_group) | |
return caching_group or model_group or kwargs["model"] | |
def _get_caching_group( | |
self, metadata: dict, model_group: Optional[str] | |
) -> Optional[str]: | |
caching_groups: Optional[List] = metadata.get("caching_groups", []) | |
if caching_groups: | |
for group in caching_groups: | |
if model_group in group: | |
return str(group) | |
return None | |
def _get_file_param_value(self, kwargs: dict) -> str: | |
""" | |
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests | |
""" | |
file = kwargs.get("file") | |
metadata = kwargs.get("metadata", {}) | |
litellm_params = kwargs.get("litellm_params", {}) | |
return ( | |
metadata.get("file_checksum") | |
or getattr(file, "name", None) | |
or metadata.get("file_name") | |
or litellm_params.get("file_name") | |
) | |
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]: | |
""" | |
Get the preset cache key from kwargs["litellm_params"] | |
We use _get_preset_cache_keys for two reasons | |
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens | |
2. avoid doing duplicate / repeated work | |
""" | |
if kwargs: | |
if "litellm_params" in kwargs: | |
return kwargs["litellm_params"].get("preset_cache_key", None) | |
return None | |
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None: | |
""" | |
Set the calculated cache key in kwargs | |
This is used to avoid doing duplicate / repeated work | |
Placed in kwargs["litellm_params"] | |
""" | |
if kwargs: | |
if "litellm_params" in kwargs: | |
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key | |
def _get_hashed_cache_key(cache_key: str) -> str: | |
""" | |
Get the hashed cache key for the given cache key. | |
Use hashlib to create a sha256 hash of the cache key | |
Args: | |
cache_key (str): The cache key to hash. | |
Returns: | |
str: The hashed cache key. | |
""" | |
hash_object = hashlib.sha256(cache_key.encode()) | |
# Hexadecimal representation of the hash | |
hash_hex = hash_object.hexdigest() | |
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex) | |
return hash_hex | |
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str: | |
""" | |
If a redis namespace is provided, add it to the cache key | |
Args: | |
hash_hex (str): The hashed cache key. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
str: The final hashed cache key with the redis namespace. | |
""" | |
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {}) | |
namespace = ( | |
dynamic_cache_control.get("namespace") | |
or kwargs.get("metadata", {}).get("redis_namespace") | |
or self.namespace | |
) | |
if namespace: | |
hash_hex = f"{namespace}:{hash_hex}" | |
verbose_logger.debug("Final hashed key: %s", hash_hex) | |
return hash_hex | |
def generate_streaming_content(self, content): | |
chunk_size = 5 # Adjust the chunk size as needed | |
for i in range(0, len(content), chunk_size): | |
yield { | |
"choices": [ | |
{ | |
"delta": { | |
"role": "assistant", | |
"content": content[i : i + chunk_size], | |
} | |
} | |
] | |
} | |
time.sleep(CACHED_STREAMING_CHUNK_DELAY) | |
def _get_cache_logic( | |
self, | |
cached_result: Optional[Any], | |
max_age: Optional[float], | |
): | |
""" | |
Common get cache logic across sync + async implementations | |
""" | |
# Check if a timestamp was stored with the cached response | |
if ( | |
cached_result is not None | |
and isinstance(cached_result, dict) | |
and "timestamp" in cached_result | |
): | |
timestamp = cached_result["timestamp"] | |
current_time = time.time() | |
# Calculate age of the cached response | |
response_age = current_time - timestamp | |
# Check if the cached response is older than the max-age | |
if max_age is not None and response_age > max_age: | |
return None # Cached response is too old | |
# If the response is fresh, or there's no max-age requirement, return the cached response | |
# cached_response is in `b{} convert it to ModelResponse | |
cached_response = cached_result.get("response") | |
try: | |
if isinstance(cached_response, dict): | |
pass | |
else: | |
cached_response = json.loads( | |
cached_response # type: ignore | |
) # Convert string to dictionary | |
except Exception: | |
cached_response = ast.literal_eval(cached_response) # type: ignore | |
return cached_response | |
return cached_result | |
def get_cache(self, **kwargs): | |
""" | |
Retrieves the cached result for the given arguments. | |
Args: | |
*args: args to litellm.completion() or embedding() | |
**kwargs: kwargs to litellm.completion() or embedding() | |
Returns: | |
The cached result if it exists, otherwise None. | |
""" | |
try: # never block execution | |
if self.should_use_cache(**kwargs) is not True: | |
return | |
messages = kwargs.get("messages", []) | |
if "cache_key" in kwargs: | |
cache_key = kwargs["cache_key"] | |
else: | |
cache_key = self.get_cache_key(**kwargs) | |
if cache_key is not None: | |
cache_control_args: DynamicCacheControl = kwargs.get("cache", {}) | |
max_age = ( | |
cache_control_args.get("s-maxage") | |
or cache_control_args.get("s-max-age") | |
or float("inf") | |
) | |
cached_result = self.cache.get_cache(cache_key, messages=messages) | |
cached_result = self.cache.get_cache(cache_key, messages=messages) | |
return self._get_cache_logic( | |
cached_result=cached_result, max_age=max_age | |
) | |
except Exception: | |
print_verbose(f"An exception occurred: {traceback.format_exc()}") | |
return None | |
async def async_get_cache(self, **kwargs): | |
""" | |
Async get cache implementation. | |
Used for embedding calls in async wrapper | |
""" | |
try: # never block execution | |
if self.should_use_cache(**kwargs) is not True: | |
return | |
kwargs.get("messages", []) | |
if "cache_key" in kwargs: | |
cache_key = kwargs["cache_key"] | |
else: | |
cache_key = self.get_cache_key(**kwargs) | |
if cache_key is not None: | |
cache_control_args = kwargs.get("cache", {}) | |
max_age = cache_control_args.get( | |
"s-max-age", cache_control_args.get("s-maxage", float("inf")) | |
) | |
cached_result = await self.cache.async_get_cache(cache_key, **kwargs) | |
return self._get_cache_logic( | |
cached_result=cached_result, max_age=max_age | |
) | |
except Exception: | |
print_verbose(f"An exception occurred: {traceback.format_exc()}") | |
return None | |
def _add_cache_logic(self, result, **kwargs): | |
""" | |
Common implementation across sync + async add_cache functions | |
""" | |
try: | |
if "cache_key" in kwargs: | |
cache_key = kwargs["cache_key"] | |
else: | |
cache_key = self.get_cache_key(**kwargs) | |
if cache_key is not None: | |
if isinstance(result, BaseModel): | |
result = result.model_dump_json() | |
## DEFAULT TTL ## | |
if self.ttl is not None: | |
kwargs["ttl"] = self.ttl | |
## Get Cache-Controls ## | |
_cache_kwargs = kwargs.get("cache", None) | |
if isinstance(_cache_kwargs, dict): | |
for k, v in _cache_kwargs.items(): | |
if k == "ttl": | |
kwargs["ttl"] = v | |
cached_data = {"timestamp": time.time(), "response": result} | |
return cache_key, cached_data, kwargs | |
else: | |
raise Exception("cache key is None") | |
except Exception as e: | |
raise e | |
def add_cache(self, result, **kwargs): | |
""" | |
Adds a result to the cache. | |
Args: | |
*args: args to litellm.completion() or embedding() | |
**kwargs: kwargs to litellm.completion() or embedding() | |
Returns: | |
None | |
""" | |
try: | |
if self.should_use_cache(**kwargs) is not True: | |
return | |
cache_key, cached_data, kwargs = self._add_cache_logic( | |
result=result, **kwargs | |
) | |
self.cache.set_cache(cache_key, cached_data, **kwargs) | |
except Exception as e: | |
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") | |
async def async_add_cache(self, result, **kwargs): | |
""" | |
Async implementation of add_cache | |
""" | |
try: | |
if self.should_use_cache(**kwargs) is not True: | |
return | |
if self.type == "redis" and self.redis_flush_size is not None: | |
# high traffic - fill in results in memory and then flush | |
await self.batch_cache_write(result, **kwargs) | |
else: | |
cache_key, cached_data, kwargs = self._add_cache_logic( | |
result=result, **kwargs | |
) | |
await self.cache.async_set_cache(cache_key, cached_data, **kwargs) | |
except Exception as e: | |
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") | |
def add_embedding_response_to_cache( | |
self, | |
result: EmbeddingResponse, | |
input: str, | |
kwargs: dict, | |
idx_in_result_data: int = 0, | |
) -> Tuple[str, dict, dict]: | |
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input}) | |
kwargs["cache_key"] = preset_cache_key | |
embedding_response = result.data[idx_in_result_data] | |
cache_key, cached_data, kwargs = self._add_cache_logic( | |
result=embedding_response, | |
**kwargs, | |
) | |
return cache_key, cached_data, kwargs | |
async def async_add_cache_pipeline(self, result, **kwargs): | |
""" | |
Async implementation of add_cache for Embedding calls | |
Does a bulk write, to prevent using too many clients | |
""" | |
try: | |
if self.should_use_cache(**kwargs) is not True: | |
return | |
# set default ttl if not set | |
if self.ttl is not None: | |
kwargs["ttl"] = self.ttl | |
cache_list = [] | |
if isinstance(kwargs["input"], list): | |
for idx, i in enumerate(kwargs["input"]): | |
( | |
cache_key, | |
cached_data, | |
kwargs, | |
) = self.add_embedding_response_to_cache(result, i, kwargs, idx) | |
cache_list.append((cache_key, cached_data)) | |
elif isinstance(kwargs["input"], str): | |
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache( | |
result, kwargs["input"], kwargs | |
) | |
cache_list.append((cache_key, cached_data)) | |
await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs) | |
# if async_set_cache_pipeline: | |
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs) | |
# else: | |
# tasks = [] | |
# for val in cache_list: | |
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs)) | |
# await asyncio.gather(*tasks) | |
except Exception as e: | |
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") | |
def should_use_cache(self, **kwargs): | |
""" | |
Returns true if we should use the cache for LLM API calls | |
If cache is default_on then this is True | |
If cache is default_off then this is only true when user has opted in to use cache | |
""" | |
if self.mode == CacheMode.default_on: | |
return True | |
# when mode == default_off -> Cache is opt in only | |
_cache = kwargs.get("cache", None) | |
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache) | |
if _cache and isinstance(_cache, dict): | |
if _cache.get("use-cache", False) is True: | |
return True | |
return False | |
async def batch_cache_write(self, result, **kwargs): | |
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs) | |
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs) | |
async def ping(self): | |
cache_ping = getattr(self.cache, "ping") | |
if cache_ping: | |
return await cache_ping() | |
return None | |
async def delete_cache_keys(self, keys): | |
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys") | |
if cache_delete_cache_keys: | |
return await cache_delete_cache_keys(keys) | |
return None | |
async def disconnect(self): | |
if hasattr(self.cache, "disconnect"): | |
await self.cache.disconnect() | |
def _supports_async(self) -> bool: | |
""" | |
Internal method to check if the cache type supports async get/set operations | |
Only S3 Cache Does NOT support async operations | |
""" | |
if self.type and self.type == LiteLLMCacheType.S3: | |
return False | |
return True | |
def enable_cache( | |
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, | |
host: Optional[str] = None, | |
port: Optional[str] = None, | |
password: Optional[str] = None, | |
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ | |
"completion", | |
"acompletion", | |
"embedding", | |
"aembedding", | |
"atranscription", | |
"transcription", | |
"atext_completion", | |
"text_completion", | |
"arerank", | |
"rerank", | |
], | |
**kwargs, | |
): | |
""" | |
Enable cache with the specified configuration. | |
Args: | |
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local". | |
host (Optional[str]): The host address of the cache server. Defaults to None. | |
port (Optional[str]): The port number of the cache server. Defaults to None. | |
password (Optional[str]): The password for the cache server. Defaults to None. | |
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): | |
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. | |
**kwargs: Additional keyword arguments. | |
Returns: | |
None | |
Raises: | |
None | |
""" | |
print_verbose("LiteLLM: Enabling Cache") | |
if "cache" not in litellm.input_callback: | |
litellm.input_callback.append("cache") | |
if "cache" not in litellm.success_callback: | |
litellm.logging_callback_manager.add_litellm_success_callback("cache") | |
if "cache" not in litellm._async_success_callback: | |
litellm.logging_callback_manager.add_litellm_async_success_callback("cache") | |
if litellm.cache is None: | |
litellm.cache = Cache( | |
type=type, | |
host=host, | |
port=port, | |
password=password, | |
supported_call_types=supported_call_types, | |
**kwargs, | |
) | |
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}") | |
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") | |
def update_cache( | |
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL, | |
host: Optional[str] = None, | |
port: Optional[str] = None, | |
password: Optional[str] = None, | |
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [ | |
"completion", | |
"acompletion", | |
"embedding", | |
"aembedding", | |
"atranscription", | |
"transcription", | |
"atext_completion", | |
"text_completion", | |
"arerank", | |
"rerank", | |
], | |
**kwargs, | |
): | |
""" | |
Update the cache for LiteLLM. | |
Args: | |
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local". | |
host (Optional[str]): The host of the cache. Defaults to None. | |
port (Optional[str]): The port of the cache. Defaults to None. | |
password (Optional[str]): The password for the cache. Defaults to None. | |
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]): | |
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"]. | |
**kwargs: Additional keyword arguments for the cache. | |
Returns: | |
None | |
""" | |
print_verbose("LiteLLM: Updating Cache") | |
litellm.cache = Cache( | |
type=type, | |
host=host, | |
port=port, | |
password=password, | |
supported_call_types=supported_call_types, | |
**kwargs, | |
) | |
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}") | |
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}") | |
def disable_cache(): | |
""" | |
Disable the cache used by LiteLLM. | |
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None. | |
Parameters: | |
None | |
Returns: | |
None | |
""" | |
from contextlib import suppress | |
print_verbose("LiteLLM: Disabling Cache") | |
with suppress(ValueError): | |
litellm.input_callback.remove("cache") | |
litellm.success_callback.remove("cache") | |
litellm._async_success_callback.remove("cache") | |
litellm.cache = None | |
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}") | |