from typing import List, Sequence, Tuple from sentence_transformers import SentenceTransformer from src.similarity_scorer import SimilarityScorer from src.vectorizer import Vectorizer class PromptSearchEngine: def __init__(self, prompts: Sequence[str]) -> None: """Initialize search engine by vectorizing prompt corpus. Vectorized prompt corpus should be used to find the top n most similar prompts w.r.t. user’s input prompt. Args: prompts: The sequence of raw prompts from the dataset. """ self.vectorizer = Vectorizer(SentenceTransformer("all-MiniLM-L6-v2")) self.scorer = SimilarityScorer() self.prompts = prompts self.embeddings = self.vectorizer.transform(prompts) def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]: """Return top n most similar prompts from corpus. Input query prompt should be vectorized with chosen Vectorizer. After that, use the cosine_similarity function to get the top n most similar prompts from the corpus. Args: query: The raw query prompt input from the user. n: The number of similar prompts returned from the corpus. Returns: The list of top n most similar prompts from the corpus along with similarity scores. Note that returned prompts are verbatim. """ query_embedding = self.vectorizer.transform(query) similarities = self.scorer.cosine_similarity(query_embedding, self.embeddings) # Get the top n indices with highest similarity scores top_n_indices = similarities.argsort()[-n:][::-1] # Retrieve the top n most similar prompts along with their similarity scores top_n_similar_prompts = [ (similarities[i], self.prompts[i]) for i in top_n_indices ] return top_n_similar_prompts