Theo Alves Da Costa commited on
Commit
3d561c7
1 Parent(s): 91f77da

Corrected major bug

Browse files
Files changed (2) hide show
  1. app.py +149 -64
  2. climateqa/chains.py +42 -5
app.py CHANGED
@@ -20,7 +20,8 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
20
 
21
  # ClimateQ&A imports
22
  from climateqa.llm import get_llm
23
- from climateqa.chains import load_climateqa_chain
 
24
  from climateqa.vectorstore import get_pinecone_vectorstore
25
  from climateqa.retriever import ClimateQARetriever
26
  from climateqa.prompts import audience_prompts
@@ -142,36 +143,49 @@ vectorstore = get_pinecone_vectorstore(embeddings_function)
142
 
143
  from threading import Thread
144
 
 
145
 
146
- def answer_user(message,history):
147
- return message, history + [[message, None]]
148
 
149
- def answer_bot(message,history,audience,sources):
 
150
 
 
151
 
152
- Q = SimpleQueue()
 
 
153
 
154
  llm_reformulation = get_llm(max_tokens = 512,temperature = 0.0,verbose = True,streaming = False)
155
- llm_streaming = get_llm(max_tokens = 1024,temperature = 0.0,verbose = True,streaming = True,
156
- callbacks=[StreamingGradioCallbackHandler(Q),StreamingStdOutCallbackHandler()],
157
- )
158
-
159
  retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,k_summary = 3,k_total = 10)
160
- chain = load_climateqa_chain(retriever,llm_reformulation,llm_streaming)
161
 
 
 
 
 
162
 
163
- if len(sources) == 0:
164
- sources = ["IPCC"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- # if len(message) <= 2:
167
- # complete_response = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate and biodiversity issues).**"
168
- # history[-1][1] += "\n\n" + complete_response
169
- # return "", history, ""
170
 
171
- def threaded_chain(query,audience):
172
- response = chain({"query":query,"audience":audience})
173
- Q.put(response)
174
- Q.put(job_done)
175
 
176
  if audience == "Children":
177
  audience_prompt = audience_prompts["children"]
@@ -182,6 +196,57 @@ def answer_bot(message,history,audience,sources):
182
  else:
183
  audience_prompt = audience_prompts["experts"]
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # history_langchain_format = []
186
  # for human, ai in history:
187
  # history_langchain_format.append(HumanMessage(content=human))
@@ -190,41 +255,42 @@ def answer_bot(message,history,audience,sources):
190
  # for next_token, content in stream(message):
191
  # yield(content)
192
 
193
- thread = Thread(target=threaded_chain, kwargs={"query":message,"audience":audience_prompt})
194
- thread.start()
195
-
196
- history[-1][1] = ""
197
- while True:
198
- next_item = Q.get(block=True) # Blocks until an input is available
199
-
200
- if next_item is job_done:
201
- continue
202
-
203
- elif isinstance(next_item, dict): # assuming LLMResult is a dictionary
204
- response = next_item
205
- if "source_documents" in response and len(response["source_documents"]) > 0:
206
- sources_text = []
207
- for i, d in enumerate(response["source_documents"], 1):
208
- sources_text.append(make_html_source(d, i))
209
- sources_text = "\n\n".join([f"Query used for retrieval:\n{response['question']}"] + sources_text)
210
- # history[-1][1] += next_item["answer"]
211
- # history[-1][1] += "\n\n" + sources_text
212
- yield "", history, sources_text
213
-
214
- else:
215
- sources_text = "⚠️ No relevant passages found in the scientific reports (IPCC and IPBES)"
216
- complete_response = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate and biodiversity issues).**"
217
- history[-1][1] += "\n\n" + complete_response
218
- yield "", history, sources_text
219
- break
220
-
221
- elif isinstance(next_item, str):
222
- new_paragraph = history[-1][1] + next_item
223
- new_paragraph = parse_output_llm_with_sources(new_paragraph)
224
- history[-1][1] = new_paragraph
225
- yield "", history, ""
226
-
227
- thread.join()
 
228
 
229
  #---------------------------------------------------------------------------
230
  # ClimateQ&A core functions
@@ -375,6 +441,8 @@ def log_on_azure(file, logs, share_client):
375
  file_client.upload_file(str(logs))
376
 
377
 
 
 
378
 
379
 
380
 
@@ -419,7 +487,9 @@ with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
419
  show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",avatar_images = ("assets/logo4.png",None))
420
 
421
  # bot.like(vote,None,None)
