Théo ALVES DA COSTA commited on
Commit
37b1e7a
1 Parent(s): e92f501

Fixed bugs with multi LLMs

Browse files
Files changed (3) hide show
  1. app.py +7 -57
  2. climateqa/engine/rag.py +8 -17
  3. climateqa/engine/utils.py +23 -6
app.py CHANGED
@@ -146,88 +146,38 @@ async def chat(query,history,audience,sources,reports):
146
  if len(reports) == 0:
147
  reports = []
148
 
149
-
150
  retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
151
  rag_chain = make_rag_chain(retriever,llm)
152
-
153
- # gradio_format = make_pairs([a.content for a in history]) + [(query, "")]
154
- # history = history + [(query,"")]
155
- # print(history)
156
- # print(gradio_format)
157
-
158
- # # reset memory
159
- # memory.clear()
160
- # for message in history:
161
- # memory.chat_memory.add_message(message)
162
 
163
  inputs = {"query": query,"audience": audience_prompt}
164
  result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
165
  # result = rag_chain.stream(inputs)
166
 
167
- reformulated_question_path_id = "/logs/flatten_dict/final_output"
168
- retriever_path_id = "/logs/Retriever/final_output"
169
- streaming_output_path_id = "/logs/AzureChatOpenAI:2/streamed_output_str/-"
170
- final_output_path_id = "/streamed_output/-"
171
 
172
  docs_html = ""
173
  output_query = ""
174
  output_language = ""
175
  gallery = []
176
 
177
- # for output in result:
178
-
179
- # if "language" in output:
180
- # output_language = output["language"]
181
- # if "question" in output:
182
- # output_query = output["question"]
183
- # if "docs" in output:
184
-
185
- # try:
186
- # docs = output['docs'] # List[Document]
187
- # docs_html = []
188
- # for i, d in enumerate(docs, 1):
189
- # docs_html.append(make_html_source(d, i))
190
- # docs_html = "".join(docs_html)
191
- # except TypeError:
192
- # print("No documents found")
193
- # continue
194
-
195
- # if "answer" in output:
196
- # new_token = output["answer"] # str
197
- # time.sleep(0.03)
198
- # answer_yet = history[-1][1] + new_token
199
- # answer_yet = parse_output_llm_with_sources(answer_yet)
200
- # history[-1] = (query,answer_yet)
201
-
202
- # yield history,docs_html,output_query,output_language,gallery
203
-
204
-
205
-
206
- # async def fallback_iterator(iterable):
207
- # async for item in iterable:
208
- # try:
209
- # yield item
210
- # except Exception as e:
211
- # print(f"Error in fallback iterator: {e}")
212
- # raise gr.Error(f"ClimateQ&A Error: {e}\nThe error has been noted, try another question and if the error remains, you can contact us :)")
213
-
214
  try:
215
  async for op in result:
216
 
217
-
218
  op = op.ops[0]
219
  # print("ITERATION",op)
220
 
221
- if op['path'] == reformulated_question_path_id: # reforulated question
222
  try:
223
  output_language = op['value']["language"] # str
224
  output_query = op["value"]["question"]
225
  except Exception as e:
226
  raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
227
 
228
- elif op['path'] == retriever_path_id: # documents
229
  try:
230
- docs = op['value']['documents'] # List[Document]
231
  docs_html = []
232
  for i, d in enumerate(docs, 1):
233
  docs_html.append(make_html_source(d, i))
@@ -237,7 +187,7 @@ async def chat(query,history,audience,sources,reports):
237
  print("op: ",op)
238
  continue
239
 
240
- elif op['path'] == streaming_output_path_id: # final answer
241
  new_token = op['value'] # str
242
  time.sleep(0.01)
243
  answer_yet = history[-1][1] + new_token
 
146
  if len(reports) == 0:
147
  reports = []
148
 
 
149
  retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5)
150
  rag_chain = make_rag_chain(retriever,llm)
 
 
 
 
 
 
 
 
 
 
151
 
152
  inputs = {"query": query,"audience": audience_prompt}
153
  result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]})
154
  # result = rag_chain.stream(inputs)
155
 
156
+ path_reformulation = "/logs/reformulation/final_output"
157
+ path_retriever = "/logs/find_documents/final_output"
158
+ path_answer = "/logs/answer/streamed_output_str/-"
 
159
 
160
  docs_html = ""
161
  output_query = ""
162
  output_language = ""
163
  gallery = []
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  try:
166
  async for op in result:
167
 
 
168
  op = op.ops[0]
169
  # print("ITERATION",op)
170
 
171
+ if op['path'] == path_reformulation: # reforulated question
172
  try:
173
  output_language = op['value']["language"] # str
174
  output_query = op["value"]["question"]
175
  except Exception as e:
176
  raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
177
 
178
+ elif op['path'] == path_retriever: # documents
179
  try:
180
+ docs = op['value']['docs'] # List[Document]
181
  docs_html = []
