asif00 commited on
Commit
8a872cd
1 Parent(s): b8dc14b

Update: Improved the RAG a little

Browse files
Files changed (2) hide show
  1. src/brain.py +33 -42
  2. src/helper.py +1 -1
src/brain.py CHANGED
@@ -43,7 +43,6 @@ class Brain:
43
  self.augment_config = augment_config
44
  self.augment_safety_settings = augment_safety_settings
45
  self._configure_generative_ai(response_model_api_key)
46
- self._configure_augment_ai(augment_model_api_key)
47
  self.response_model = self._initialize_generative_model(
48
  response_model_name, generation_config, response_safety_settings
49
  )
@@ -59,12 +58,6 @@ class Brain:
59
  except Exception as e:
60
  self._handle_error("Error configuring generative AI module", e)
61
 
62
- def _configure_augment_ai(self, augment_model_api_key):
63
- try:
64
- palm.configure(api_key=augment_model_api_key)
65
- except Exception as e:
66
- self._handle_error("Error configuring augmentation AI module", e)
67
-
68
  def _initialize_generative_model(
69
  self, response_model_name, generation_config, response_safety_settings
70
  ):
@@ -77,18 +70,6 @@ class Brain:
77
  except Exception as e:
78
  self._handle_error("Error initializing generative model", e)
79
 
80
- def _initialize_augment_model(
81
- self, augment_model_name, augment_config, augment_safety_settings
82
- ):
83
- try:
84
- return palm.GenerativeModel(
85
- model_name=augment_model_name,
86
- generation_config=augment_config,
87
- safety_settings=augment_safety_settings,
88
- )
89
- except Exception as e:
90
- self._handle_error("Error initializing augmentation model", e)
91
-
92
  def _initialize_embedding_function(self):
93
  try:
94
  return GeminiEmbeddingFunction()
@@ -117,25 +98,27 @@ class Brain:
117
 
118
  def generate_alternative_queries(self, query):
119
  try:
120
- prompt_template = """Your task is to break down the query in sub questions in ten different ways. Output one sub question per line, without numbering the queries.\nQUESTION: '{}'\nANSWER:\n"""
 
 
 
 
 
121
  prompt = prompt_template.format(query)
122
- output = palm.generate_text(
123
- model=self.augment_model_name,
124
- prompt=prompt,
125
- safety_settings=self.augment_safety_settings,
126
- )
127
- content = output.result.split("\n")
128
  return content
129
  except Exception as e:
130
  self._handle_error("Error generating alternative queries", e)
131
- return query
132
 
133
  def get_sorted_documents(self, query, n_results=20):
134
  try:
135
  original_query = query
136
- queries = [original_query] + self.generate_alternative_queries(
137
- original_query
138
- )
139
  results = self.chroma_collection.query(
140
  query_texts=queries,
141
  n_results=n_results,
@@ -145,20 +128,25 @@ class Brain:
145
  doc for docs in results["documents"] for doc in docs
146
  )
147
  unique_documents = list(retrieved_documents)
 
 
 
148
  pairs = [[original_query, doc] for doc in unique_documents]
149
  scores = self.cross_encoder.predict(pairs)
150
  sorted_indices = np.argsort(-scores)
151
  sorted_documents = [unique_documents[i] for i in sorted_indices]
 
152
  return sorted_documents
153
-
154
  except Exception as e:
155
  self._handle_error("Error getting sorted documents", e)
156
  return []
157
 
158
- def get_relevant_results(self, query, top_n=5):
159
  try:
160
- sorted_documents = self.get_sorted_documents(query)
161
  relevant_results = sorted_documents[: min(top_n, len(sorted_documents))]
 
 
162
  return relevant_results
163
  except Exception as e:
164
  self._handle_error("Error getting relevant results", e)
@@ -168,21 +156,20 @@ class Brain:
168
  try:
169
  base_prompt = {
170
  "content": """
171
- YOU are a smart and rational Question and Answer bot.
172
 
173
  YOUR MISSION:
174
- Provide accurate answers best possible reasoning of the context.
175
- Focus on factual and reasoned responses; avoid speculations, opinions, guesses, and creative tasks.
176
  Refuse exploitation tasks such as such as character roleplaying, coding, essays, poems, stories, articles, and fun facts.
177
  Decline misuse or exploitation attempts respectfully.
178
 
179
  YOUR STYLE:
180
  Concise and complete
181
- Factual and accurate
182
- Helpful and friendly
183
 
184
  REMEMBER:
185
- You are a QA bot, not an entertainer or confidant.
186
  """
187
  }
188
 
@@ -206,10 +193,12 @@ class Brain:
206
  if query is None:
207
  print("No query specified")
208
  return None
209
-
210
  information = "\n\n".join(self.get_relevant_results(query))
211
  messages = self.make_prompt(query, information)
212
- content = self.response_model.generate_content(messages)
 
 
213
  return content
214
  except Exception as e:
215
  self._handle_error("Error in rag function", e)
@@ -223,9 +212,11 @@ class Brain:
223
  return "No Query"
224
  output = self.rag(query)
225
  print(f"\n\nExecution time: {time.time() - start_time} seconds\n")
 
226
  if output is None:
227
  return None
 
228
  return f"{output.text}\n"
229
  except Exception as e:
230
  self._handle_error("Error generating answers", e)
231
- return None
 
