TheoLvs commited on
Commit
caf1faa
1 Parent(s): 001af11

Experimental openalex feature

Browse files
Ekimetrics_Logo_Color.jpg DELETED
Binary file (76.8 kB)
app.py CHANGED
@@ -1,6 +1,11 @@
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
  import pandas as pd
@@ -32,6 +37,8 @@ from climateqa.engine.prompts import audience_prompts
32
  from climateqa.sample_questions import QUESTIONS
33
  from climateqa.constants import POSSIBLE_REPORTS
34
  from climateqa.utils import get_image_from_azure_blob_storage
 
 
35
 
36
  # Load environment variables in local mode
37
  try:
@@ -141,19 +148,20 @@ async def chat(query,history,audience,sources,reports):
141
  # result = rag_chain.stream(inputs)
142
 
143
  path_reformulation = "/logs/reformulation/final_output"
 
144
  path_retriever = "/logs/find_documents/final_output"
145
  path_answer = "/logs/answer/streamed_output_str/-"
146
 
147
  docs_html = ""
148
  output_query = ""
149
  output_language = ""
 
150
  gallery = []
151
 
152
  try:
153
  async for op in result:
154
 
155
  op = op.ops[0]
156
- # print("ITERATION",op)
157
 
158
  if op['path'] == path_reformulation: # reforulated question
159
  try:
@@ -162,6 +170,14 @@ async def chat(query,history,audience,sources,reports):
162
  except Exception as e:
163
  raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
164
 
 
 
 
 
 
 
 
 
165
  elif op['path'] == path_retriever: # documents
166
  try:
167
  docs = op['value']['docs'] # List[Document]
@@ -183,23 +199,13 @@ async def chat(query,history,audience,sources,reports):
183
  answer_yet = parse_output_llm_with_sources(answer_yet)
184
  history[-1] = (query,answer_yet)
185
 
186
-
187
- # elif op['path'] == final_output_path_id:
188
- # final_output = op['value']
189
 
190
- # if "answer" in final_output:
191
-
192
- # final_output = final_output["answer"]
193
- # print(final_output)
194
- # answer = history[-1][1] + final_output
195
- # answer = parse_output_llm_with_sources(answer)
196
- # history[-1] = (query,answer)
197
 
198
  else:
199
  continue
200
 
201
  history = [tuple(x) for x in history]
202
- yield history,docs_html,output_query,output_language,gallery
203
 
204
  except Exception as e:
205
  raise gr.Error(f"{e}")
@@ -267,37 +273,7 @@ async def chat(query,history,audience,sources,reports):
267
  # gallery = list(set("|".join(gallery).split("|")))
268
  # gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
269
 
270
- yield history,docs_html,output_query,output_language,gallery
271
-
272
-
273
- # memory.save_context(inputs, {"answer": gradio_format[-1][1]})
274
- # yield gradio_format, memory.load_memory_variables({})["history"], source_string
275
-
276
- # async def chat_with_timeout(query, history, audience, sources, reports, timeout_seconds=2):
277
- # async def timeout_gen(async_gen, timeout):
278
- # try:
279
- # while True:
280
- # try:
281
- # yield await asyncio.wait_for(async_gen.__anext__(), timeout)
282
- # except StopAsyncIteration:
283
- # break
284
- # except asyncio.TimeoutError:
285
- # raise gr.Error("Operation timed out. Please try again.")
286
-
287
- # return timeout_gen(chat(query, history, audience, sources, reports), timeout_seconds)
288
-
289
-
290
-
291
- # # A wrapper function that includes a timeout
292
- # async def chat_with_timeout(query, history, audience, sources, reports, timeout_seconds=2):
293
- # try:
294
- # # Use asyncio.wait_for to apply a timeout to the chat function
295
- # return await asyncio.wait_for(chat(query, history, audience, sources, reports), timeout_seconds)
296
- # except asyncio.TimeoutError:
297
- # # Handle the timeout error as desired
298
- # raise gr.Error("Operation timed out. Please try again.")
299
-
300
-
301
 
302
 
303
  def make_html_source(source,i):
@@ -392,6 +368,79 @@ def log_on_azure(file, logs, share_client):
392
  file_client.upload_file(logs)
