Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
from sentence_transformers import util | |
import torch | |
def predict( | |
query: str, | |
corpus_embeddings: torch.Tensor, | |
corpus_labels: list, | |
model: SentenceTransformer, | |
top_k: int = 5, | |
) -> list: | |
query_embedding = model.encode([query]) | |
result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k) | |
result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]] | |
return result_predictions | |