Nuno Machado commited on
Commit
3f9c44b
1 Parent(s): 8d8e1b1

Add faiss search engine

Browse files
Files changed (2) hide show
  1. search/__init__.py +0 -0
  2. search/faiss.py +27 -0
search/__init__.py ADDED
File without changes
search/faiss.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from datasets import Dataset
4
+ from embeddings.encoder import EmbeddingEncoder
5
+
6
+
7
+ class FaissSearchEngine:
8
+ def __init__(self, embeddings: Dataset, encoder: EmbeddingEncoder):
9
+ self.embeddings = embeddings
10
+ # assume dataset has a column "embeddings" with the embeddings
11
+ self.embeddings.add_faiss_index(column="embeddings")
12
+ self.encoder = encoder
13
+
14
+ def search(self, query, k=5):
15
+ # Encode the query using the same model that was used to generate the embeddings
16
+ query_embedding = self.encoder.generate_embeddings(query).numpy()
17
+
18
+ # Search the index using FAISS
19
+ scores, samples = self.embeddings.get_nearest_examples("embeddings", query_embedding, k)
20
+
21
+ # Return the results as a list of dictionaries
22
+ samples_df = pd.DataFrame.from_dict(samples)
23
+ samples_df["scores"] = scores
24
+ #samples_df.sort_values("scores", ascending=False, inplace=True)
25
+ samples_df = samples_df.drop("embeddings", axis=1)
26
+
27
+ return samples_df