393
 
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  # --------------------------------------------------------------------
396
  # Gradio
397
  # --------------------------------------------------------------------
@@ -474,7 +523,7 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
474
  samples.append(group_examples)
475
 
476
 
477
- with gr.Tab("Citations",elem_id = "tab-citations",id = 1):
478
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
479
  docs_textbox = gr.State("")
480
 
@@ -513,6 +562,7 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
513
 
514
 
515
 
 
516
  #---------------------------------------------------------------------------------------
517
  # OTHER TABS
518
  #---------------------------------------------------------------------------------------
@@ -521,6 +571,28 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
521
  with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
522
  gallery_component = gr.Gallery()
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  with gr.Tab("About",elem_classes = "max-height other-tabs"):
525
  with gr.Row():
526
  with gr.Column(scale=1):
@@ -537,13 +609,13 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
537
 
538
  (textbox
539
  .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
540
- .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component],concurrency_limit = 8,api_name = "chat_textbox")
541
  .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
542
  )
543
 
544
  (examples_hidden
545
  .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
546
- .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component],concurrency_limit = 8,api_name = "chat_examples")
547
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
548
  )
549
 
@@ -558,6 +630,9 @@ with gr.Blocks(title="Climate Q&A", css="style.css", theme=theme,elem_id = "main
558
 
559
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
560
 
 
 
 
561
  # # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
562
  # (textbox
563
  # .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
1
  from climateqa.engine.embeddings import get_embeddings_function
2
  embeddings_function = get_embeddings_function()
3
 
4
+ from climateqa.papers.openalex import OpenAlex
5
+ from sentence_transformers import CrossEncoder
6
+
7
+ reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
8
+ oa = OpenAlex()
9
 
10
  import gradio as gr
11
  import pandas as pd
37
  from climateqa.sample_questions import QUESTIONS
38
  from climateqa.constants import POSSIBLE_REPORTS
39
  from climateqa.utils import get_image_from_azure_blob_storage
40
+ from climateqa.engine.keywords import make_keywords_chain
41
+ from climateqa.engine.rag import make_rag_papers_chain
42
 
43
  # Load environment variables in local mode
44
  try:
148
  # result = rag_chain.stream(inputs)
149
 
150
  path_reformulation = "/logs/reformulation/final_output"
151
+ path_keywords = "/logs/keywords/final_output"
152
  path_retriever = "/logs/find_documents/final_output"
153
  path_answer = "/logs/answer/streamed_output_str/-"
154
 
155
  docs_html = ""
156
  output_query = ""
157
  output_language = ""
158
+ output_keywords = ""
159
  gallery = []
160
 
161
  try:
162
  async for op in result:
163
 
164
  op = op.ops[0]
 
165
 
166
  if op['path'] == path_reformulation: # reforulated question
167
  try:
170
  except Exception as e:
171
  raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)")
172
 
173
+ if op["path"] == path_keywords:
174
+ try:
175
+ output_keywords = op['value']["keywords"] # str
176
+ output_keywords = " AND ".join(output_keywords)
177
+ except Exception as e:
178
+ pass
179
+
180
+
181
  elif op['path'] == path_retriever: # documents
182
  try:
183
  docs = op['value']['docs'] # List[Document]
199
  answer_yet = parse_output_llm_with_sources(answer_yet)
200
  history[-1] = (query,answer_yet)
201
 
 
 
 
202
 
 
 
 
 
 
 
 
203
 
204
  else:
205
  continue
206
 
207
  history = [tuple(x) for x in history]
208
+ yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
209
 
210
  except Exception as e:
211
  raise gr.Error(f"{e}")
273
  # gallery = list(set("|".join(gallery).split("|")))
274
  # gallery = [get_image_from_azure_blob_storage(x) for x in gallery]
275
 
276
+ yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
 
279
  def make_html_source(source,i):
368
  file_client.upload_file(logs)
369
 
370
 
