alexpantex commited on
Commit
5e9edd3
·
verified ·
1 Parent(s): ef09277

Upload scripts/prompt_engine.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/prompt_engine.py +104 -0
scripts/prompt_engine.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(sys.path[0].replace('scripts', ''))
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ from config.data_paths import VECTORDB_PATH
8
+
9
+ from typing import Sequence, List, Tuple
10
+ import faiss
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+
14
+ class Vectorizer:
15
+ def __init__(self, model_name: str) -> None:
16
+ """
17
+ Initialize the vectorizer with a pre-trained embedding model.
18
+ Args:
19
+ model_name: The name of the pre-trained embedding model (compatible with sentence-transformers).
20
+ """
21
+ self.model = SentenceTransformer(model_name)
22
+
23
+ def transform(self, prompts: Sequence[str], build_index=False) -> np.ndarray:
24
+ """
25
+ Transform texts into numerical vectors using the specified model.
26
+ Args:
27
+ prompts: The sequence of raw corpus prompts.
28
+ Returns:
29
+ Vectorized prompts as a numpy array.
30
+ """
31
+ embeddings = self.model.encode(prompts, show_progress_bar=True)
32
+ embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # normalize embeddings
33
+ if build_index:
34
+ # self.embeddings=embeddings
35
+ if os.path.isfile(os.path.join(VECTORDB_PATH, 'prompts_index.faiss')):
36
+ print('Embeddings already stored in vector db')
37
+ else:
38
+ index = self._build_index(embeddings)
39
+ faiss.write_index(index, os.path.join(VECTORDB_PATH, 'prompts_index.faiss'))
40
+ else:
41
+ return embeddings
42
+
43
+ def _build_index(self, embeddings: np.ndarray) -> faiss.IndexFlatIP:
44
+ """
45
+ Build and return a FAISS index for the given embeddings.
46
+ Args:
47
+ embeddings: A numpy array of prompt embeddings.
48
+ Returns:
49
+ FAISS index for efficient similarity search.
50
+ """
51
+ index = faiss.IndexFlatIP(embeddings.shape[1]) # Cosine similarity (IP on normalized vectors)
52
+ index.add(embeddings)
53
+ return index
54
+
55
+ def cosine_similarity(query_vector: np.ndarray, corpus_vectors: np.ndarray) -> np.ndarray:
56
+ """
57
+ Calculate cosine similarity between prompt vectors.
58
+ Args:
59
+ query_vector: Vectorized prompt query of shape (1, D).
60
+ corpus_vectors: Vectorized prompt corpus of shape (N, D).
61
+ Returns:
62
+ A vector of shape (N,) with values in range [-1, 1] where 1 is maximum similarity.
63
+ """
64
+ return np.dot(corpus_vectors, query_vector.T).flatten()
65
+
66
+ class PromptSearchEngine:
67
+ def __init__(self, corpus: str, model_name: str = 'all-MiniLM-L6-v2', use_index=False) -> None:
68
+ """
69
+ Initialize search engine by vectorizing prompt corpus.
70
+ Vectorized prompt corpus should be used to find the top n most similar prompts.
71
+ Args:
72
+ corpus: Path to the parquet dataset with raw prompts.
73
+ model_name: The name of the pre-trained embedding model.
74
+ """
75
+ self.use_index=use_index
76
+ self.prompts=pd.read_parquet(corpus)['prompt'].to_list()
77
+ self.prompts=self.prompts# if use_index else np.random.choice(self.prompts, 1000, replace=False)
78
+ self.vectorizer = Vectorizer(model_name)
79
+ self.embeddings = self.vectorizer.transform(self.prompts,
80
+ build_index=use_index) # build index initially for faster retrieval
81
+ if use_index:
82
+ self.index = faiss.read_index(os.path.join(VECTORDB_PATH, 'prompts_index.faiss'))
83
+
84
+ def most_similar(self, query: str, n: int = 5) -> List[Tuple[float, str]]:
85
+ """
86
+ Return top n most similar prompts from the corpus.
87
+ Input query prompt is vectorized using the Vectorizer. After that, use the cosine_similarity
88
+ function to get the top n most similar prompts from the corpus.
89
+ Args:
90
+ query: The raw query prompt input from the user.
91
+ n: The number of similar prompts to return from the corpus.
92
+ Returns:
93
+ The list of top n most similar prompts from the corpus along with similarity scores.
94
+ Note that returned prompts are verbatim.
95
+ """
96
+ query_vector = self.vectorizer.transform([query])
97
+ if self.use_index:
98
+ distances, indices = self.index.search(query_vector, n)
99
+ results = [{'prompt': self.prompts[idx], 'score': distances[0][i]} for i, idx in enumerate(indices[0])]
100
+ return results
101
+ else:
102
+ similarities = cosine_similarity(query_vector, self.embeddings)
103
+ top_indices = np.argsort(-similarities)[:n] # Sort in descending order
104
+ return [{'prompt': self.prompts[i], 'score': similarities[i]} for i in top_indices]