422
-
 
 
423
  with gr.Row(elem_id = "input-message"):
424
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=1,lines = 1,interactive = True)
425
  # submit_button = gr.Button(">",scale = 1,elem_id = "submit-button")
@@ -472,12 +542,14 @@ with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
472
  )
473
 
474
  with gr.Tab("📚 Citations",elem_id = "tab-citations"):
475
- sources_textbox = gr.Markdown(show_label=False, elem_id="sources-textbox")
 
476
 
477
  with gr.Tab("⚙️ Configuration",elem_id = "tab-config"):
478
 
479
  gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
480
 
 
481
  dropdown_sources = gr.CheckboxGroup(
482
  ["IPCC", "IPBES"],
483
  label="Select reports",
@@ -492,14 +564,27 @@ with gr.Blocks(title="🌍 Climate Q&A", css="style.css", theme=theme) as demo:
492
  interactive=True,
493
  )
494
 
 
 
 
 
495
 
496
  # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
497
- textbox.submit(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
498
- answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
499
- )
500
- examples_hidden.change(answer_user, [examples_hidden, bot], [textbox, bot], queue=True).then(
501
- answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
502
- )
 
 
 
 
 
 
 
 
 
503
  # submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
504
  # answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
505
  # )
@@ -688,6 +773,6 @@ Or around 2 to 4 times more than a typical Google search.
688
  """
689
  )
690
 
691
- demo.queue(concurrency_count=1)
692
 
693
  demo.launch()
 
20
 
21
  # ClimateQ&A imports
22
  from climateqa.llm import get_llm
23
+ from climateqa.chains import load_qa_chain_with_docs,load_qa_chain_with_text
24
+ from climateqa.chains import load_reformulation_chain
25
  from climateqa.vectorstore import get_pinecone_vectorstore
26
  from climateqa.retriever import ClimateQARetriever
27
  from climateqa.prompts import audience_prompts
 
143
 
144
  from threading import Thread
145
 
146
+ import json
147
 
148
+ def answer_user(query,query_example,history):
149
+ return query, history + [[query, ". . ."]]
150
 
151
+ def answer_user_example(query,query_example,history):
152
+ return query_example, history + [[query_example, ". . ."]]
153
 
154
+ def fetch_sources(query,sources):
155
 
156
+ # Prepare default values
157
+ if len(sources) == 0:
158
+ sources = ["IPCC"]
159
 
160
  llm_reformulation = get_llm(max_tokens = 512,temperature = 0.0,verbose = True,streaming = False)
 
 
 
 
161
  retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,k_summary = 3,k_total = 10)
162
+ reformulation_chain = load_reformulation_chain(llm_reformulation)
163
 
164
+ # Calculate language
165
+ output_reformulation = reformulation_chain({"query":query})
166
+ question = output_reformulation["question"]
167
+ language = output_reformulation["language"]
168
 
169
+ # Retrieve docs
170
+ docs = retriever.get_relevant_documents(question)
171
+
172
+ if len(docs) > 0:
173
+
174
+ # Already display the sources
175
+ sources_text = []
176
+ for i, d in enumerate(docs, 1):
177
+ sources_text.append(make_html_source(d, i))
178
+ citations_text = "".join(sources_text)
179
+ docs_text = "\n\n".join([d.page_content for d in docs])
180
+ return "",citations_text,docs_text,question,language
181
+ else:
182
+ sources_text = "⚠️ No relevant passages found in the scientific reports (IPCC and IPBES)"
183
+ citations_text = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate and biodiversity issues).**"
184
+ docs_text = ""
185
+ return "",citations_text,docs_text,question,language
186
 
 
 
 
 
187
 
188
+ def answer_bot(query,history,docs,question,language,audience):
 
 
 
189
 
190
  if audience == "Children":
191
  audience_prompt = audience_prompts["children"]
 
196
  else:
197
  audience_prompt = audience_prompts["experts"]
198
 
199
+ # Prepare Queue for streaming LLMs
200
+ Q = SimpleQueue()
201
+
202
+ llm_streaming = get_llm(max_tokens = 1024,temperature = 0.0,verbose = True,streaming = True,
203
+ callbacks=[StreamingGradioCallbackHandler(Q),StreamingStdOutCallbackHandler()],
204
+ )
205
+
206
+ qa_chain = load_qa_chain_with_text(llm_streaming)
207
+
208
+ def threaded_chain(question,audience,language,docs):
209
+ try:
210
+ response = qa_chain({"question":question,"audience":audience,"language":language,"summaries":docs})
211
+ Q.put(response)
212
+ Q.put(job_done)
213
+ except Exception as e:
214
+ print(e)
215
+
216
+ history[-1][1] = ""
217
+
218
+ textbox=gr.Textbox(placeholder=". . .",show_label=False,scale=1,lines = 1,interactive = False)
219
+
220
+
221
+ if len(docs) > 0:
222
+
223
+ # Start thread for streaming
224
+ thread = Thread(
225
+ target=threaded_chain,
226
+ kwargs={"question":question,"audience":audience_prompt,"language":language,"docs":docs}
227
+ )
228
+ thread.start()
229
+
230
+ while True:
231
+ next_item = Q.get(block=True) # Blocks until an input is available
232
+
233
+ if next_item is job_done:
234
+ break
235
+ elif isinstance(next_item, str):
236
+ new_paragraph = history[-1][1] + next_item
237
+ new_paragraph = parse_output_llm_with_sources(new_paragraph)
238
+ history[-1][1] = new_paragraph
239
+ yield textbox,history
240
+ else:
241
+ pass
242
+ thread.join()
243
+ else:
244
+ complete_response = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate and biodiversity issues).**"
245
+ history[-1][1] += complete_response
246
+ yield "",history
247
+
248
+
249
+
250
  # history_langchain_format = []
251
  # for human, ai in history:
252
  # history_langchain_format.append(HumanMessage(content=human))
 
255
  # for next_token, content in stream(message):
256
  # yield(content)
257
 
258
+ # thread = Thread(target=threaded_chain, kwargs={"query":message,"audience":audience_prompt})
259
+ # thread.start()
260
+
261
+ # history[-1][1] = ""
262
+ # while True:
263
+ # next_item = Q.get(block=True) # Blocks until an input is available
264
+
265
+ # print(type(next_item))
266
+ # if next_item is job_done:
267
+ # continue
268
+
269
+ # elif isinstance(next_item, dict): # assuming LLMResult is a dictionary
270
+ # response = next_item
271
+ # if "source_documents" in response and len(response["source_documents"]) > 0:
272
+ # sources_text = []
273
+ # for i, d in enumerate(response["source_documents"], 1):
274
+ # sources_text.append(make_html_source(d, i))
275
+ # sources_text = "\n\n".join([f"Query used for retrieval:\n{response['question']}"] + sources_text)
276
+ # # history[-1][1] += next_item["answer"]
277
+ # # history[-1][1] += "\n\n" + sources_text
278
+ # yield "", history, sources_text
279
+
280
+ # else:
281
+ # sources_text = "⚠️ No relevant passages found in the scientific reports (IPCC and IPBES)"
282
+ # complete_response = "**⚠️ No relevant passages found in the climate science reports (IPCC and IPBES), you may want to ask a more specific question (specifying your question on climate and biodiversity issues).**"
283
+ # history[-1][1] += "\n\n" + complete_response
284
+ # yield "", history, sources_text
285
+ # break
286
+
287
+ # elif isinstance(next_item, str):
288
+ # new_paragraph = history[-1][1] + next_item
289
+ # new_paragraph = parse_output_llm_with_sources(new_paragraph)
290
+ # history[-1][1] = new_paragraph
291
+ # yield "", history, ""
292
+
293
+ # thread.join()
294
 
295
  #---------------------------------------------------------------------------
296
  # ClimateQ&A core functions
 
441
  file_client.upload_file(str(logs))
442
 
443
 
444
+ def disable_component():
445
+ return gr.update(interactive = False)
446
 
447
 
448
 
 
487
  show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel",avatar_images = ("assets/logo4.png",None))
488
 
489
  # bot.like(vote,None,None)
490
+
491
+
492
+
493
  with gr.Row(elem_id = "input-message"):
494
  textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=1,lines = 1,interactive = True)
495
  # submit_button = gr.Button(">",scale = 1,elem_id = "submit-button")
 
542
  )
543
 
544
  with gr.Tab("📚 Citations",elem_id = "tab-citations"):
545
+ sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
546
+ docs_textbox = gr.State("")
547
 
548
  with gr.Tab("⚙️ Configuration",elem_id = "tab-config"):
549
 
550
  gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
551
 
552
+
553
  dropdown_sources = gr.CheckboxGroup(
554
  ["IPCC", "IPBES"],
555
  label="Select reports",
 
564
  interactive=True,
565
  )
566
 
567
+ output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
568
+ output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
569
+
570
+
571
 
572
  # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
573
+ (textbox
574
+ .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
575
+ .then(disable_component, [examples_questions], [examples_questions],queue = False)
576
+ .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
577
+ .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue = True)
578
+ .success(lambda x : textbox,[textbox],[textbox])
579
+ )
580
+
581
+ (examples_hidden
582
+ .change(answer_user_example, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
583
+ .then(disable_component, [examples_questions], [examples_questions],queue = False)
584
+ .success(fetch_sources,[textbox,dropdown_sources], [textbox,sources_textbox,docs_textbox,output_query,output_language])
585
+ .success(answer_bot, [textbox,bot,docs_textbox,output_query,output_language,dropdown_audience], [textbox,bot],queue=True)
586
+ .success(lambda x : textbox,[textbox],[textbox])
587
+ )
588
  # submit_button.click(answer_user, [textbox, bot], [textbox, bot], queue=True).then(
589
  # answer_bot, [textbox,bot,dropdown_audience,dropdown_sources], [textbox,bot,sources_textbox]
590
  # )
 
773
  """