182
  for i, d in enumerate(docs, 1):
183
  docs_html.append(make_html_source(d, i))
 
187
  print("op: ",op)
188
  continue
189
 
190
+ elif op['path'] == path_answer: # final answer
191
  new_token = op['value'] # str
192
  time.sleep(0.01)
193
  answer_yet = history[-1][1] + new_token
climateqa/engine/rag.py CHANGED
@@ -8,8 +8,7 @@ from langchain_core.prompts.base import format_document
8
 
9
  from climateqa.engine.reformulation import make_reformulation_chain
10
  from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
11
- from climateqa.engine.utils import pass_values, flatten_dict
12
-
13
 
14
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
15
 
@@ -44,21 +43,13 @@ def make_rag_chain(retriever,llm):
44
  prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
45
 
46
  # ------- CHAIN 0 - Reformulation
47
- reformulation_chain = make_reformulation_chain(llm)
48
- reformulation = (
49
- {"reformulation":reformulation_chain,**pass_values(["audience","query"])}
50
- | RunnablePassthrough()
51
- | flatten_dict
52
- )
53
-
54
 
55
  # ------- CHAIN 1
56
  # Retrieved documents
57
- find_documents = {
58
- "docs": itemgetter("question") | retriever,
59
- **pass_values(["question","audience","language","query"])
60
- } | RunnablePassthrough()
61
-
62
 
63
  # ------- CHAIN 2
64
  # Construct inputs for the llm
@@ -69,15 +60,15 @@ def make_rag_chain(retriever,llm):
69
 
70
  # ------- CHAIN 3
71
  # Bot answer
72
-
73
 
74
  answer_with_docs = {
75
- "answer": input_documents | prompt | llm | StrOutputParser(),
76
  **pass_values(["question","audience","language","query","docs"]),
77
  }
78
 
79
  answer_without_docs = {
80
- "answer": prompt_without_docs | llm | StrOutputParser(),
81
  **pass_values(["question","audience","language","query","docs"]),
82
  }
83
 
 
8
 
9
  from climateqa.engine.reformulation import make_reformulation_chain
10
  from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
11
+ from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
 
12
 
13
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
14
 
 
43
  prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
44
 
45
  # ------- CHAIN 0 - Reformulation
46
+ reformulation = make_reformulation_chain(llm)
47
+ reformulation = prepare_chain(reformulation,"reformulation")
 
 
 
 
 
48
 
49
  # ------- CHAIN 1
50
  # Retrieved documents
51
+ find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
52
+ find_documents = prepare_chain(find_documents,"find_documents")
 
 
 
53
 
54
  # ------- CHAIN 2
55
  # Construct inputs for the llm
 
60
 
61
  # ------- CHAIN 3
62
  # Bot answer
63
+ llm_final = rename_chain(llm,"answer")
64
 
65
  answer_with_docs = {
66
+ "answer": input_documents | prompt | llm_final | StrOutputParser(),
67
  **pass_values(["question","audience","language","query","docs"]),
68
  }
69
 
70
  answer_without_docs = {
71
+ "answer": prompt_without_docs | llm_final | StrOutputParser(),
72
  **pass_values(["question","audience","language","query","docs"]),
73
  }
74
 
climateqa/engine/utils.py CHANGED
@@ -1,10 +1,29 @@
1
-
2
- from typing import Any, Dict, Iterable, Tuple, Union
3
  from operator import itemgetter
 
 
 
4
 
5
  def pass_values(x):
6
- if not isinstance(x,list): x = [x]
7
- return {k:itemgetter(k) for k in x}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  # Drawn from langchain utils and modified to remove the parent key
@@ -48,5 +67,3 @@ def flatten_dict(
48
  """
49
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
50
  return flat_dict
51
-
52
-
 
 
 
1
  from operator import itemgetter
2
+ from typing import Any, Dict, Iterable, Tuple
3
+ from langchain_core.runnables import RunnablePassthrough
4
+
5
 
6
  def pass_values(x):
7
+ if not isinstance(x, list):
8
+ x = [x]
9
+ return {k: itemgetter(k) for k in x}
10
+
11
+
12
+ def prepare_chain(chain,name):
13
+ chain = propagate_inputs(chain)
14
+ chain = rename_chain(chain,name)
15
+ return chain
16
+
17
+
18
+ def propagate_inputs(chain):
19
+ chain_with_values = {
20
+ "outputs": chain,
21
+ "inputs": RunnablePassthrough()
22
+ } | RunnablePassthrough() | flatten_dict
23
+ return chain_with_values
24
+
25
+ def rename_chain(chain,name):
26
+ return chain.with_config({"run_name":name})
27
 
28
 
29
  # Drawn from langchain utils and modified to remove the parent key
 
67
  """
68
  flat_dict = {k: v for k, v in _flatten_dict(nested_dict, parent_key, sep)}
69
  return flat_dict