371
+ def generate_keywords(query):
372
+ chain = make_keywords_chain(llm)
373
+ keywords = chain.invoke(query)
374
+ keywords = " AND ".join(keywords["keywords"])
375
+ return keywords
376
+
377
+
378
+
379
+ papers_cols_widths = {
380
+ "doc":50,
381
+ "id":100,
382
+ "title":300,
383
+ "doi":100,
384
+ "publication_year":100,
385
+ "abstract":500,
386
+ "rerank_score":100,
387
+ "is_oa":50,
388
+ }
389
+
390
+ papers_cols = list(papers_cols_widths.keys())
391
+ papers_cols_widths = list(papers_cols_widths.values())
392
+
393
+ async def find_papers(query, keywords,after):
394
+
395
+ summary = ""
396
+
397
+ df_works = oa.search(keywords,after = after)
398
+ df_works = df_works.dropna(subset=["abstract"])
399
+ df_works = oa.rerank(query,df_works,reranker)
400
+ df_works = df_works.sort_values("rerank_score",ascending=False)
401
+ G = oa.make_network(df_works)
402
+
403
+ height = "750px"
404
+ network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
405
+ network_html = network.generate_html()
406
+
407
+ network_html = network_html.replace("'", "\"")
408
+ css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
409
+ network_html = network_html + css_to_inject
410
+
411
+
412
+ network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
413
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
414
+ allow-scripts allow-same-origin allow-popups
415
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
416
+ allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
417
+
418
+
419
+ docs = df_works["content"].head(15).tolist()
420
+
421
+ df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
422
+ df_works["doc"] = df_works["doc"] + 1
423
+ df_works = df_works[papers_cols]
424
+
425
+ yield df_works,network_html,summary
426
+
427
+ chain = make_rag_papers_chain(llm)
428
+ result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
429
+ path_answer = "/logs/StrOutputParser/streamed_output/-"
430
+
431
+ async for op in result:
432
+
433
+ op = op.ops[0]
434
+
435
+ if op['path'] == path_answer: # reforulated question
436
+ new_token = op['value'] # str
437
+ summary += new_token
438
+ else:
439
+ continue
440
+ yield df_works,network_html,summary
441
+
442
+
443
+
444
  # --------------------------------------------------------------------
445
  # Gradio
446
  # --------------------------------------------------------------------
523
  samples.append(group_examples)
524
 
525
 
526
+ with gr.Tab("Sources",elem_id = "tab-citations",id = 1):
527
  sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
528
  docs_textbox = gr.State("")
529
 
562
 
563
 
564
 
565
+
566
  #---------------------------------------------------------------------------------------
567
  # OTHER TABS
568
  #---------------------------------------------------------------------------------------
571
  with gr.Tab("Figures",elem_id = "tab-images",elem_classes = "max-height other-tabs"):
572
  gallery_component = gr.Gallery()
573
 
574
+ with gr.Tab("Papers (beta)",elem_id = "tab-papers",elem_classes = "max-height other-tabs"):
575
+
576
+ with gr.Row():
577
+ with gr.Column(scale=1):
578
+ query_papers = gr.Textbox(placeholder="Question",show_label=False,lines = 1,interactive = True,elem_id="query-papers")
579
+ keywords_papers = gr.Textbox(placeholder="Keywords",show_label=False,lines = 1,interactive = True,elem_id="keywords-papers")
580
+ after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers")
581
+ search_papers = gr.Button("Search",elem_id="search-papers",interactive=True)
582
+
583
+ with gr.Column(scale=7):
584
+
585
+ with gr.Tab("Summary",elem_id="papers-summary-tab"):
586
+ papers_summary = gr.Markdown(visible=True,elem_id="papers-summary")
587
+
588
+ with gr.Tab("Relevant papers",elem_id="papers-results-tab"):
589
+ papers_dataframe = gr.Dataframe(visible=True,elem_id="papers-table",headers = papers_cols)
590
+
591
+ with gr.Tab("Citations network",elem_id="papers-network-tab"):
592
+ citations_network = gr.HTML(visible=True,elem_id="papers-citations-network")
593
+
594
+
595
+
596
  with gr.Tab("About",elem_classes = "max-height other-tabs"):
597
  with gr.Row():
598
  with gr.Column(scale=1):
609
 
610
  (textbox
611
  .submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
612
+ .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_textbox")
613
  .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
614
  )