774
  )
775
 
776
+ demo.queue(concurrency_count=16)
777
 
778
  demo.launch()
climateqa/chains.py CHANGED
@@ -3,7 +3,7 @@
3
  import json
4
 
5
  from langchain import PromptTemplate, LLMChain
6
- from langchain.chains import RetrievalQAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
@@ -37,11 +37,48 @@ def load_reformulation_chain(llm):
37
  return reformulation_chain
38
 
39
 
40
-
41
-
42
- def load_answer_chain(retriever,llm):
43
  prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
44
  qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  # This could be improved by providing a document prompt to avoid modifying page_content in the docs
47
  # See here https://github.com/langchain-ai/langchain/issues/3523
@@ -59,7 +96,7 @@ def load_answer_chain(retriever,llm):
59
  def load_climateqa_chain(retriever,llm_reformulation,llm_answer):
60
 
61
  reformulation_chain = load_reformulation_chain(llm_reformulation)
62
- answer_chain = load_answer_chain(retriever,llm_answer)
63
 
64
  climateqa_chain = SequentialChain(
65
  chains = [reformulation_chain,answer_chain],
 
3
  import json
4
 
5
  from langchain import PromptTemplate, LLMChain
6
+ from langchain.chains import RetrievalQAWithSourcesChain,QAWithSourcesChain
7
  from langchain.chains import TransformChain, SequentialChain
8
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
9
 
 
37
  return reformulation_chain
38
 
39
 
40
+ def load_combine_documents_chain(llm):
 
 
41
  prompt = PromptTemplate(template=answer_prompt, input_variables=["summaries", "question","audience","language"])
42
  qa_chain = load_qa_with_sources_chain(llm, chain_type="stuff",prompt = prompt)
43
+ return qa_chain
44
+
45
+ def load_qa_chain_with_docs(llm):
46
+ """Load a QA chain with documents.
47
+ Useful when you already have retrieved docs
48
+
49
+ To be called with this input
50
+
51
+ ```
52
+ output = chain({
53
+ "question":query,
54
+ "audience":"experts climate scientists",
55
+ "docs":docs,
56
+ "language":"English",
57
+ })
58
+ ```
59
+ """
60
+
61
+ qa_chain = load_combine_documents_chain(llm)
62
+ chain = QAWithSourcesChain(
63
+ input_docs_key = "docs",
64
+ combine_documents_chain = qa_chain,
65
+ return_source_documents = True,
66
+ )
67
+ return chain
68
+
69
+
70
+ def load_qa_chain_with_text(llm):
71
+
72
+ prompt = PromptTemplate(
73
+ template = answer_prompt,
74
+ input_variables=["question","audience","language","summaries"],
75
+ )
76
+ qa_chain = LLMChain(llm = llm,prompt = prompt)
77
+ return qa_chain
78
+
79
+
80
+ def load_qa_chain_with_retriever(retriever,llm):
81
+ qa_chain = load_combine_documents_chain(llm)
82
 
83
  # This could be improved by providing a document prompt to avoid modifying page_content in the docs
84
  # See here https://github.com/langchain-ai/langchain/issues/3523
 
96
  def load_climateqa_chain(retriever,llm_reformulation,llm_answer):
97
 
98
  reformulation_chain = load_reformulation_chain(llm_reformulation)
99
+ answer_chain = load_qa_chain_with_retriever(retriever,llm_answer)
100
 
101
  climateqa_chain = SequentialChain(
102
  chains = [reformulation_chain,answer_chain],