Spaces:
Runtime error
Runtime error
| # Import necessary libraries | |
| import os | |
| import gradio as gr | |
| from azure.storage.fileshare import ShareServiceClient | |
| # Import custom modules | |
| from climateqa.engine.embeddings import get_embeddings_function | |
| from climateqa.engine.llm import get_llm | |
| from climateqa.engine.vectorstore import get_azure_search_vectorstore | |
| from climateqa.engine.reranker import get_reranker | |
| from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc | |
| from climateqa.engine.chains.retrieve_papers import find_papers | |
| from climateqa.chat import start_chat, chat_stream, finish_chat | |
| from front.tabs import create_config_modal, cqa_tab, create_about_tab | |
| from front.tabs import MainTabPanel, ConfigPanel | |
| from front.tabs.tab_drias import create_drias_tab | |
| from front.tabs.tab_ipcc import create_ipcc_tab | |
| from front.utils import process_figures | |
| from gradio_modal import Modal | |
| from utils import create_user_id | |
| import logging | |
| logging.basicConfig(level=logging.WARNING) | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs | |
| logging.getLogger().setLevel(logging.WARNING) | |
| # Load environment variables in local mode | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except Exception as e: | |
| pass | |
| # Set up Gradio Theme | |
| theme = gr.themes.Base( | |
| primary_hue="blue", | |
| secondary_hue="red", | |
| font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| # Azure Blob Storage credentials | |
| account_key = os.environ["BLOB_ACCOUNT_KEY"] | |
| if len(account_key) == 86: | |
| account_key += "==" | |
| credential = { | |
| "account_key": account_key, | |
| "account_name": os.environ["BLOB_ACCOUNT_NAME"], | |
| } | |
| account_url = os.environ["BLOB_ACCOUNT_URL"] | |
| file_share_name = "climateqa" | |
| service = ShareServiceClient(account_url=account_url, credential=credential) | |
| share_client = service.get_share_client(file_share_name) | |
| user_id = create_user_id() | |
| # Create vectorstore and retriever | |
| embeddings_function = get_embeddings_function() | |
| vectorstore = get_azure_search_vectorstore(embeddings=embeddings_function, index_name="climateqa-ipx") | |
| vectorstore_graphs = get_azure_search_vectorstore(embeddings=embeddings_function, index_name="climateqa-owid", text_key="description") | |
| vectorstore_region = get_azure_search_vectorstore(embeddings=embeddings_function, index_name="climateqa-v2") | |
| llm = get_llm(provider="azure", max_tokens=3000, temperature=0.0, streaming=True) | |
| if os.environ["GRADIO_ENV"] == "local": | |
| reranker = get_reranker("nano") | |
| else: | |
| reranker = get_reranker("nano") | |
| agent = make_graph_agent( | |
| llm=llm, | |
| vectorstore_ipcc=vectorstore, | |
| vectorstore_graphs=vectorstore_graphs, | |
| vectorstore_region=vectorstore_region, | |
| reranker=reranker, | |
| threshold_docs=0.2, | |
| ) | |
| agent_poc = make_graph_agent_poc( | |
| llm=llm, | |
| vectorstore_ipcc=vectorstore, | |
| vectorstore_graphs=vectorstore_graphs, | |
| vectorstore_region=vectorstore_region, | |
| reranker=reranker, | |
| threshold_docs=0, | |
| version="v4", | |
| ) # TODO put back default 0.2 | |
| async def chat( | |
| query, | |
| history, | |
| audience, | |
| sources, | |
| reports, | |
| relevant_content_sources_selection, | |
| search_only, | |
| ): | |
| print("chat cqa - message received") | |
| # Ensure default values if components are not set | |
| audience = audience or "Experts" | |
| sources = sources or ["IPCC", "IPBES"] | |
| reports = reports or [] | |
| relevant_content_sources_selection = relevant_content_sources_selection or ["Figures (IPCC/IPBES)"] | |
| search_only = bool(search_only) # Convert to boolean if None | |
| async for event in chat_stream( | |
| agent, | |
| query, | |
| history, | |
| audience, | |
| sources, | |
| reports, | |
| relevant_content_sources_selection, | |
| search_only, | |
| share_client, | |
| user_id, | |
| ): | |
| yield event | |
| async def chat_poc( | |
| query, | |
| history, | |
| audience, | |
| sources, | |
| reports, | |
| relevant_content_sources_selection, | |
| search_only, | |
| ): | |
| print("chat poc - message received") | |
| async for event in chat_stream( | |
| agent_poc, | |
| query, | |
| history, | |
| audience, | |
| sources, | |
| reports, | |
| relevant_content_sources_selection, | |
| search_only, | |
| share_client, | |
| user_id, | |
| ): | |
| yield event | |
| # -------------------------------------------------------------------- | |
| # Gradio | |
| # -------------------------------------------------------------------- | |
| # Function to update modal visibility | |
| def update_config_modal_visibility(config_open): | |
| print(config_open) | |
| new_config_visibility_status = not config_open | |
| return Modal(visible=new_config_visibility_status), new_config_visibility_status | |
| def update_sources_number_display( | |
| sources_textbox, figures_cards, current_graphs, papers_html | |
| ): | |
| sources_number = sources_textbox.count("<h2>") | |
| figures_number = figures_cards.count("<h2>") | |
| graphs_number = current_graphs.count("<iframe") | |
| papers_number = papers_html.count("<h2>") | |
| sources_notif_label = f"Sources ({sources_number})" | |
| figures_notif_label = f"Figures ({figures_number})" | |
| graphs_notif_label = f"Graphs ({graphs_number})" | |
| papers_notif_label = f"Papers ({papers_number})" | |
| recommended_content_notif_label = ( | |
| f"Recommended content ({figures_number + graphs_number + papers_number})" | |
| ) | |
| return ( | |
| gr.update(label=recommended_content_notif_label), | |
| gr.update(label=sources_notif_label), | |
| gr.update(label=figures_notif_label), | |
| gr.update(label=graphs_notif_label), | |
| gr.update(label=papers_notif_label), | |
| ) | |
| def config_event_handling( | |
| main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel | |
| ): | |
| config_open = config_componenets.config_open | |
| config_modal = config_componenets.config_modal | |
| close_config_modal = config_componenets.close_config_modal_button | |
| for button in [close_config_modal] + [ | |
| main_tab_component.config_button for main_tab_component in main_tabs_components | |
| ]: | |
| button.click( | |
| fn=update_config_modal_visibility, | |
| inputs=[config_open], | |
| outputs=[config_modal, config_open], | |
| ) | |
| def event_handling( | |
| main_tab_components: MainTabPanel, | |
| config_components: ConfigPanel, | |
| tab_name="ClimateQ&A", | |
| ): | |
| chatbot = main_tab_components.chatbot | |
| textbox = main_tab_components.textbox | |
| tabs = main_tab_components.tabs | |
| sources_raw = main_tab_components.sources_raw | |
| new_figures = main_tab_components.new_figures | |
| current_graphs = main_tab_components.current_graphs | |
| examples_hidden = main_tab_components.examples_hidden | |
| sources_textbox = main_tab_components.sources_textbox | |
| figures_cards = main_tab_components.figures_cards | |
| gallery_component = main_tab_components.gallery_component | |
| papers_direct_search = main_tab_components.papers_direct_search | |
| papers_html = main_tab_components.papers_html | |
| citations_network = main_tab_components.citations_network | |
| papers_summary = main_tab_components.papers_summary | |
| tab_recommended_content = main_tab_components.tab_recommended_content | |
| tab_sources = main_tab_components.tab_sources | |
| tab_figures = main_tab_components.tab_figures | |
| tab_graphs = main_tab_components.tab_graphs | |
| tab_papers = main_tab_components.tab_papers | |
| graphs_container = main_tab_components.graph_container | |
| follow_up_examples = main_tab_components.follow_up_examples | |
| follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden | |
| dropdown_sources = config_components.dropdown_sources | |
| dropdown_reports = config_components.dropdown_reports | |
| dropdown_external_sources = config_components.dropdown_external_sources | |
| search_only = config_components.search_only | |
| dropdown_audience = config_components.dropdown_audience | |
| after = config_components.after | |
| output_query = config_components.output_query | |
| output_language = config_components.output_language | |
| new_sources_hmtl = gr.State([]) | |
| ttd_data = gr.State([]) | |
| if tab_name == "ClimateQ&A": | |
| print("chat cqa - message sent") | |
| # Event for textbox | |
| ( | |
| textbox.submit( | |
| start_chat, | |
| [textbox, chatbot, search_only], | |
| [textbox, tabs, chatbot, sources_raw], | |
| queue=False, | |
| api_name=f"start_chat_{textbox.elem_id}", | |
| ) | |
| .then( | |
| chat, | |
| [ | |
| textbox, | |
| chatbot, | |
| dropdown_audience, | |
| dropdown_sources, | |
| dropdown_reports, | |
| dropdown_external_sources, | |
| search_only, | |
| ], | |
| [ | |
| chatbot, | |
| new_sources_hmtl, | |
| output_query, | |
| output_language, | |
| new_figures, | |
| current_graphs, | |
| follow_up_examples.dataset, | |
| ], | |
| concurrency_limit=8, | |
| api_name=f"chat_{textbox.elem_id}", | |
| ) | |
| .then( | |
| finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}" | |
| ) | |
| ) | |
| # Event for examples_hidden | |
| ( | |
| examples_hidden.change( | |
| start_chat, | |
| [examples_hidden, chatbot, search_only], | |
| [examples_hidden, tabs, chatbot, sources_raw], | |
| queue=False, | |
| api_name=f"start_chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| chat, | |
| [ | |
| examples_hidden, | |
| chatbot, | |
| dropdown_audience, | |
| dropdown_sources, | |
| dropdown_reports, | |
| dropdown_external_sources, | |
| search_only, | |
| ], | |
| [ | |
| chatbot, | |
| new_sources_hmtl, | |
| output_query, | |
| output_language, | |
| new_figures, | |
| current_graphs, | |
| follow_up_examples.dataset, | |
| ], | |
| concurrency_limit=8, | |
| api_name=f"chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| finish_chat, | |
| None, | |
| [textbox], | |
| api_name=f"finish_chat_{examples_hidden.elem_id}", | |
| ) | |
| ) | |
| ( | |
| follow_up_examples_hidden.change( | |
| start_chat, | |
| [follow_up_examples_hidden, chatbot, search_only], | |
| [follow_up_examples_hidden, tabs, chatbot, sources_raw], | |
| queue=False, | |
| api_name=f"start_chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| chat, | |
| [ | |
| follow_up_examples_hidden, | |
| chatbot, | |
| dropdown_audience, | |
| dropdown_sources, | |
| dropdown_reports, | |
| dropdown_external_sources, | |
| search_only, | |
| ], | |
| [ | |
| chatbot, | |
| new_sources_hmtl, | |
| output_query, | |
| output_language, | |
| new_figures, | |
| current_graphs, | |
| follow_up_examples.dataset, | |
| ], | |
| concurrency_limit=8, | |
| api_name=f"chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| finish_chat, | |
| None, | |
| [textbox], | |
| api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}", | |
| ) | |
| ) | |
| elif tab_name == "France - Local Q&A": | |
| print("chat poc - message sent") | |
| # Event for textbox | |
| ( | |
| textbox.submit( | |
| start_chat, | |
| [textbox, chatbot, search_only], | |
| [textbox, tabs, chatbot, sources_raw], | |
| queue=False, | |
| api_name=f"start_chat_{textbox.elem_id}", | |
| ) | |
| .then( | |
| chat_poc, | |
| [ | |
| textbox, | |
| chatbot, | |
| dropdown_audience, | |
| dropdown_sources, | |
| dropdown_reports, | |
| dropdown_external_sources, | |
| search_only, | |
| ], | |
| [ | |
| chatbot, | |
| new_sources_hmtl, | |
| output_query, | |
| output_language, | |
| new_figures, | |
| current_graphs, | |
| follow_up_examples.dataset, | |
| ], | |
| concurrency_limit=8, | |
| api_name=f"chat_{textbox.elem_id}", | |
| ) | |
| .then( | |
| finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}" | |
| ) | |
| ) | |
| # Event for examples_hidden | |
| ( | |
| examples_hidden.change( | |
| start_chat, | |
| [examples_hidden, chatbot, search_only], | |
| [examples_hidden, tabs, chatbot, sources_raw], | |
| queue=False, | |
| api_name=f"start_chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| chat_poc, | |
| [ | |
| examples_hidden, | |
| chatbot, | |
| dropdown_audience, | |
| dropdown_sources, | |
| dropdown_reports, | |
| dropdown_external_sources, | |
| search_only, | |
| ], | |
| [ | |
| chatbot, | |
| new_sources_hmtl, | |
| output_query, | |
| output_language, | |
| new_figures, | |
| current_graphs, | |
| follow_up_examples.dataset, | |
| ], | |
| concurrency_limit=8, | |
| api_name=f"chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| finish_chat, | |
| None, | |
| [textbox], | |
| api_name=f"finish_chat_{examples_hidden.elem_id}", | |
| ) | |
| ) | |
| ( | |
| follow_up_examples_hidden.change( | |
| start_chat, | |
| [follow_up_examples_hidden, chatbot, search_only], | |
| [follow_up_examples_hidden, tabs, chatbot, sources_raw], | |
| queue=False, | |
| api_name=f"start_chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| chat, | |
| [ | |
| follow_up_examples_hidden, | |
| chatbot, | |
| dropdown_audience, | |
| dropdown_sources, | |
| dropdown_reports, | |
| dropdown_external_sources, | |
| search_only, | |
| ], | |
| [ | |
| chatbot, | |
| new_sources_hmtl, | |
| output_query, | |
| output_language, | |
| new_figures, | |
| current_graphs, | |
| follow_up_examples.dataset, | |
| ], | |
| concurrency_limit=8, | |
| api_name=f"chat_{examples_hidden.elem_id}", | |
| ) | |
| .then( | |
| finish_chat, | |
| None, | |
| [textbox], | |
| api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}", | |
| ) | |
| ) | |
| new_sources_hmtl.change( | |
| lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox] | |
| ) | |
| current_graphs.change( | |
| lambda x: x, inputs=[current_graphs], outputs=[graphs_container] | |
| ) | |
| new_figures.change( | |
| process_figures, | |
| inputs=[sources_raw, new_figures], | |
| outputs=[sources_raw, figures_cards, gallery_component], | |
| ) | |
| # Update sources numbers | |
| for component in [sources_textbox, figures_cards, current_graphs, papers_html]: | |
| component.change( | |
| update_sources_number_display, | |
| [sources_textbox, figures_cards, current_graphs, papers_html], | |
| [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers], | |
| ) | |
| # Search for papers | |
| for component in [textbox, examples_hidden, papers_direct_search]: | |
| component.submit( | |
| find_papers, | |
| [component, after, dropdown_external_sources], | |
| [papers_html, citations_network, papers_summary], | |
| ) | |
| # if tab_name == "France - Local Q&A": # Not untill results are good enough | |
| # # Drias search | |
| # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display]) | |
| def main_ui(): | |
| # config_open = gr.State(True) | |
| # Read CSS from file | |
| css_path = os.path.join(os.getcwd(), "style.css") | |
| css_content = open(css_path).read() if os.path.exists(css_path) else "" | |
| with gr.Blocks( | |
| title="Climate Q&A", | |
| css_paths=os.getcwd() + "/style.css", | |
| theme=theme, | |
| elem_id="main-component", | |
| ) as demo: | |
| config_components = create_config_modal() | |
| with gr.Tabs(): | |
| cqa_components = cqa_tab(tab_name="ClimateQ&A") | |
| local_cqa_components = cqa_tab(tab_name="France - Local Q&A") | |
| drias_components = create_drias_tab(share_client=share_client, user_id=user_id) | |
| ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id) | |
| create_about_tab() | |
| event_handling(cqa_components, config_components, tab_name="ClimateQ&A") | |
| event_handling( | |
| local_cqa_components, config_components, tab_name="France - Local Q&A" | |
| ) | |
| config_event_handling([cqa_components, local_cqa_components], config_components) | |
| demo.queue() | |
| return demo | |
| demo = main_ui() | |
| demo.launch(ssr_mode=False) | |