Rahkakavee Baskaran commited on
Commit
99fa459
1 Parent(s): c43e95c

add predictor

Browse files
Files changed (1) hide show
  1. semantic_search.py +16 -0
semantic_search.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from sentence_transformers import util
3
+ import torch
4
+
5
+
6
+ def predict(
7
+ query: str,
8
+ corpus_embeddings: torch.Tensor,
9
+ corpus_labels: list,
10
+ model: SentenceTransformer,
11
+ top_k: int = 5,
12
+ ) -> list:
13
+ query_embedding = model.encode([query])
14
+ result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
15
+ result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]]
16
+ return result_predictions