File size: 1,930 Bytes
268c7f9
 
 
 
 
3556e6f
268c7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import List, Sequence, Tuple

from sentence_transformers import SentenceTransformer

from src.similarity_scorer import SimilarityScorer
from src.vectorizer import Vectorizer


class PromptSearchEngine:
    def __init__(self, prompts: Sequence[str]) -> None:
        """Initialize search engine by vectorizing prompt corpus.
        Vectorized prompt corpus should be used to find the top n most
        similar prompts w.r.t. user’s input prompt.
        Args:
        prompts: The sequence of raw prompts from the dataset.
        """
        self.vectorizer = Vectorizer(SentenceTransformer("all-MiniLM-L6-v2"))
        self.scorer = SimilarityScorer()

        self.prompts = prompts
        self.embeddings = self.vectorizer.transform(prompts)

    def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
        """Return top n most similar prompts from corpus.
        Input query prompt should be vectorized with chosen Vectorizer.
        After
        that, use the cosine_similarity function to get the top n most
        similar
        prompts from the corpus.
        Args:
        query: The raw query prompt input from the user.
        n: The number of similar prompts returned from the corpus.
        Returns:
        The list of top n most similar prompts from the corpus along
        with similarity scores. Note that returned prompts are
        verbatim.
        """
        query_embedding = self.vectorizer.transform(query)

        similarities = self.scorer.cosine_similarity(query_embedding, self.embeddings)

        # Get the top n indices with highest similarity scores
        top_n_indices = similarities.argsort()[-n:][::-1]

        # Retrieve the top n most similar prompts along with their similarity scores
        top_n_similar_prompts = [
            (similarities[i], self.prompts[i]) for i in top_n_indices
        ]

        return top_n_similar_prompts