File size: 1,843 Bytes
40ae8d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from qdrant_client import QdrantClient
import streamlit as st

from qdrant_client.models import Filter
from sentence_transformers import SentenceTransformer

#
class NeuralSearcher:
    def __init__(self, collection_name):
        self.collection_name = collection_name
        # Initialize encoder model
        self.model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
        # initialize Qdrant client
         # self.qdrant_client = QdrantClient("http://localhost:6333")
        self.qdrant_client = QdrantClient(
            url="https://ed55d75f-bb54-4c09-8907-8d112e6278a1.us-east4-0.gcp.cloud.qdrant.io",
            api_key=st.secrets["QDRANT_API_KEY"],
        )

    def search(self, text: str):
        # Convert text query into vector
        vector = self.model.encode(text).tolist()

        #
        # city_of_interest = "Berlin"
        # # Define a filter for cities
        # city_filter = Filter(**{
        #     "must": [{
        #         "key": "country", # Store city information in a field of the same name 
        #         "match": { # This condition checks if payload field has the requested value
        #             "value": "London" 
        #         }
        #     }]
        # })
        #
        # Use `vector` for search for closest vectors in the collection
        search_result = self.qdrant_client.search(
            collection_name=self.collection_name,
            query_vector=vector,
            query_filter=None,  # If you don't want any filters for now
            limit=3,  # 5 the most closest results is enough
        )
        # `search_result` contains found vector ids with similarity scores along with the stored payload
        # In this function you are interested in payload only
        payloads = [hit.payload for hit in search_result]
        return payloads