|
from typing import Any |
|
|
|
from qdrant_client import QdrantClient, models |
|
from uuid import uuid4 |
|
from transformers import PreTrainedModel |
|
|
|
from src.config import QDRANT_COLLECTION_NAME, QDRANT_URL, QDRANT_API_KEY, EMBEDDING_MODEL |
|
from src.embeddings import TextEmbedder |
|
|
|
|
|
class QdrantStore: |
|
def __init__(self, client: QdrantClient, collection_config=None): |
|
self.client = client |
|
self.collection_names = set([i.name for i in client.get_collections().collections]) |
|
|
|
if collection_config is not None: |
|
self.create_collection(collection_config) |
|
|
|
def create_collection(self, collection_config: dict): |
|
collection_name = collection_config["collection_name"] |
|
if not self.client.collection_exists(collection_name): |
|
self.client.create_collection(**collection_config) |
|
self.collection_names.add(collection_name) |
|
|
|
def _check_collection_name(self, collection_name): |
|
if collection_name not in self.collection_names: |
|
raise ValueError(f"Collection: {collection_name} does not exist.") |
|
|
|
def upsert_points(self, |
|
vectors: Any | list[Any], |
|
payloads: dict | list[dict], |
|
collection_name: str): |
|
self._check_collection_name(collection_name) |
|
|
|
ids = [str(uuid4()) for _ in payloads] |
|
|
|
self.client.upsert( |
|
collection_name=collection_name, |
|
points=models.Batch( |
|
ids=ids, |
|
payloads=payloads, |
|
vectors=vectors |
|
) |
|
) |
|
|
|
def delete_points(self, |
|
filters: dict[str, list[models.FieldCondition]], |
|
collection_name: str): |
|
self._check_collection_name(collection_name) |
|
|
|
self.client.delete( |
|
collection_name=collection_name, |
|
points_selector=models.Filter(**filters) |
|
) |
|
|
|
def delete_points_by_match(self, |
|
key_value: tuple[str, list[str] | str], |
|
collection_name: str): |
|
key, values = key_value |
|
if isinstance(values, str): |
|
values = [values] |
|
filter = {"must": [models.FieldCondition(key=key, match=models.MatchAny(any=values))]} |
|
self.delete_points(filter, collection_name) |
|
|
|
def get_topk_points_single(self, |
|
query: Any | str, |
|
collection_name: str, |
|
k=5): |
|
responses = self.client.query_points(collection_name=collection_name, |
|
query=query, |
|
limit=k) |
|
|
|
return [i.payload["text"] for i in responses.points] |
|
|
|
|
|
if __name__ == '__main__': |
|
client = QdrantClient(QDRANT_URL, api_key=QDRANT_API_KEY) |
|
qdrant_store = QdrantStore(client) |
|
embedding_model = TextEmbedder(modelname=EMBEDDING_MODEL) |
|
query = "How to filter a dataframe" |
|
query_emb = embedding_model.embed_text(query) |
|
responses = qdrant_store.get_topk_points_single(query_emb[0], collection_name=QDRANT_COLLECTION_NAME) |
|
print(responses) |
|
|