File size: 819 Bytes
ff4ec71
9ec02db
 
 
ff4ec71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ec02db
 
ff4ec71
9ec02db
ff4ec71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

from backend.utils import load_embeddings, load_model, load_texts


# Search
def query_search(query: str, n_answers: int, model_name: str):
    model = load_model(model_name)

    # Creating embeddings
    # query_emb = model.encode(query, convert_to_tensor=True)[None, :]
    query_emb = model.encode(query, convert_to_tensor=True)

    print("loading embedding")
    corpus_emb = load_embeddings()
    corpus_texts = load_texts()

    # Getting hits
    hits = torch.nn.functional.cosine_similarity(
        query_emb[None, :], corpus_emb, dim=1, eps=1e-8
    )

    corpus_texts["Similarity"] = hits.tolist()

    print(corpus_texts)

    return corpus_texts.sort_values(by="Similarity", ascending=False).head(n_answers)[
        ["func_documentation_string", "repository_name", "func_code_url"]
    ]