from sentence_transformers import SentenceTransformer from sentence_transformers import util import torch def predict( query: str, corpus_embeddings: torch.Tensor, corpus_labels: list, model: SentenceTransformer, top_k: int = 5, ) -> list: query_embedding = model.encode([query]) result = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k) result_predictions: list = [corpus_labels[el["corpus_id"]] for el in result[0]] return result_predictions