615
 
616
  (examples_hidden
617
  .change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
618
+ .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query,output_language,gallery_component,query_papers,keywords_papers],concurrency_limit = 8,api_name = "chat_examples")
619
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
620
  )
621
 
630
 
631
  dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
632
 
633
+ query_papers.submit(generate_keywords,[query_papers], [keywords_papers])
634
+ search_papers.click(find_papers,[query_papers,keywords_papers,after], [papers_dataframe,citations_network,papers_summary])
635
+
636
  # # textbox.submit(predict_climateqa,[textbox,bot],[None,bot,sources_textbox])
637
  # (textbox
638
  # .submit(answer_user, [textbox,examples_hidden, bot], [textbox, bot],queue = False)
climateqa/engine/keywords.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List
3
+ from typing import Literal
4
+ from langchain_core.pydantic_v1 import BaseModel, Field
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.utils.function_calling import convert_to_openai_function
7
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
8
+
9
+ class KeywordsOutput(BaseModel):
10
+ """Analyzing the user query to get keywords for a search engine"""
11
+
12
+ keywords: list = Field(
13
+ description="""
14
+ Generate 1 or 2 relevant keywords from the user query to ask a search engine for scientific research papers.
15
+
16
+ Example:
17
+ - "What is the impact of deep sea mining ?" -> ["deep sea mining"]
18
+ - "How will El Nino be impacted by climate change" -> ["el nino"]
19
+ - "Is climate change a hoax" -> [Climate change","hoax"]
20
+ """
21
+ )
22
+
23
+
24
+ def make_keywords_chain(llm):
25
+
26
+ functions = [convert_to_openai_function(KeywordsOutput)]
27
+ llm_functions = llm.bind(functions = functions,function_call={"name":"KeywordsOutput"})
28
+
29
+ chain = llm_functions | JsonOutputFunctionsParser()
30
+ return chain
climateqa/engine/prompts.py CHANGED
@@ -60,6 +60,32 @@ Question: {question} - Explained to {audience}
60
  Answer in {language} with the passages citations:
61
  """
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  answer_prompt_images_template = """
64
  You are ClimateQ&A, an AI Assistant created by Ekimetrics.
65
  You are given the answer to a environmental question based on passages from the IPCC and IPBES reports and image captions.
60
  Answer in {language} with the passages citations:
61
  """
62
 
63
+
64
+ papers_prompt_template = """
65
+ You are ClimateQ&A, an AI Assistant created by Ekimetrics. You are given a question and extracted abstracts of scientific papers. Provide a clear and structured answer based on the abstracts provided, the context and the guidelines.
66
+
67
+ Guidelines:
68
+ - If the passages have useful facts or numbers, use them in your answer.
69
+ - When you use information from a passage, mention where it came from by using [Doc i] at the end of the sentence. i stands for the number of the document.
70
+ - Do not use the sentence 'Doc i says ...' to say where information came from.
71
+ - If the same thing is said in more than one document, you can mention all of them like this: [Doc i, Doc j, Doc k]
72
+ - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
73
+ - If it makes sense, use bullet points and lists to make your answers easier to understand.
74
+ - Use markdown to format your answer and make it easier to read.
75
+ - You do not need to use every passage. Only use the ones that help answer the question.
76
+ - If the documents do not have the information needed to answer the question, just say you do not have enough information.
77
+
78
+ -----------------------
79
+ Abstracts:
80
+ {context}
81
+
82
+ -----------------------
83
+ Question: {question}
84
+ Answer in {language} with the passages citations:
85
+ """
86
+
87
+
88
+
89
  answer_prompt_images_template = """
90
  You are ClimateQ&A, an AI Assistant created by Ekimetrics.
91
  You are given the answer to a environmental question based on passages from the IPCC and IPBES reports and image captions.
climateqa/engine/rag.py CHANGED
@@ -8,7 +8,9 @@ from langchain_core.prompts.base import format_document
8
 
9
  from climateqa.engine.reformulation import make_reformulation_chain
10
  from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
 
11
  from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
 
12
 
13
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
14
 
