Spaces:
Sleeping
Sleeping
""" | |
Qdrant Semantic Cache implementation | |
Has 4 methods: | |
- set_cache | |
- get_cache | |
- async_set_cache | |
- async_get_cache | |
""" | |
import ast | |
import asyncio | |
import json | |
from typing import Any, cast | |
import litellm | |
from litellm._logging import print_verbose | |
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE | |
from litellm.types.utils import EmbeddingResponse | |
from .base_cache import BaseCache | |
class QdrantSemanticCache(BaseCache): | |
def __init__( # noqa: PLR0915 | |
self, | |
qdrant_api_base=None, | |
qdrant_api_key=None, | |
collection_name=None, | |
similarity_threshold=None, | |
quantization_config=None, | |
embedding_model="text-embedding-ada-002", | |
host_type=None, | |
): | |
import os | |
from litellm.llms.custom_httpx.http_handler import ( | |
_get_httpx_client, | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.secret_managers.main import get_secret_str | |
if collection_name is None: | |
raise Exception("collection_name must be provided, passed None") | |
self.collection_name = collection_name | |
print_verbose( | |
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}" | |
) | |
if similarity_threshold is None: | |
raise Exception("similarity_threshold must be provided, passed None") | |
self.similarity_threshold = similarity_threshold | |
self.embedding_model = embedding_model | |
headers = {} | |
# check if defined as os.environ/ variable | |
if qdrant_api_base: | |
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith( | |
"os.environ/" | |
): | |
qdrant_api_base = get_secret_str(qdrant_api_base) | |
if qdrant_api_key: | |
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith( | |
"os.environ/" | |
): | |
qdrant_api_key = get_secret_str(qdrant_api_key) | |
qdrant_api_base = ( | |
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE") | |
) | |
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY") | |
headers = {"Content-Type": "application/json"} | |
if qdrant_api_key: | |
headers["api-key"] = qdrant_api_key | |
if qdrant_api_base is None: | |
raise ValueError("Qdrant url must be provided") | |
self.qdrant_api_base = qdrant_api_base | |
self.qdrant_api_key = qdrant_api_key | |
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}") | |
self.headers = headers | |
self.sync_client = _get_httpx_client() | |
self.async_client = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.Caching | |
) | |
if quantization_config is None: | |
print_verbose( | |
"Quantization config is not provided. Default binary quantization will be used." | |
) | |
collection_exists = self.sync_client.get( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists", | |
headers=self.headers, | |
) | |
if collection_exists.status_code != 200: | |
raise ValueError( | |
f"Error from qdrant checking if /collections exist {collection_exists.text}" | |
) | |
if collection_exists.json()["result"]["exists"]: | |
collection_details = self.sync_client.get( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}", | |
headers=self.headers, | |
) | |
self.collection_info = collection_details.json() | |
print_verbose( | |
f"Collection already exists.\nCollection details:{self.collection_info}" | |
) | |
else: | |
if quantization_config is None or quantization_config == "binary": | |
quantization_params = { | |
"binary": { | |
"always_ram": False, | |
} | |
} | |
elif quantization_config == "scalar": | |
quantization_params = { | |
"scalar": { | |
"type": "int8", | |
"quantile": QDRANT_SCALAR_QUANTILE, | |
"always_ram": False, | |
} | |
} | |
elif quantization_config == "product": | |
quantization_params = { | |
"product": {"compression": "x16", "always_ram": False} | |
} | |
else: | |
raise Exception( | |
"Quantization config must be one of 'scalar', 'binary' or 'product'" | |
) | |
new_collection_status = self.sync_client.put( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}", | |
json={ | |
"vectors": {"size": QDRANT_VECTOR_SIZE, "distance": "Cosine"}, | |
"quantization_config": quantization_params, | |
}, | |
headers=self.headers, | |
) | |
if new_collection_status.json()["result"]: | |
collection_details = self.sync_client.get( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}", | |
headers=self.headers, | |
) | |
self.collection_info = collection_details.json() | |
print_verbose( | |
f"New collection created.\nCollection details:{self.collection_info}" | |
) | |
else: | |
raise Exception("Error while creating new collection") | |
def _get_cache_logic(self, cached_response: Any): | |
if cached_response is None: | |
return cached_response | |
try: | |
cached_response = json.loads( | |
cached_response | |
) # Convert string to dictionary | |
except Exception: | |
cached_response = ast.literal_eval(cached_response) | |
return cached_response | |
def set_cache(self, key, value, **kwargs): | |
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}") | |
import uuid | |
# get the prompt | |
messages = kwargs["messages"] | |
prompt = "" | |
for message in messages: | |
prompt += message["content"] | |
# create an embedding for prompt | |
embedding_response = cast( | |
EmbeddingResponse, | |
litellm.embedding( | |
model=self.embedding_model, | |
input=prompt, | |
cache={"no-store": True, "no-cache": True}, | |
), | |
) | |
# get the embedding | |
embedding = embedding_response["data"][0]["embedding"] | |
value = str(value) | |
assert isinstance(value, str) | |
data = { | |
"points": [ | |
{ | |
"id": str(uuid.uuid4()), | |
"vector": embedding, | |
"payload": { | |
"text": prompt, | |
"response": value, | |
}, | |
}, | |
] | |
} | |
self.sync_client.put( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", | |
headers=self.headers, | |
json=data, | |
) | |
return | |
def get_cache(self, key, **kwargs): | |
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}") | |
# get the messages | |
messages = kwargs["messages"] | |
prompt = "" | |
for message in messages: | |
prompt += message["content"] | |
# convert to embedding | |
embedding_response = cast( | |
EmbeddingResponse, | |
litellm.embedding( | |
model=self.embedding_model, | |
input=prompt, | |
cache={"no-store": True, "no-cache": True}, | |
), | |
) | |
# get the embedding | |
embedding = embedding_response["data"][0]["embedding"] | |
data = { | |
"vector": embedding, | |
"params": { | |
"quantization": { | |
"ignore": False, | |
"rescore": True, | |
"oversampling": 3.0, | |
} | |
}, | |
"limit": 1, | |
"with_payload": True, | |
} | |
search_response = self.sync_client.post( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", | |
headers=self.headers, | |
json=data, | |
) | |
results = search_response.json()["result"] | |
if results is None: | |
return None | |
if isinstance(results, list): | |
if len(results) == 0: | |
return None | |
similarity = results[0]["score"] | |
cached_prompt = results[0]["payload"]["text"] | |
# check similarity, if more than self.similarity_threshold, return results | |
print_verbose( | |
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" | |
) | |
if similarity >= self.similarity_threshold: | |
# cache hit ! | |
cached_value = results[0]["payload"]["response"] | |
print_verbose( | |
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" | |
) | |
return self._get_cache_logic(cached_response=cached_value) | |
else: | |
# cache miss ! | |
return None | |
pass | |
async def async_set_cache(self, key, value, **kwargs): | |
import uuid | |
from litellm.proxy.proxy_server import llm_model_list, llm_router | |
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") | |
# get the prompt | |
messages = kwargs["messages"] | |
prompt = "" | |
for message in messages: | |
prompt += message["content"] | |
# create an embedding for prompt | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
if llm_router is not None and self.embedding_model in router_model_names: | |
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") | |
embedding_response = await llm_router.aembedding( | |
model=self.embedding_model, | |
input=prompt, | |
cache={"no-store": True, "no-cache": True}, | |
metadata={ | |
"user_api_key": user_api_key, | |
"semantic-cache-embedding": True, | |
"trace_id": kwargs.get("metadata", {}).get("trace_id", None), | |
}, | |
) | |
else: | |
# convert to embedding | |
embedding_response = await litellm.aembedding( | |
model=self.embedding_model, | |
input=prompt, | |
cache={"no-store": True, "no-cache": True}, | |
) | |
# get the embedding | |
embedding = embedding_response["data"][0]["embedding"] | |
value = str(value) | |
assert isinstance(value, str) | |
data = { | |
"points": [ | |
{ | |
"id": str(uuid.uuid4()), | |
"vector": embedding, | |
"payload": { | |
"text": prompt, | |
"response": value, | |
}, | |
}, | |
] | |
} | |
await self.async_client.put( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", | |
headers=self.headers, | |
json=data, | |
) | |
return | |
async def async_get_cache(self, key, **kwargs): | |
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") | |
from litellm.proxy.proxy_server import llm_model_list, llm_router | |
# get the messages | |
messages = kwargs["messages"] | |
prompt = "" | |
for message in messages: | |
prompt += message["content"] | |
router_model_names = ( | |
[m["model_name"] for m in llm_model_list] | |
if llm_model_list is not None | |
else [] | |
) | |
if llm_router is not None and self.embedding_model in router_model_names: | |
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") | |
embedding_response = await llm_router.aembedding( | |
model=self.embedding_model, | |
input=prompt, | |
cache={"no-store": True, "no-cache": True}, | |
metadata={ | |
"user_api_key": user_api_key, | |
"semantic-cache-embedding": True, | |
"trace_id": kwargs.get("metadata", {}).get("trace_id", None), | |
}, | |
) | |
else: | |
# convert to embedding | |
embedding_response = await litellm.aembedding( | |
model=self.embedding_model, | |
input=prompt, | |
cache={"no-store": True, "no-cache": True}, | |
) | |
# get the embedding | |
embedding = embedding_response["data"][0]["embedding"] | |
data = { | |
"vector": embedding, | |
"params": { | |
"quantization": { | |
"ignore": False, | |
"rescore": True, | |
"oversampling": 3.0, | |
} | |
}, | |
"limit": 1, | |
"with_payload": True, | |
} | |
search_response = await self.async_client.post( | |
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", | |
headers=self.headers, | |
json=data, | |
) | |
results = search_response.json()["result"] | |
if results is None: | |
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 | |
return None | |
if isinstance(results, list): | |
if len(results) == 0: | |
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 | |
return None | |
similarity = results[0]["score"] | |
cached_prompt = results[0]["payload"]["text"] | |
# check similarity, if more than self.similarity_threshold, return results | |
print_verbose( | |
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" | |
) | |
# update kwargs["metadata"] with similarity, don't rewrite the original metadata | |
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity | |
if similarity >= self.similarity_threshold: | |
# cache hit ! | |
cached_value = results[0]["payload"]["response"] | |
print_verbose( | |
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" | |
) | |
return self._get_cache_logic(cached_response=cached_value) | |
else: | |
# cache miss ! | |
return None | |
pass | |
async def _collection_info(self): | |
return self.collection_info | |
async def async_set_cache_pipeline(self, cache_list, **kwargs): | |
tasks = [] | |
for val in cache_list: | |
tasks.append(self.async_set_cache(val[0], val[1], **kwargs)) | |
await asyncio.gather(*tasks) | |