Musterdatenkatalog / semantic_search.py
Rahkakavee Baskaran
add predictor
99fa459
raw
history blame
503 Bytes
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import torch
def predict(
query: str,
corpus_embeddings: torch.Tensor,
corpus_labels: list,
model: SentenceTransformer,
top_k: int = 5,
) -> list:
query_embedding = model.encode([query])
result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]]
return result_predictions