Spaces:
Sleeping
Sleeping
File size: 1,930 Bytes
268c7f9 3556e6f 268c7f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import List, Sequence, Tuple
from sentence_transformers import SentenceTransformer
from src.similarity_scorer import SimilarityScorer
from src.vectorizer import Vectorizer
class PromptSearchEngine:
def __init__(self, prompts: Sequence[str]) -> None:
"""Initialize search engine by vectorizing prompt corpus.
Vectorized prompt corpus should be used to find the top n most
similar prompts w.r.t. user’s input prompt.
Args:
prompts: The sequence of raw prompts from the dataset.
"""
self.vectorizer = Vectorizer(SentenceTransformer("all-MiniLM-L6-v2"))
self.scorer = SimilarityScorer()
self.prompts = prompts
self.embeddings = self.vectorizer.transform(prompts)
def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
"""Return top n most similar prompts from corpus.
Input query prompt should be vectorized with chosen Vectorizer.
After
that, use the cosine_similarity function to get the top n most
similar
prompts from the corpus.
Args:
query: The raw query prompt input from the user.
n: The number of similar prompts returned from the corpus.
Returns:
The list of top n most similar prompts from the corpus along
with similarity scores. Note that returned prompts are
verbatim.
"""
query_embedding = self.vectorizer.transform(query)
similarities = self.scorer.cosine_similarity(query_embedding, self.embeddings)
# Get the top n indices with highest similarity scores
top_n_indices = similarities.argsort()[-n:][::-1]
# Retrieve the top n most similar prompts along with their similarity scores
top_n_similar_prompts = [
(similarities[i], self.prompts[i]) for i in top_n_indices
]
return top_n_similar_prompts
|