timeki's picture
front UI change
c9346b3
raw
history blame
30.3 kB
from climateqa.engine.embeddings import get_embeddings_function
embeddings_function = get_embeddings_function()
from climateqa.knowledge.openalex import OpenAlex
from sentence_transformers import CrossEncoder
# reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1")
oa = OpenAlex()
import gradio as gr
from gradio_modal import Modal
import pandas as pd
import numpy as np
import os
import time
import re
import json
from gradio import ChatMessage
# from gradio_modal import Modal
from io import BytesIO
import base64
from datetime import datetime
from azure.storage.fileshare import ShareServiceClient
from utils import create_user_id
from gradio_modal import Modal
from PIL import Image
from langchain_core.runnables.schema import StreamEvent
# ClimateQ&A imports
from climateqa.engine.llm import get_llm
from climateqa.engine.vectorstore import get_pinecone_vectorstore
# from climateqa.knowledge.retriever import ClimateQARetriever
from climateqa.engine.reranker import get_reranker
from climateqa.engine.embeddings import get_embeddings_function
from climateqa.engine.chains.prompts import audience_prompts
from climateqa.sample_questions import QUESTIONS
from climateqa.constants import POSSIBLE_REPORTS, OWID_CATEGORIES
from climateqa.utils import get_image_from_azure_blob_storage
from climateqa.engine.keywords import make_keywords_chain
from climateqa.engine.chains.answer_rag import make_rag_papers_chain
from climateqa.engine.graph import make_graph_agent
from climateqa.engine.embeddings import get_embeddings_function
from front.utils import serialize_docs,process_figures,make_html_df
from climateqa.event_handler import init_audience, handle_retrieved_documents, stream_answer,handle_retrieved_owid_graphs
# Load environment variables in local mode
try:
from dotenv import load_dotenv
load_dotenv()
except Exception as e:
pass
import requests
# Set up Gradio Theme
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="red",
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
)
init_prompt = ""
system_template = {
"role": "system",
"content": init_prompt,
}
account_key = os.environ["BLOB_ACCOUNT_KEY"]
if len(account_key) == 86:
account_key += "=="
credential = {
"account_key": account_key,
"account_name": os.environ["BLOB_ACCOUNT_NAME"],
}
account_url = os.environ["BLOB_ACCOUNT_URL"]
file_share_name = "climateqa"
service = ShareServiceClient(account_url=account_url, credential=credential)
share_client = service.get_share_client(file_share_name)
user_id = create_user_id()
CITATION_LABEL = "BibTeX citation for ClimateQ&A"
CITATION_TEXT = r"""@misc{climateqa,
author={Théo Alves Da Costa, Timothée Bohe},
title={ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
year={2024},
howpublished= {\url{https://climateqa.com}},
}
@software{climateqa,
author = {Théo Alves Da Costa, Timothée Bohe},
publisher = {ClimateQ&A},
title = {ClimateQ&A, AI-powered conversational assistant for climate change and biodiversity loss},
}
"""
# Create vectorstore and retriever
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX"))
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="title")
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
reranker = get_reranker("nano")
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
async def chat(query, history, audience, sources, reports, relevant_content_sources):
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of:
(messages in gradio format, messages in langchain format, source documents)"""
date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f">> NEW QUESTION ({date_now}) : {query}")
audience_prompt = init_audience(audience)
# Prepare default values
if sources is None or len(sources) == 0:
sources = ["IPCC", "IPBES", "IPOS"]
if reports is None or len(reports) == 0:
reports = []
inputs = {"user_input": query,"audience": audience_prompt,"sources_input":sources, "relevant_content_sources" : relevant_content_sources}
result = agent.astream_events(inputs,version = "v1")
docs = []
used_figures=[]
related_contents = []
docs_html = ""
output_query = ""
output_language = ""
output_keywords = ""
start_streaming = False
graphs_html = ""
figures = '<div class="figures-container"><p></p> </div>'
steps_display = {
"categorize_intent":("🔄️ Analyzing user message",True),
"transform_query":("🔄️ Thinking step by step to answer the question",True),
"retrieve_documents":("🔄️ Searching in the knowledge base",False),
}
used_documents = []
answer_message_content = ""
try:
async for event in result:
if "langgraph_node" in event["metadata"]:
node = event["metadata"]["langgraph_node"]
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
intent = event["data"]["output"]["intent"]
if "language" in event["data"]["output"]:
output_language = event["data"]["output"]["language"]
else :
output_language = "English"
history[-1].content = f"Language identified : {output_language} \n Intent identified : {intent}"
elif event["name"] in steps_display.keys() and event["event"] == "on_chain_start": #display steps
event_description, display_output = steps_display[node]
if not hasattr(history[-1], 'metadata') or history[-1].metadata["title"] != event_description: # if a new step begins
history.append(ChatMessage(role="assistant", content = "", metadata={'title' :event_description}))
elif event["name"] != "transform_query" and event["event"] == "on_chat_model_stream" and node in ["answer_rag", "answer_search","answer_chitchat"]:# if streaming answer
history, start_streaming, answer_message_content = stream_answer(history, event, start_streaming, answer_message_content)
elif event["name"] in ["retrieve_graphs", "retrieve_graphs_ai"] and event["event"] == "on_chain_end":
graphs_html = handle_retrieved_owid_graphs(event, graphs_html)
if event["name"] == "transform_query" and event["event"] =="on_chain_end":
if hasattr(history[-1],"content"):
history[-1].content += "Decompose question into sub-questions: \n\n - " + "\n - ".join([q["question"] for q in event["data"]["output"]["remaining_questions"]])
if event["name"] == "categorize_intent" and event["event"] == "on_chain_start":
print("X")
yield history, docs_html, output_query, output_language, related_contents , graphs_html, #,output_query,output_keywords
except Exception as e:
print(event, "has failed")
raise gr.Error(f"{e}")
try:
# Log answer on Azure Blob Storage
if os.getenv("GRADIO_ENV") != "local":
timestamp = str(datetime.now().timestamp())
file = timestamp + ".json"
prompt = history[1]["content"]
logs = {
"user_id": str(user_id),
"prompt": prompt,
"query": prompt,
"question":output_query,
"sources":sources,
"docs":serialize_docs(docs),
"answer": history[-1].content,
"time": timestamp,
}
log_on_azure(file, logs, share_client)
except Exception as e:
print(f"Error logging on Azure Blob Storage: {e}")
raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)")
yield history, docs_html, output_query, output_language, related_contents, graphs_html
def save_feedback(feed: str, user_id):
if len(feed) > 1:
timestamp = str(datetime.now().timestamp())
file = user_id + timestamp + ".json"
logs = {
"user_id": user_id,
"feedback": feed,
"time": timestamp,
}
log_on_azure(file, logs, share_client)
return "Feedback submitted, thank you!"
def log_on_azure(file, logs, share_client):
logs = json.dumps(logs)
file_client = share_client.get_file_client(file)
file_client.upload_file(logs)
def generate_keywords(query):
chain = make_keywords_chain(llm)
keywords = chain.invoke(query)
keywords = " AND ".join(keywords["keywords"])
return keywords
papers_cols_widths = {
"id":100,
"title":300,
"doi":100,
"publication_year":100,
"abstract":500,
"is_oa":50,
}
papers_cols = list(papers_cols_widths.keys())
papers_cols_widths = list(papers_cols_widths.values())
async def find_papers(query,after, relevant_content_sources):
if "OpenAlex" in relevant_content_sources:
summary = ""
keywords = generate_keywords(query)
df_works = oa.search(keywords,after = after)
df_works = df_works.dropna(subset=["abstract"])
df_works = oa.rerank(query,df_works,reranker)
df_works = df_works.sort_values("rerank_score",ascending=False)
docs_html = []
for i in range(10):
docs_html.append(make_html_df(df_works, i))
docs_html = "".join(docs_html)
print(docs_html)
G = oa.make_network(df_works)
height = "750px"
network = oa.show_network(G,color_by = "rerank_score",notebook=False,height = height)
network_html = network.generate_html()
network_html = network_html.replace("'", "\"")
css_to_inject = "<style>#mynetwork { border: none !important; } .card { border: none !important; }</style>"
network_html = network_html + css_to_inject
network_html = f"""<iframe style="width: 100%; height: {height};margin:0 auto" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{network_html}'></iframe>"""
docs = df_works["content"].head(10).tolist()
df_works = df_works.reset_index(drop = True).reset_index().rename(columns = {"index":"doc"})
df_works["doc"] = df_works["doc"] + 1
df_works = df_works[papers_cols]
yield docs_html, network_html, summary
chain = make_rag_papers_chain(llm)
result = chain.astream_log({"question": query,"docs": docs,"language":"English"})
path_answer = "/logs/StrOutputParser/streamed_output/-"
async for op in result:
op = op.ops[0]
if op['path'] == path_answer: # reforulated question
new_token = op['value'] # str
summary += new_token
else:
continue
yield docs_html, network_html, summary
else :
yield "","", ""
# --------------------------------------------------------------------
# Gradio
# --------------------------------------------------------------------
init_prompt = """
Hello, I am ClimateQ&A, a conversational assistant designed to help you understand climate change and biodiversity loss. I will answer your questions by **sifting through the IPCC and IPBES scientific reports**.
❓ How to use
- **Language**: You can ask me your questions in any language.
- **Audience**: You can specify your audience (children, general public, experts) to get a more adapted answer.
- **Sources**: You can choose to search in the IPCC or IPBES reports, or both.
⚠️ Limitations
*Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.*
🛈 Information
Please note that we log your questions for meta-analysis purposes, so avoid sharing any sensitive or personal information.
What do you want to learn ?
"""
def vote(data: gr.LikeData):
if data.liked:
print(data.value)
else:
print(data)
def save_graph(saved_graphs_state, embedding, category):
print(f"\nCategory:\n{saved_graphs_state}\n")
if category not in saved_graphs_state:
saved_graphs_state[category] = []
if embedding not in saved_graphs_state[category]:
saved_graphs_state[category].append(embedding)
return saved_graphs_state, gr.Button("Graph Saved")
with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=theme,elem_id = "main-component") as demo:
chat_completed_state = gr.State(0)
current_graphs = gr.State([])
saved_graphs = gr.State({})
with gr.Tab("ClimateQ&A"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2):
chatbot = gr.Chatbot(
value = [ChatMessage(role="assistant", content=init_prompt)],
type = "messages",
show_copy_button=True,
show_label = False,
elem_id="chatbot",
layout = "panel",
avatar_images = (None,"https://i.ibb.co/YNyd5W2/logo4.png"),
max_height="80vh",
height="100vh"
)
# bot.like(vote,None,None)
with gr.Row(elem_id = "input-message"):
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
with gr.Column(scale=2, variant="panel",elem_id = "right-panel"):
with gr.Tabs(elem_id = "right_panel_tab") as tabs:
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0):
examples_hidden = gr.Textbox(visible = False)
first_key = list(QUESTIONS.keys())[0]
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples")
samples = []
for i,key in enumerate(QUESTIONS.keys()):
examples_visible = True if i == 0 else False
with gr.Row(visible = examples_visible) as group_examples:
examples_questions = gr.Examples(
QUESTIONS[key],
[examples_hidden],
examples_per_page=8,
run_on_click=False,
elem_id=f"examples{i}",
api_name=f"examples{i}",
# label = "Click on the example question or enter your own",
# cache_examples=True,
)
samples.append(group_examples)
with gr.Tab("Configuration", id = 10, ) as tab_config:
gr.Markdown("Reminders: You can talk in any language, ClimateQ&A is multi-lingual!")
with gr.Row():
dropdown_sources = gr.CheckboxGroup(
["IPCC", "IPBES","IPOS"],
label="Select source",
value=["IPCC"],
interactive=True,
)
dropdown_external_sources = gr.CheckboxGroup(
["IPCC figures","OpenAlex", "OurWorldInData"],
label="Select database to search for relevant content",
value=["IPCC figures"],
interactive=True,
)
dropdown_reports = gr.Dropdown(
POSSIBLE_REPORTS,
label="Or select specific reports",
multiselect=True,
value=None,
interactive=True,
)
dropdown_audience = gr.Dropdown(
["Children","General public","Experts"],
label="Select audience",
value="Experts",
interactive=True,
)
after = gr.Slider(minimum=1950,maximum=2023,step=1,value=1960,label="Publication date",show_label=True,interactive=True,elem_id="date-papers", visible=False)
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False, visible= False)
output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False, visible= False)
dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after])
# dropdown_external_sources.change(lambda x: gr.update(visible = True ) if "OpenAlex" in x else gr.update(visible=False) , inputs=[dropdown_external_sources], outputs=[after], visible=True)
with gr.Tab("Sources",elem_id = "tab-sources",id = 1) as tab_sources:
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox")
docs_textbox = gr.State("")
with gr.Tab("Recommended content", elem_id="tab-recommended_content",id=2) as tab_recommended_content:
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
sources_raw = gr.State()
with Modal(visible=False, elem_id="modal_figure_galery") as modal:
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
show_full_size_figures = gr.Button("Show figures in full size",elem_id="show-figures",interactive=True)
show_full_size_figures.click(lambda : Modal(visible=True),None,modal)
figures_cards = gr.HTML(show_label=False, elem_id="sources-figures")
with gr.Tab("Papers",elem_id = "tab-citations",id = 4) as tab_papers:
# btn_summary = gr.Button("Summary")
# Fenêtre simulée pour le Summary
with gr.Accordion(visible=True, elem_id="papers-summary-popup", label= "See summary of relevant papers", open= False) as summary_popup:
papers_summary = gr.Markdown("", visible=True, elem_id="papers-summary")
# btn_relevant_papers = gr.Button("Relevant papers")
# Fenêtre simulée pour les Relevant Papers
with gr.Accordion(visible=True, elem_id="papers-relevant-popup",label= "See relevant papers", open= False) as relevant_popup:
papers_html = gr.HTML(show_label=False, elem_id="papers-textbox")
docs_textbox = gr.State("")
btn_citations_network = gr.Button("Explore papers citations network")
# Fenêtre simulée pour le Citations Network
with Modal(visible=False) as modal:
citations_network = gr.HTML("<h3>Citations Network Graph</h3>", visible=True, elem_id="papers-citations-network")
btn_citations_network.click(lambda: Modal(visible=True), None, modal)
with gr.Tab("Graphs", elem_id="tab-graphs", id=5) as tab_graphs:
graphs_container = gr.HTML("<h2>There are no graphs to be displayed at the moment. Try asking another question.</h2>",elem_id="graphs-container")
current_graphs.change(lambda x : x, inputs=[current_graphs], outputs=[graphs_container])
# with gr.Tab("OECD",elem_id = "tab-oecd",id = 6):
# oecd_indicator = "RIVER_FLOOD_RP100_POP_SH"
# oecd_topic = "climate"
# oecd_latitude = "46.8332"
# oecd_longitude = "5.3725"
# oecd_zoom = "5.6442"
# # Create the HTML content with the iframe
# iframe_html = f"""
# <iframe src="https://localdataportal.oecd.org/maps.html?indicator={oecd_indicator}&topic={oecd_topic}&latitude={oecd_latitude}&longitude={oecd_longitude}&zoom={oecd_zoom}"
# width="100%" height="600" frameborder="0" style="border:0;" allowfullscreen></iframe>
# """
# oecd_textbox = gr.HTML(iframe_html, show_label=False, elem_id="oecd-textbox")
#---------------------------------------------------------------------------------------
# OTHER TABS
#---------------------------------------------------------------------------------------
# with gr.Tab("Settings",elem_id = "tab-config",id = 2):
# gr.Markdown("Reminder: You can talk in any language, ClimateQ&A is multi-lingual!")
# dropdown_sources = gr.CheckboxGroup(
# ["IPCC", "IPBES","IPOS", "OpenAlex"],
# label="Select source",
# value=["IPCC"],
# interactive=True,
# )
# dropdown_reports = gr.Dropdown(
# POSSIBLE_REPORTS,
# label="Or select specific reports",
# multiselect=True,
# value=None,
# interactive=True,
# )
# dropdown_audience = gr.Dropdown(
# ["Children","General public","Experts"],
# label="Select audience",
# value="Experts",
# interactive=True,
# )
# output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False)
# output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False)
with gr.Tab("About",elem_classes = "max-height other-tabs"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
### More info
- See more info at [https://climateqa.com](https://climateqa.com/docs/intro/)
- Feedbacks on this [form](https://forms.office.com/e/1Yzgxm6jbp)
### Citation
"""
)
with gr.Accordion(CITATION_LABEL,elem_id="citation", open = False,):
# # Display citation label and text)
gr.Textbox(
value=CITATION_TEXT,
label="",
interactive=False,
show_copy_button=True,
lines=len(CITATION_TEXT.split('\n')),
)
def start_chat(query,history):
history = history + [ChatMessage(role="user", content=query)]
return (gr.update(interactive = False),gr.update(selected=1),history)
def finish_chat():
return gr.update(interactive = True,value = "")
# Initialize visibility states
summary_visible = False
relevant_visible = False
# Functions to toggle visibility
def toggle_summary_visibility():
global summary_visible
summary_visible = not summary_visible
return gr.update(visible=summary_visible)
def toggle_relevant_visibility():
global relevant_visible
relevant_visible = not relevant_visible
return gr.update(visible=relevant_visible)
def change_completion_status(current_state):
current_state = 1 - current_state
return current_state
def update_sources_number_display(sources_textbox, figures_cards, current_graphs, papers_html):
sources_number = sources_textbox.count("<h2>")
figures_number = figures_cards.count("<h2>")
graphs_number = current_graphs.count("<iframe")
papers_number = papers_html.count("<h2>")
sources_notif_label = f"Sources ({sources_number})"
figures_notif_label = f"Figures ({figures_number})"
graphs_notif_label = f"Graphs ({graphs_number})"
papers_notif_label = f"Papers ({papers_number})"
recommended_content_notif_label = f"Recommended content ({figures_number + graphs_number + papers_number})"
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
(textbox
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
.then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
)
(examples_hidden
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
)
def change_sample_questions(key):
index = list(QUESTIONS.keys()).index(key)
visible_bools = [False] * len(samples)
visible_bools[index] = True
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
# update sources numbers
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
figures_cards.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
current_graphs.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
papers_html.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
# other questions examples
dropdown_samples.change(change_sample_questions,dropdown_samples,samples)
# search for papers
textbox.submit(find_papers,[textbox,after, dropdown_external_sources], [papers_html,citations_network,papers_summary])
examples_hidden.change(find_papers,[examples_hidden,after,dropdown_external_sources], [papers_html,citations_network,papers_summary])
# btn_summary.click(toggle_summary_visibility, outputs=summary_popup)
# btn_relevant_papers.click(toggle_relevant_visibility, outputs=relevant_popup)
demo.queue()
demo.launch(ssr_mode=False)