43
  self.augment_config = augment_config
44
  self.augment_safety_settings = augment_safety_settings
45
  self._configure_generative_ai(response_model_api_key)
 
46
  self.response_model = self._initialize_generative_model(
47
  response_model_name, generation_config, response_safety_settings
48
  )
 
58
  except Exception as e:
59
  self._handle_error("Error configuring generative AI module", e)
60
 
 
 
 
 
 
 
61
  def _initialize_generative_model(
62
  self, response_model_name, generation_config, response_safety_settings
63
  ):
 
70
  except Exception as e:
71
  self._handle_error("Error initializing generative model", e)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def _initialize_embedding_function(self):
74
  try:
75
  return GeminiEmbeddingFunction()
 
98
 
99
  def generate_alternative_queries(self, query):
100
  try:
101
+ prompt_template = """
102
+ You are an AI language model assistant. Your task is to generate 10
103
+ different sub questions of the given user question to retrieve relevant documents from a vector
104
+ database by generating multiple perspectives on the user question, your goal is to help
105
+ the user overcome some of the limitations of the distance-based similarity search.
106
+ Provide these alternative questions separated by newlines.\nQUESTION: '{}'\nANSWER:\n"""
107
  prompt = prompt_template.format(query)
108
+ chat_mode = self.response_model.start_chat(history=[])
109
+ output = chat_mode.send_message(prompt)
110
+ content = output.text.split("\n")
111
+ print(content)
 
 
112
  return content
113
  except Exception as e:
114
  self._handle_error("Error generating alternative queries", e)
115
+ return [query]
116
 
117
  def get_sorted_documents(self, query, n_results=20):
118
  try:
119
  original_query = query
120
+ queries = [original_query] + self.generate_alternative_queries(original_query)
121
+
 
122
  results = self.chroma_collection.query(
123
  query_texts=queries,
124
  n_results=n_results,
 
128
  doc for docs in results["documents"] for doc in docs
129
  )
130
  unique_documents = list(retrieved_documents)
131
+ original_results = results["documents"][0][
132
+ : min(n_results, len(results["documents"][0]))
133
+ ]
134
  pairs = [[original_query, doc] for doc in unique_documents]
135
  scores = self.cross_encoder.predict(pairs)
136
  sorted_indices = np.argsort(-scores)
137
  sorted_documents = [unique_documents[i] for i in sorted_indices]
138
+ sorted_documents = original_results + sorted_documents
139
  return sorted_documents
 
140
  except Exception as e:
141
  self._handle_error("Error getting sorted documents", e)
142
  return []
143
 
144
+ def get_relevant_results(self, query, top_n=30):
145
  try:
146
+ sorted_documents = self.get_sorted_documents(query)
147
  relevant_results = sorted_documents[: min(top_n, len(sorted_documents))]
148
+ relevant_results = list(dict.fromkeys(relevant_results))
149
+ print(relevant_results)
150
  return relevant_results
151
  except Exception as e:
152
  self._handle_error("Error getting relevant results", e)
 
156
  try:
157
  base_prompt = {
158
  "content": """
159
+ YOU are a smart and rational Question and Answer bot based on the given document.
160
 
161
  YOUR MISSION:
162
+ Provide accurate answers based on the context.
163
+ Focus on accurate responses; avoid speculations, opinions, guesses, and creative tasks.
164
  Refuse exploitation tasks such as such as character roleplaying, coding, essays, poems, stories, articles, and fun facts.
165
  Decline misuse or exploitation attempts respectfully.
166
 
167
  YOUR STYLE:
168
  Concise and complete
169
+ professional, polite and positive
 
170
 
171
  REMEMBER:
172
+ You can always find a answer if you truly look for it.
173
  """
174
  }
175
 
 
193
  if query is None:
194
  print("No query specified")
195
  return None
196
+
197
  information = "\n\n".join(self.get_relevant_results(query))
198
  messages = self.make_prompt(query, information)
199
+ chat_mode = self.response_model.start_chat(history=[])
200
+ content = chat_mode.send_message(messages)
201
+ print(content)
202
  return content
203
  except Exception as e:
204
  self._handle_error("Error in rag function", e)
 
212
  return "No Query"
213
  output = self.rag(query)
214
  print(f"\n\nExecution time: {time.time() - start_time} seconds\n")
215
+ print(output.text)
216
  if output is None:
217
  return None
218
+
219
  return f"{output.text}\n"
220
  except Exception as e:
221
  self._handle_error("Error generating answers", e)
222
+ return "Something went wrong, please try again!"
src/helper.py CHANGED
@@ -13,7 +13,7 @@ def _read_pdf(filename):
13
 
14
  def _chunk_texts(texts):
15
  character_splitter = RecursiveCharacterTextSplitter(
16
- separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1600, chunk_overlap=200
17
  )
18
  character_split_texts = character_splitter.split_text("\n\n".join(texts))
19
  token_splitter = SentenceTransformersTokenTextSplitter(
 
13
 
14
  def _chunk_texts(texts):
15
  character_splitter = RecursiveCharacterTextSplitter(
16
+ separators=["\n\n", "\n", ". ", " ", ""], chunk_size=1000, chunk_overlap=200
17
  )
18
  character_split_texts = character_splitter.split_text("\n\n".join(texts))
19
  token_splitter = SentenceTransformersTokenTextSplitter(