|
import torch |
|
from backend.utils import load_model, load_embeddings, load_texts |
|
|
|
|
|
def query_search(query: str, n_answers: int, model_name: str): |
|
model = load_model(model_name) |
|
|
|
|
|
|
|
query_emb = model.encode(query, convert_to_tensor=True) |
|
|
|
print("loading embedding") |
|
corpus_emb = load_embeddings() |
|
corpus_texts = load_texts() |
|
|
|
|
|
hits = torch.nn.functional.cosine_similarity( |
|
query_emb[None, :], corpus_emb, dim=1, eps=1e-8 |
|
) |
|
|
|
corpus_texts["Similarity"] = hits.tolist() |
|
|
|
return corpus_texts.sort_values(by="Similarity", ascending=False).head(n_answers)[ |
|
["Description", "Code"] |
|
] |
|
|