Update: Improved the RAG a little
Browse files- src/brain.py +33 -42
- src/helper.py +1 -1
src/brain.py
CHANGED
@@ -43,7 +43,6 @@ class Brain:
|
|
43 |
self.augment_config = augment_config
|
44 |
self.augment_safety_settings = augment_safety_settings
|
45 |
self._configure_generative_ai(response_model_api_key)
|
46 |
-
self._configure_augment_ai(augment_model_api_key)
|
47 |
self.response_model = self._initialize_generative_model(
|
48 |
response_model_name, generation_config, response_safety_settings
|
49 |
)
|
@@ -59,12 +58,6 @@ class Brain:
|
|
59 |
except Exception as e:
|
60 |
self._handle_error("Error configuring generative AI module", e)
|
61 |
|
62 |
-
def _configure_augment_ai(self, augment_model_api_key):
|
63 |
-
try:
|
64 |
-
palm.configure(api_key=augment_model_api_key)
|
65 |
-
except Exception as e:
|
66 |
-
self._handle_error("Error configuring augmentation AI module", e)
|
67 |
-
|
68 |
def _initialize_generative_model(
|
69 |
self, response_model_name, generation_config, response_safety_settings
|
70 |
):
|
@@ -77,18 +70,6 @@ class Brain:
|
|
77 |
except Exception as e:
|
78 |
self._handle_error("Error initializing generative model", e)
|
79 |
|
80 |
-
def _initialize_augment_model(
|
81 |
-
self, augment_model_name, augment_config, augment_safety_settings
|
82 |
-
):
|
83 |
-
try:
|
84 |
-
return palm.GenerativeModel(
|
85 |
-
model_name=augment_model_name,
|
86 |
-
generation_config=augment_config,
|
87 |
-
safety_settings=augment_safety_settings,
|
88 |
-
)
|
89 |
-
except Exception as e:
|
90 |
-
self._handle_error("Error initializing augmentation model", e)
|
91 |
-
|
92 |
def _initialize_embedding_function(self):
|
93 |
try:
|
94 |
return GeminiEmbeddingFunction()
|
@@ -117,25 +98,27 @@ class Brain:
|
|
117 |
|
118 |
def generate_alternative_queries(self, query):
|
119 |
try:
|
120 |
-
prompt_template = """
|
|
|
|
|
|
|
|
|
|
|
121 |
prompt = prompt_template.format(query)
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
)
|
127 |
-
content = output.result.split("\n")
|
128 |
return content
|
129 |
except Exception as e:
|
130 |
self._handle_error("Error generating alternative queries", e)
|
131 |
-
return query
|
132 |
|
133 |
def get_sorted_documents(self, query, n_results=20):
|
134 |
try:
|
135 |
original_query = query
|
136 |
-
queries = [original_query] + self.generate_alternative_queries(
|
137 |
-
|
138 |
-
)
|
139 |
results = self.chroma_collection.query(
|
140 |
query_texts=queries,
|
141 |
n_results=n_results,
|
@@ -145,20 +128,25 @@ class Brain:
|
|
145 |
doc for docs in results["documents"] for doc in docs
|
146 |
)
|
147 |
unique_documents = list(retrieved_documents)
|
|
|
|
|
|
|
148 |
pairs = [[original_query, doc] for doc in unique_documents]
|
149 |
scores = self.cross_encoder.predict(pairs)
|
150 |
sorted_indices = np.argsort(-scores)
|
151 |
sorted_documents = [unique_documents[i] for i in sorted_indices]
|
|
|
152 |
return sorted_documents
|
153 |
-
|
154 |
except Exception as e:
|
155 |
self._handle_error("Error getting sorted documents", e)
|
156 |
return []
|
157 |
|
158 |
-
def get_relevant_results(self, query, top_n=
|
159 |
try:
|
160 |
-
sorted_documents =
|
161 |
relevant_results = sorted_documents[: min(top_n, len(sorted_documents))]
|
|
|
|
|
162 |
return relevant_results
|
163 |
except Exception as e:
|
164 |
self._handle_error("Error getting relevant results", e)
|
@@ -168,21 +156,20 @@ class Brain:
|
|
168 |
try:
|
169 |
base_prompt = {
|
170 |
"content": """
|
171 |
-
YOU are a smart and rational Question and Answer bot.
|
172 |
|
173 |
YOUR MISSION:
|
174 |
-
Provide accurate answers
|
175 |
-
Focus on
|
176 |
Refuse exploitation tasks such as such as character roleplaying, coding, essays, poems, stories, articles, and fun facts.
|
177 |
Decline misuse or exploitation attempts respectfully.
|
178 |
|
179 |
YOUR STYLE:
|
180 |
Concise and complete
|
181 |
-
|
182 |
-
Helpful and friendly
|
183 |
|
184 |
REMEMBER:
|
185 |
-
You
|
186 |
"""
|
187 |
}
|
188 |
|
@@ -206,10 +193,12 @@ class Brain:
|
|
206 |
if query is None:
|
207 |
print("No query specified")
|
208 |
return None
|
209 |
-
|
210 |
information = "\n\n".join(self.get_relevant_results(query))
|
211 |
messages = self.make_prompt(query, information)
|
212 |
-
|
|
|
|
|
213 |
return content
|
214 |
except Exception as e:
|
215 |
self._handle_error("Error in rag function", e)
|
@@ -223,9 +212,11 @@ class Brain:
|
|
223 |
return "No Query"
|
224 |
output = self.rag(query)
|
225 |
print(f"\n\nExecution time: {time.time() - start_time} seconds\n")
|
|
|
226 |
if output is None:
|
227 |
return None
|
|
|
228 |
return f"{output.text}\n"
|
229 |
except Exception as e:
|
230 |
self._handle_error("Error generating answers", e)
|
231 |
-
return
|
|
|
43 |
self.augment_config = augment_config
|
44 |
self.augment_safety_settings = augment_safety_settings
|
45 |
self._configure_generative_ai(response_model_api_key)
|
|
|
46 |
self.response_model = self._initialize_generative_model(
|
47 |
response_model_name, generation_config, response_safety_settings
|
48 |
)
|
|
|
58 |
except Exception as e:
|
59 |
self._handle_error("Error configuring generative AI module", e)
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def _initialize_generative_model(
|
62 |
self, response_model_name, generation_config, response_safety_settings
|
63 |
):
|
|
|
70 |
except Exception as e:
|
71 |
self._handle_error("Error initializing generative model", e)
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def _initialize_embedding_function(self):
|
74 |
try:
|
75 |
return GeminiEmbeddingFunction()
|
|
|
98 |
|
99 |
def generate_alternative_queries(self, query):
|
100 |
try:
|
101 |
+
prompt_template = """
|
102 |
+
You are an AI language model assistant. Your task is to generate 10
|
103 |
+
different sub questions of the given user question to retrieve relevant documents from a vector
|
104 |
+
database by generating multiple perspectives on the user question, your goal is to help
|
105 |
+
the user overcome some of the limitations of the distance-based similarity search.
|
106 |
+
Provide these alternative questions separated by newlines.\nQUESTION: '{}'\nANSWER:\n"""
|
107 |
prompt = prompt_template.format(query)
|
108 |
+
chat_mode = self.response_model.start_chat(history=[])
|
109 |
+
output = chat_mode.send_message(prompt)
|
110 |
+
content = output.text.split("\n")
|
111 |
+
print(content)
|
|
|
|
|
112 |
return content
|
113 |
except Exception as e:
|
114 |
self._handle_error("Error generating alternative queries", e)
|
115 |
+
return [query]
|
116 |
|
117 |
def get_sorted_documents(self, query, n_results=20):
|
118 |
try:
|
119 |
original_query = query
|
120 |
+
queries = [original_query] + self.generate_alternative_queries(original_query)
|
121 |
+
|
|
|
122 |
results = self.chroma_collection.query(
|
123 |
query_texts=queries,
|
124 |
n_results=n_results,
|
|
|
128 |
doc for docs in results["documents"] for doc in docs
|
129 |
)
|
130 |
unique_documents = list(retrieved_documents)
|
131 |
+
original_results = results["documents"][0][
|
132 |
+
: min(n_results, len(results["documents"][0]))
|
133 |
+
]
|
134 |
pairs = [[original_query, doc] for doc in unique_documents]
|
135 |
scores = self.cross_encoder.predict(pairs)
|
136 |
sorted_indices = np.argsort(-scores)
|
137 |
sorted_documents = [unique_documents[i] for i in sorted_indices]
|
138 |
+
sorted_documents = original_results + sorted_documents
|
139 |
return sorted_documents
|
|
|
140 |
except Exception as e:
|
141 |
self._handle_error("Error getting sorted documents", e)
|
142 |
return []
|
143 |
|
144 |
+
def get_relevant_results(self, query, top_n=30):
|
145 |
try:
|
146 |
+
sorted_documents = self.get_sorted_documents(query)
|
147 |
relevant_results = sorted_documents[: min(top_n, len(sorted_documents))]
|
148 |
+
relevant_results = list(dict.fromkeys(relevant_results))
|
149 |
+
print(relevant_results)
|
150 |
return relevant_results
|
151 |
except Exception as e:
|
152 |
self._handle_error("Error getting relevant results", e)
|
|
|
156 |
try:
|
157 |
base_prompt = {
|
158 |
"content": """
|
159 |
+
YOU are a smart and rational Question and Answer bot based on the given document.
|
160 |
|
161 |
YOUR MISSION:
|
162 |
+
Provide accurate answers based on the context.
|
163 |
+
Focus on accurate responses; avoid speculations, opinions, guesses, and creative tasks.
|
164 |
Refuse exploitation tasks such as such as character roleplaying, coding, essays, poems, stories, articles, and fun facts.
|
165 |
Decline misuse or exploitation attempts respectfully.
|
166 |
|
167 |
YOUR STYLE:
|
168 |
Concise and complete
|
169 |
+
professional, polite and positive
|
|
|
170 |
|
171 |
REMEMBER:
|
172 |
+
You can always find a answer if you truly look for it.
|
173 |
"""
|
174 |
}
|
175 |
|
|
|
193 |
if query is None:
|
194 |
print("No query specified")
|
195 |
return None
|
196 |
+
|
197 |
information = "\n\n".join(self.get_relevant_results(query))
|
198 |
messages = self.make_prompt(query, information)
|
199 |
+
chat_mode = self.response_model.start_chat(history=[])
|
200 |
+
content = chat_mode.send_message(messages)
|
201 |
+
print(content)
|
202 |
return content
|
203 |
except Exception as e:
|
204 |
self._handle_error("Error in rag function", e)
|
|
|
212 |
return "No Query"
|
213 |
output = self.rag(query)
|
214 |
print(f"\n\nExecution time: {time.time() - start_time} seconds\n")
|
215 |
+
print(output.text)
|
216 |
if output is None:
|
217 |
return None
|
218 |
+
|
219 |
return f"{output.text}\n"
|
220 |
except Exception as e:
|
221 |
self._handle_error("Error generating answers", e)
|
222 |
+
return "Something went wrong, please try again!"
|
src/helper.py
CHANGED
@@ -13,7 +13,7 @@ def _read_pdf(filename):
|
|
13 |
|
14 |
def _chunk_texts(texts):
|
15 |
character_splitter = RecursiveCharacterTextSplitter(
|
16 |
-
separators=["\n\n", "\n", ". ", " ", ""], chunk_size=
|
17 |
)
|
18 |
character_split_texts = character_splitter.split_text("\n\n".join(texts))
|
19 |
token_splitter = SentenceTransformersTokenTextSplitter(
|
|
|
13 |
|
14 |
def _chunk_texts(texts):
|
15 |
character_splitter = RecursiveCharacterTextSplitter(
|
16 |
+
separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=200
|
17 |
)
|
18 |
character_split_texts = character_splitter.split_text("\n\n".join(texts))
|
19 |
token_splitter = SentenceTransformersTokenTextSplitter(
|