spinoza_public / app.py
Msvr's picture
Initial commit
3911020
import gradio as gr
import time
import os
import yaml
from qdrant_client import models
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
from spinoza_project.source.backend.llm_utils import (
get_llm_api,
)
from spinoza_project.source.frontend.utils import (
init_env,
parse_output_llm_with_sources,
)
from spinoza_project.source.frontend.gradio_utils import (
get_sources,
set_prompts,
get_config,
get_prompts,
get_assets,
get_theme,
get_init_prompt,
get_synthesis_prompt,
get_qdrants,
start_agents,
end_agents,
next_call,
zip_longest_fill,
reformulate,
answer,
get_text,
update_translation,
)
from assets.utils_javascript import (
accordion_trigger,
accordion_trigger_end,
accordion_trigger_spinoza,
accordion_trigger_spinoza_end,
update_footer
)
init_env()
with open("./spinoza_project/config.yaml") as f:
config = yaml.full_load(f)
## Loading Prompts
print("Loading Prompts")
prompts = get_prompts(config)
chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config)
synthesis_prompt_template = get_synthesis_prompt(config)
## Building LLM
print("Building LLM")
groq_model_name = config.get("groq_model_name", "")
llm = get_llm_api(groq_model_name)
## Loading BDDs
print("Loading Databases")
qdrants, df_qdrants = get_qdrants(config)
dataframes_by_source = {
source: df_qdrants[df_qdrants['Source'] == source].drop(columns=['Source'])
for source in df_qdrants['Source'].unique()
}
for source, df in dataframes_by_source.items():
dataframes_by_source[source]['Filter'] = dataframes_by_source[source]['Filter'].fillna('Unknown')
unknown_percentage = df.apply(lambda x: (x == 'Unknown').mean())
columns_to_drop = unknown_percentage[unknown_percentage == 1.0].index
if len(columns_to_drop) > 0:
print(f"Deleting following columns for {source}: {columns_to_drop.tolist()}")
dataframes_by_source[source] = df.drop(columns=columns_to_drop)
## Loading Assets
print("Loading assets")
css, source_information_fr, source_information_en, about_contact_fr, about_contact_en = get_assets()
theme = get_theme()
init_prompt = get_init_prompt()
## Updating TRANSLATIONS dictionnary
list_tabs = list(config["tabs"])
update_translation(list_tabs, config)
def get_source_df(source_name):
return dataframes_by_source.get(source_name, pd.DataFrame())
LANGUAGE_MAPPING = {
"fr": "french/français",
"en": "english/anglais"
}
def reformulate_questions(
lang_component,
question,
llm=llm,
chat_reformulation_prompts=chat_reformulation_prompts,
config=config,
):
lang = lang_component.value if hasattr(lang_component, 'value') else lang_component
language = LANGUAGE_MAPPING.get(lang, "french/français")
for elt in zip_longest_fill(
*[
reformulate(language, llm, chat_reformulation_prompts, question, tab, config=config)
for tab in config["tabs"]
]
):
time.sleep(0.02)
yield elt
def retrieve_sources(
*questions,
filters_dict,
qdrants=qdrants,
config=config,
):
if filters_dict is None:
filters_dict = {}
formated_sources, text_sources = get_sources(
questions, filters_dict, qdrants, config
)
return (formated_sources, *text_sources)
def retrieve_sources_wrapper(*args):
questions = list(args[:-1])
filters = args[-1]
return retrieve_sources(
questions,
filters_dict=filters
)
def answer_questions(
lang_component,
*questions_sources,
llm=llm,
chat_qa_prompts=chat_qa_prompts,
config=config
):
lang = lang_component.value if hasattr(lang_component, 'value') else lang_component
language = LANGUAGE_MAPPING.get(lang, "french/français")
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]]
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]]
for elt in zip_longest_fill(
*[
answer(language, llm, chat_qa_prompts, question, source, tab, config)
for question, source, tab in zip(questions, sources, config["tabs"])
]
):
time.sleep(0.02)
yield [
[(question, parse_output_llm_with_sources(ans))]
for question, ans in zip(questions, elt)
]
def get_synthesis(
lang_component,
question,
*answers,
llm=llm,
synthesis_prompt_template=synthesis_prompt_template,
config=config,
):
lang = lang_component.value if hasattr(lang_component, 'value') else lang_component
language = LANGUAGE_MAPPING.get(lang, "french/français")
answer = []
for i, tab in enumerate(config["tabs"]):
if len(str(answers[i])) >= 100:
answer.append(
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "")
)
if len(answer) == 0:
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question"
else:
for elt in llm.stream(
synthesis_prompt_template,
{
"question": question.replace("<p>", "").replace("</p>\n", ""),
"answers": "\n\n".join(answer),
"language": language
},
):
time.sleep(0.01)
yield [(question, parse_output_llm_with_sources(elt))]
def get_unique_values_filters(df):
filters_values = sorted([
str(x) for x in df['Filter'].unique()
if pd.notna(x) and str(x).strip() != ''
])
return filters_values
def filter_data(filter, source):
if source not in dataframes_by_source:
raise ValueError(f"'{source}' not found withing the sources availible")
df = dataframes_by_source[source]
if filter:
df = df[df['Filter'].fillna('').astype(str).isin(filter)]
return df.values.tolist()
def update_filters(filters_dict, agent, values):
field = "file_filtering_modality"
if filters_dict is None:
filters_dict = {}
new_filters = dict(filters_dict)
if agent not in new_filters:
new_filters[agent] = {}
if not values or isinstance(values, list):
if field in new_filters[agent]:
del new_filters[agent][field]
if not new_filters[agent]:
del new_filters[agent]
else:
new_filters[agent][field] = values
return new_filters, new_filters
with gr.Blocks(
title=f"🔍 Spinoza",
css=css,
js=update_footer(),
theme=theme,
) as demo:
accordions_qa = {}
accordions_filters = {}
current_language = gr.State(value="fr")
chatbots = {}
question = gr.State("")
agt_input_flt = {}
agt_desc = {}
agt_input_dsp = gr.State({})
docs_textbox = gr.State([""])
agent_questions = {elt: gr.State("") for elt in config["tabs"]}
component_sources = {elt: gr.State("") for elt in config["tabs"]}
text_sources = {elt: gr.State("") for elt in config["tabs"]}
tab_states = {elt: gr.State(elt) for elt in config["tabs"]}
filters_state = gr.State({})
filters_display = gr.JSON(
label="Filtres sélectionnés",
value={},
visible=False
)
with gr.Row(elem_classes="header-row"):
button_fr = gr.Button("", elem_id="fr-button", elem_classes="lang-button", icon='./assets/logos/france_round.png')
button_en = gr.Button("", elem_id="en-button", elem_classes="lang-button", icon='./assets/logos/us_round.png')
with gr.Row(elem_classes="main-row"):
with gr.Tab("Q&A", elem_id="main-component"):
with gr.Row(elem_id="chatbot-row"):
with gr.Column(scale=2, elem_id="center-panel"):
with gr.Row(elem_id="input-message"):
ask = gr.Textbox(
placeholder=get_text("ask_placeholder", current_language.value),
show_label=False,
scale=7,
lines=1,
interactive=True,
elem_id="input-textbox",
)
with gr.Group(elem_id="chatbot-group"):
for tab in list(config["tabs"].keys()):
agent_name = get_text(f"agent_{config['source_mapping'][tab]}_qa", current_language.value)
elem_id = f"accordion-{config['source_mapping'][tab]}"
elem_classes = "accordion accordion-agent"
with gr.Accordion(
label=agent_name,
open=False,
elem_id=elem_id,
elem_classes=elem_classes,
) as accordions_qa[config['source_mapping'][tab]]:
# chatbot_key = agent_name.lower().replace(" ", "_")
chatbots[tab] = gr.Chatbot(
value=None,
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
(
"./assets/logos/spinoza.png"
if agent_name == "Spinoza"
else None
),
)
)
agent_name = "Spinoza"
with gr.Accordion(
label=agent_name,
open=True,
elem_id="accordion-Spinoza",
elem_classes="accordion accordion-agent spinoza-agent",
) as accordion_spinoza:
# chatbot_key = agent_name.lower().replace(" ", "_")
chatbots["Spinoza"] = gr.Chatbot(
value=([(None, get_text("init_prompt", current_language.value))]),
show_copy_button=True,
show_share_button=False,
show_label=False,
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}",
layout="panel",
avatar_images=(
"./assets/logos/help.png",
"./assets/logos/spinoza.png",
),
)
with gr.Column(scale=1, variant="panel", elem_id="right-panel"):
with gr.TabItem("Sources", elem_id="tab-sources", id=0):
sources_textbox = gr.HTML(
show_label=False, elem_id="sources-textbox"
)
with gr.Tab(label=get_text("source_filter_label", current_language.value), elem_id="filter-component") as source_filter_tab:
source_filter_title= gr.Markdown(value=get_text("source_filter_title", current_language.value))
source_filter_subtitle = gr.Markdown(value=get_text("source_filter_subtitle", current_language.value))
with gr.Row(elem_id="filter-row"):
with gr.Column(scale=2, elem_id="filter-center-panel"):
with gr.Group(elem_id="filter-group"):
for tab in list(config["tabs"].keys()):
agent_name = get_text(f"agent_{config['source_mapping'][tab]}_flt", current_language.value)
elem_id = f"accordion-filter-{config['source_mapping'][tab]}"
elem_classes = "accordion accordion-source"
with gr.Accordion(
label=agent_name,
open=False,
elem_id=elem_id,
elem_classes=elem_classes,
) as accordions_filters[config['source_mapping'][tab]]:
question_filter = gr.Markdown(value=get_text("question_filter", current_language.value))
with gr.Tabs():
df = get_source_df(config['source_mapping'][tab])
if not df.empty and 'Filter' in df.columns:
filters = get_unique_values_filters(df)
with gr.Row():
var_name = f"{config['source_mapping'][tab]}_input_flt"
agt_input_flt[var_name] = gr.CheckboxGroup(
[filter for filter in filters],
label="Filter(s):"
)
agt_input_flt[var_name].change(
fn=update_filters,
inputs=[filters_state, gr.State(config['source_mapping'][tab]), agt_input_flt[var_name]],
outputs=[filters_state, filters_display]
)
else:
gr.Markdown("**Error:** No data / 'Filter' column doesn't exist...")
with gr.Tab(label=get_text("source_informatation_label", current_language.value), elem_id="source-component") as source_information_tab:
with gr.Row():
with gr.Column(scale=1):
display_info_desc = gr.Markdown(value=get_text("display_info_desc", current_language.value))
accordions_inf = {}
with gr.Tabs(elem_id="main-tab-disp"):
for tab in list(config["tabs"].keys()):
agent_name = get_text(f"agent_{config['source_mapping'][tab]}_tab", current_language.value)
elem_id = f"accordion-{config['source_mapping'][tab]}-tab"
elem_classes = "disp-tabs"
with gr.Tab(
label=agent_name,
elem_id=elem_id,
elem_classes=elem_classes
) as accordions_inf[config['source_mapping'][tab]]:
var_name = f"{config['source_mapping'][tab]}_desc"
agt_desc[var_name] = gr.Markdown(value=get_text(f"{config['source_mapping'][tab]}_desc", current_language.value))
df = get_source_df(config['source_mapping'][tab])
if not df.empty and 'Filter' in df.columns:
filters = get_unique_values_filters(df)
with gr.Row():
var_name = f"{config['source_mapping'][tab]}_input_dsp"
agt_input_dsp.value[var_name] = gr.CheckboxGroup(
[filter for filter in filters],
label="Filter(s):"
)
output_df = gr.Dataframe(
headers=['Title', 'Pages', 'Filter Category', 'Publishing Date'],
datatype=['str', 'number', 'str', 'number'],
value=df.values.tolist(),
column_widths=[300, 100, 100, 150],
wrap=True
)
agt_input_dsp.value[var_name].change(
filter_data,
inputs=[agt_input_dsp.value[var_name]]+[gr.State(config['source_mapping'][tab])],
outputs=[output_df]
)
else:
gr.Markdown("**Error:** No data / 'Filter' column doesn't exist...")
with gr.Tab(label=get_text("contact_label", current_language.value), elem_id="contact-component") as contact_label:
with gr.Row():
with gr.Column(scale=1):
contact_info = gr.Markdown(value=about_contact_fr)
ask.submit(
start_agents, inputs=[current_language], outputs=[chatbots["Spinoza"]] + [source_filter_tab], js=accordion_trigger()
).then(
fn=reformulate_questions,
inputs=[current_language]+
[ask],
outputs=[agent_questions[tab] for tab in config["tabs"]],
).then(
fn=retrieve_sources_wrapper,
inputs=[agent_questions[tab] for tab in config["tabs"]] + [filters_state],
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]],
).then(
fn=answer_questions,
inputs=[current_language]
+ [agent_questions[tab] for tab in config["tabs"]]
+ [text_sources[tab] for tab in config["tabs"]],
outputs=[chatbots[tab] for tab in config["tabs"]],
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end()
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza()
).then(
fn=get_synthesis,
inputs=[current_language]
+ [ask]
+ [chatbots[tab] for tab in config["tabs"]],
outputs=[chatbots["Spinoza"]],
).then(
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end()
).then(
fn=end_agents, inputs=[current_language], outputs=[source_filter_tab]
)
def reset_app(language):
chatbot_updates = {}
for tab in config["tabs"]:
chatbot_updates[tab] = gr.update(value=None)
chatbot_updates["Spinoza"] = gr.update(value=[(None, get_text("init_prompt", language))])
empty_checkbox = gr.update(value=None)
checkbox_components = list(agt_input_flt.keys()) + list(agt_input_dsp.value.keys())
checkbox_updates = {component: empty_checkbox for component in checkbox_components}
return {
"chatbots": chatbot_updates,
"filters_state": gr.update(value={}),
"filters_display": gr.update(value={}),
"ask": gr.update(value="", placeholder=get_text("ask_placeholder", language)),
"sources_textbox": gr.update(value=""),
"checkbox_updates": checkbox_updates
}
def toggle_language_fr():
reset_state = reset_app("fr")
return [
"fr",
reset_state["ask"],
reset_state["chatbots"]["Spinoza"],
*[reset_state["chatbots"][tab] for tab in config["tabs"]],
*[
gr.update(
label=get_text(f"agent_{config['source_mapping'][tab]}_qa", "fr"),
open=False,
elem_id=f"accordion-{config['source_mapping'][tab]}",
elem_classes="accordion accordion-agent"
)
for tab in list(config["tabs"].keys())
],
gr.update(label=get_text("source_filter_label", "fr"), elem_id="filter-component"),
*[
gr.update(
label=get_text(f"agent_{config['source_mapping'][tab]}_flt", "fr"),
elem_id=f"accordion-filter-{config['source_mapping'][tab]}",
elem_classes="accordion accordion-source"
)
for tab in list(config["tabs"].keys())
],
gr.update(value=get_text("source_filter_title", 'fr')),
gr.update(value=get_text("source_filter_subtitle", 'fr')),
gr.update(value=get_text("question_filter", 'fr')),
gr.update(label=get_text("source_informatation_label", "fr"), elem_id="source-component"),
gr.update(value=get_text("display_info_desc", "fr")),
*[
gr.update(value=get_text(f"{config['source_mapping'][tab]}_desc", "fr"))
for tab in list(config["tabs"].keys())
],
*[
gr.update(
label=get_text(f"agent_{config['source_mapping'][tab]}_tab", "fr"),
elem_id=f"accordion-{config['source_mapping'][tab]}-tab",
elem_classes="disp-tabs"
)
for tab in list(config["tabs"].keys())
],
gr.update(label=get_text("contact_label", "fr")),
gr.update(value=about_contact_fr),
gr.update(value=""),
gr.update(value={}),
gr.update(value={}),
*[
gr.update(value=None) for _ in range(len(agt_input_flt))
]
]
def toggle_language_en():
reset_state = reset_app("en")
return [
"en",
reset_state["ask"],
reset_state["chatbots"]["Spinoza"],
*[reset_state["chatbots"][tab] for tab in config["tabs"]],
*[
gr.update(
label=get_text(f"agent_{config['source_mapping'][tab]}_qa", "en"),
open=False,
elem_id=f"accordion-{config['source_mapping'][tab]}",
elem_classes="accordion accordion-agent"
)
for tab in list(config["tabs"].keys())
],
gr.update(label=get_text("source_filter_label", "en"), elem_id="filter-component"),
*[
gr.update(
label=get_text(f"agent_{config['source_mapping'][tab]}_flt", "en"),
elem_id=f"accordion-filter-{config['source_mapping'][tab]}",
elem_classes="accordion accordion-source"
)
for tab in list(config["tabs"].keys())
],
gr.update(value=get_text("source_filter_title", 'en')),
gr.update(value=get_text("source_filter_subtitle", 'en')),
gr.update(value=get_text("question_filter", 'en')),
gr.update(label=get_text("source_informatation_label", "en"), elem_id="source-component"),
gr.update(value=get_text("display_info_desc", "en")),
*[
gr.update(value=get_text(f"{config['source_mapping'][tab]}_desc", "en"))
for tab in list(config["tabs"].keys())
],
*[
gr.update(
label=get_text(f"agent_{config['source_mapping'][tab]}_tab", "en"),
elem_id=f"accordion-{config['source_mapping'][tab]}-tab",
elem_classes="disp-tabs"
)
for tab in list(config["tabs"].keys())
],
gr.update(label=get_text("contact_label", "en")),
gr.update(value=about_contact_en),
gr.update(value=""),
gr.update(value={}),
gr.update(value={}),
*[
gr.update(value=None) for _ in range(len(agt_input_flt))
]
]
button_fr.click(
fn=toggle_language_fr,
inputs=[],
outputs=[
current_language,
ask,
chatbots["Spinoza"],
*[chatbots[tab] for tab in config["tabs"]],
*[accordions_qa[key] for key in accordions_qa.keys()],
source_filter_tab,
*[accordions_filters[key] for key in accordions_filters.keys()],
source_filter_title,
source_filter_subtitle,
question_filter,
source_information_tab,
display_info_desc,
*[agt_desc[key] for key in agt_desc.keys()],
*[accordions_inf[key] for key in accordions_inf.keys()],
contact_label,
contact_info,
sources_textbox,
filters_state,
filters_display,
*[agt_input_flt[key] for key in agt_input_flt.keys()]
]
)
button_en.click(
fn=toggle_language_en,
inputs=[],
outputs=[
current_language,
ask,
chatbots["Spinoza"],
*[chatbots[tab] for tab in config["tabs"]],
*[accordions_qa[key] for key in accordions_qa.keys()],
source_filter_tab,
*[accordions_filters[key] for key in accordions_filters.keys()],
source_filter_title,
source_filter_subtitle,
question_filter,
source_information_tab,
display_info_desc,
*[agt_desc[key] for key in agt_desc.keys()],
*[accordions_inf[key] for key in accordions_inf.keys()],
contact_label,
contact_info,
sources_textbox,
filters_state,
filters_display,
*[agt_input_flt[key] for key in agt_input_flt.keys()]
]
)
if __name__ == "__main__":
demo.queue().launch(debug=True, share=True)