@@ -21,7 +23,11 @@ def _combine_documents(
21
  for i,doc in enumerate(docs):
22
  # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
23
  chunk_type = "Doc"
24
- doc_string = f"{chunk_type} {i+1}: " + format_document(doc, document_prompt)
 
 
 
 
25
  doc_string = doc_string.replace("\n"," ")
26
  doc_strings.append(doc_string)
27
 
@@ -37,7 +43,6 @@ def get_image_docs(x):
37
 
38
  def make_rag_chain(retriever,llm):
39
 
40
-
41
  # Construct the prompt
42
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
43
  prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
@@ -46,6 +51,11 @@ def make_rag_chain(retriever,llm):
46
  reformulation = make_reformulation_chain(llm)
47
  reformulation = prepare_chain(reformulation,"reformulation")
48
 
 
 
 
 
 
49
  # ------- CHAIN 1
50
  # Retrieved documents
51
  find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
@@ -55,7 +65,7 @@ def make_rag_chain(retriever,llm):
55
  # Construct inputs for the llm
56
  input_documents = {
57
  "context":lambda x : _combine_documents(x["docs"]),
58
- **pass_values(["question","audience","language"])
59
  }
60
 
61
  # ------- CHAIN 3
@@ -64,12 +74,12 @@ def make_rag_chain(retriever,llm):
64
 
65
  answer_with_docs = {
66
  "answer": input_documents | prompt | llm_final | StrOutputParser(),
67
- **pass_values(["question","audience","language","query","docs"]),
68
  }
69
 
70
  answer_without_docs = {
71
  "answer": prompt_without_docs | llm_final | StrOutputParser(),
72
- **pass_values(["question","audience","language","query","docs"]),
73
  }
74
 
75
  # def has_images(x):
@@ -87,11 +97,29 @@ def make_rag_chain(retriever,llm):
87
 
88
  # ------- FINAL CHAIN
89
  # Build the final chain
90
- rag_chain = reformulation | find_documents | answer
91
 
92
  return rag_chain
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def make_illustration_chain(llm):
97
 
8
 
9
  from climateqa.engine.reformulation import make_reformulation_chain
10
  from climateqa.engine.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
11
+ from climateqa.engine.prompts import papers_prompt_template
12
  from climateqa.engine.utils import pass_values, flatten_dict,prepare_chain,rename_chain
13
+ from climateqa.engine.keywords import make_keywords_chain
14
 
15
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
16
 
23
  for i,doc in enumerate(docs):
24
  # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
25
  chunk_type = "Doc"
26
+ if isinstance(doc,str):
27
+ doc_formatted = doc
28
+ else:
29
+ doc_formatted = format_document(doc, document_prompt)
30
+ doc_string = f"{chunk_type} {i+1}: " + doc_formatted
31
  doc_string = doc_string.replace("\n"," ")
32
  doc_strings.append(doc_string)
33
 
43
 
44
  def make_rag_chain(retriever,llm):
45
 
 
46
  # Construct the prompt
47
  prompt = ChatPromptTemplate.from_template(answer_prompt_template)
48
  prompt_without_docs = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
51
  reformulation = make_reformulation_chain(llm)
52
  reformulation = prepare_chain(reformulation,"reformulation")
53
 
54
+ # ------- Find all keywords from the reformulated query
55
+ keywords = make_keywords_chain(llm)
56
+ keywords = {"keywords":itemgetter("question") | keywords}
57
+ keywords = prepare_chain(keywords,"keywords")
58
+
59
  # ------- CHAIN 1
60
  # Retrieved documents
61
  find_documents = {"docs": itemgetter("question") | retriever} | RunnablePassthrough()
65
  # Construct inputs for the llm
66
  input_documents = {
67
  "context":lambda x : _combine_documents(x["docs"]),
68
+ **pass_values(["question","audience","language","keywords"])
69
  }
70
 
71
  # ------- CHAIN 3
74
 
75
  answer_with_docs = {
76
  "answer": input_documents | prompt | llm_final | StrOutputParser(),
77
+ **pass_values(["question","audience","language","query","docs","keywords"]),
78
  }
79
 
80
  answer_without_docs = {
81
  "answer": prompt_without_docs | llm_final | StrOutputParser(),
82
+ **pass_values(["question","audience","language","query","docs","keywords"]),
83
  }
84
 
85
  # def has_images(x):
97
 
98
  # ------- FINAL CHAIN
99
  # Build the final chain
100
+ rag_chain = reformulation | keywords | find_documents | answer
101
 
102
  return rag_chain
103
 
104
 
105
+ def make_rag_papers_chain(llm):
106
+
107
+ prompt = ChatPromptTemplate.from_template(papers_prompt_template)
108
+
109
+ input_documents = {
110
+ "context":lambda x : _combine_documents(x["docs"]),
111
+ **pass_values(["question","language"])
112
+ }
113
+
114
+ chain = input_documents | prompt | llm | StrOutputParser()
115
+ chain = rename_chain(chain,"answer")
116
+
117
+ return chain
118
+
119
+
120
+
121
+
122
+
123
 
124
  def make_illustration_chain(llm):
125
 
climateqa/papers/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
4
+ import pyalex
5
+
6
+ pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"
7
+
8
+ class OpenAlex():
9
+ def __init__(self):
10
+ pass
11
+
12
+
13
+
14
+ def search(self,keywords,n_results = 100,after = None,before = None):
15
+ works = Works().search(keywords).get()
16
+
17
+ for page in works.paginate(per_page=n_results):
18
+ break
19
+
20
+ df_works = pd.DataFrame(page)
21
+
22
+ return works
23
+
24
+
25
+ def make_network(self):
26
+ pass
27
+
28
+
29
+ def get_abstract_from_inverted_index(self,index):
30
+
31
+ # Determine the maximum index to know the length of the reconstructed array
32
+ max_index = max([max(positions) for positions in index.values()])
33
+
34
+ # Initialize a list with placeholders for all positions
35
+ reconstructed = [''] * (max_index + 1)
36
+
37
+ # Iterate through the inverted index and place each token at its respective position(s)
38
+ for token, positions in index.items():
39
+ for position in positions:
40
+ reconstructed[position] = token
41
+
42
+ # Join the tokens to form the reconstructed sentence(s)
43
+ return ' '.join(reconstructed)
climateqa/papers/openalex.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import networkx as nx
3
+ import matplotlib.pyplot as plt
4
+ from pyvis.network import Network
5
+
6
+ from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders
7
+ import pyalex
8
+
9
+ pyalex.config.email = "theo.alvesdacosta@ekimetrics.com"
10
+
11
+ class OpenAlex():
12
+ def __init__(self):
13
+ pass
14
+
15
+
16
+
17
+ def search(self,keywords,n_results = 100,after = None,before = None):
18
+
19
+ if isinstance(keywords,str):
20
+ works = Works().search(keywords)
21
+ if after is not None:
22
+ assert isinstance(after,int), "after must be an integer"
23
+ assert after > 1900, "after must be greater than 1900"
24
+ works = works.filter(publication_year=f">{after}")
25
+
26
+ for page in works.paginate(per_page=n_results):
27
+ break
28
+
29
+ df_works = pd.DataFrame(page)
30
+ df_works["abstract"] = df_works["abstract_inverted_index"].apply(lambda x: self.get_abstract_from_inverted_index(x))
31
+ df_works["is_oa"] = df_works["open_access"].map(lambda x : x.get("is_oa",False))
32
+ df_works["pdf_url"] = df_works["primary_location"].map(lambda x : x.get("pdf_url",None))
33
+ df_works["content"] = df_works["title"] + "\n" + df_works["abstract"]
34
+
35
+ else:
36
+ df_works = []
37
+ for keyword in keywords:
38
+ df_keyword = self.search(keyword,n_results = n_results,after = after,before = before)
39
+ df_works.append(df_keyword)
40
+ df_works = pd.concat(df_works,ignore_index=True,axis = 0)
41
+ return df_works
42
+
43
+
44
+ def rerank(self,query,df,reranker):
45
+
46
+ scores = reranker.rank(
47
+ query,
48
+ df["content"].tolist(),
49
+ top_k = len(df),
50
+ )
51
+ scores.sort(key = lambda x : x["corpus_id"])
52
+ scores = [x["score"] for x in scores]
53
+ df["rerank_score"] = scores
54
+ return df
55
+
56
+
57
+ def make_network(self,df):
58
+
59
+ # Initialize your graph
60
+ G = nx.DiGraph()
61
+
62
+ for i,row in df.iterrows():
63
+ paper = row.to_dict()
64
+ G.add_node(paper['id'], **paper)
65
+ for reference in paper['referenced_works']:
66
+ if reference not in G:
67
+ pass
68
+ else:
69
+ # G.add_node(reference, id=reference, title="", reference_works=[], original=False)
70
+ G.add_edge(paper['id'], reference, relationship="CITING")
71
+ return G
72
+
73
+ def show_network(self,G,height = "750px",notebook = True,color_by = "pagerank"):
74
+
75
+ net = Network(height=height, width="100%", bgcolor="#ffffff", font_color="black",notebook = notebook,directed = True,neighborhood_highlight = True)
76
+ net.force_atlas_2based()
77
+
78
+ # Add nodes with size reflecting the PageRank to highlight importance
79
+ pagerank = nx.pagerank(G)
80
+
81
+ if color_by == "pagerank":
82
+ color_scores = pagerank
83
+ elif color_by == "rerank_score":
84
+ color_scores = {node: G.nodes[node].get("rerank_score", 0) for node in G.nodes}
85
+ else:
86
+ raise ValueError(f"Unknown color_by value: {color_by}")
87
+
88
+ # Normalize PageRank values to [0, 1] for color mapping
89
+ min_score = min(color_scores.values())
90
+ max_score = max(color_scores.values())
91
+ norm_color_scores = {node: (color_scores[node] - min_score) / (max_score - min_score) for node in color_scores}
92
+
93
+
94
+
95
+ for node in G.nodes:
96
+ info = G.nodes[node]
97
+ title = info["title"]
98
+ label = title[:30] + " ..."
99
+
100
+ title = [title,f"Year: {info['publication_year']}",f"ID: {info['id']}"]
101
+ title = "\n".join(title)
102
+
103
+ color_value = norm_color_scores[node]
104
+ # Generating a color from blue (low) to red (high)
105
+ color = plt.cm.RdBu_r(color_value) # coolwarm is a matplotlib colormap from blue to red
106
+ def clamp(x):
107
+ return int(max(0, min(x*255, 255)))
108
+ color = tuple([clamp(x) for x in color[:3]])
109
+ color = '#%02x%02x%02x' % color
110
+
111
+ net.add_node(node, title=title,size = pagerank[node]*1000,label = label,color = color)
112
+
113
+ # Add edges
114
+ for edge in G.edges:
115
+ net.add_edge(edge[0], edge[1],arrowStrikethrough=True,color = "gray")
116
+
117
+ # Show the network
118
+ if notebook:
119
+ return net.show("network.html")
120
+ else:
121
+ return net
122
+
123
+
124
+ def get_abstract_from_inverted_index(self,index):
125
+
126
+ if index is None:
127
+ return ""
128
+ else:
129
+
130
+ # Determine the maximum index to know the length of the reconstructed array
131
+ max_index = max([max(positions) for positions in index.values()])
132
+
133
+ # Initialize a list with placeholders for all positions
134
+ reconstructed = [''] * (max_index + 1)
135
+
136
+ # Iterate through the inverted index and place each token at its respective position(s)
137
+ for token, positions in index.items():
138
+ for position in positions:
139
+ reconstructed[position] = token
140
+
141
+ # Join the tokens to form the reconstructed sentence(s)
142
+ return ' '.join(reconstructed)
requirements.txt CHANGED
@@ -5,6 +5,9 @@ python-dotenv==1.0.0
5
  langchain==0.1.4
6
  langchain_openai==0.0.6
7
  pinecone-client==3.0.2
8
- sentence-transformers
9
  huggingface-hub
10
- msal
 
 
 
5
  langchain==0.1.4
6
  langchain_openai==0.0.6
7
  pinecone-client==3.0.2
8
+ sentence-transformers==2.6.0
9
  huggingface-hub
10
+ msal
11
+ pyalex==0.13
12
+ networkx==3.2.1
13
+ pyvis==0.3.2