File size: 956 Bytes
70a4e1e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import numpy as np
def vector_search(query, model, index, num_results=10):
"""Tranforms query to vector using a pretrained, sentence-level
DistilBERT model and finds similar vectors using FAISS.
Args:
query (str): User query that should be more than a sentence long.
model (sentence_transformers.SentenceTransformer.SentenceTransformer)
index (`numpy.ndarray`): FAISS index that needs to be deserialized.
num_results (int): Number of results to return.
Returns:
D (:obj:`numpy.array` of `float`): Distance between results and query.
I (:obj:`numpy.array` of `int`): Paper ID of the results.
"""
vector = model.encode(list(query))
D, I = index.search(np.array(vector).astype("float32"), k=num_results)
return D, I
def id2details(df, I, column):
"""Returns the paper titles based on the paper index."""
return [list(df[df.rid == idx][column]) for idx in I[0]]
|