Spaces:
Sleeping
Sleeping
from typing import List, Sequence, Tuple | |
from sentence_transformers import SentenceTransformer | |
from src.similarity_scorer import SimilarityScorer | |
from src.vectorizer import Vectorizer | |
class PromptSearchEngine: | |
def __init__(self, prompts: Sequence[str]) -> None: | |
"""Initialize search engine by vectorizing prompt corpus. | |
Vectorized prompt corpus should be used to find the top n most | |
similar prompts w.r.t. user’s input prompt. | |
Args: | |
prompts: The sequence of raw prompts from the dataset. | |
""" | |
self.vectorizer = Vectorizer(SentenceTransformer("all-MiniLM-L6-v2")) | |
self.scorer = SimilarityScorer() | |
self.prompts = prompts | |
self.embeddings = self.vectorizer.transform(prompts) | |
def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]: | |
"""Return top n most similar prompts from corpus. | |
Input query prompt should be vectorized with chosen Vectorizer. | |
After | |
that, use the cosine_similarity function to get the top n most | |
similar | |
prompts from the corpus. | |
Args: | |
query: The raw query prompt input from the user. | |
n: The number of similar prompts returned from the corpus. | |
Returns: | |
The list of top n most similar prompts from the corpus along | |
with similarity scores. Note that returned prompts are | |
verbatim. | |
""" | |
query_embedding = self.vectorizer.transform(query) | |
similarities = self.scorer.cosine_similarity(query_embedding, self.embeddings) | |
# Get the top n indices with highest similarity scores | |
top_n_indices = similarities.argsort()[-n:][::-1] | |
# Retrieve the top n most similar prompts along with their similarity scores | |
top_n_similar_prompts = [ | |
(similarities[i], self.prompts[i]) for i in top_n_indices | |
] | |
return top_n_similar_prompts | |