|
import gradio as gr |
|
import time |
|
import os |
|
from spinoza_project.source.backend.llm_utils import ( |
|
get_llm_api, |
|
get_vectorstore_api, |
|
) |
|
from spinoza_project.source.frontend.utils import ( |
|
init_env, |
|
parse_output_llm_with_sources, |
|
) |
|
from spinoza_project.source.frontend.gradio_utils import ( |
|
get_sources, |
|
set_prompts, |
|
get_config, |
|
get_prompts, |
|
get_assets, |
|
get_theme, |
|
get_init_prompt, |
|
get_synthesis_prompt, |
|
get_qdrants, |
|
get_qdrants_public, |
|
start_agents, |
|
end_agents, |
|
next_call, |
|
zip_longest_fill, |
|
reformulate, |
|
answer, |
|
) |
|
|
|
from assets.utils_javascript import ( |
|
accordion_trigger, |
|
accordion_trigger_end, |
|
accordion_trigger_spinoza, |
|
accordion_trigger_spinoza_end, |
|
update_footer, |
|
) |
|
|
|
init_env() |
|
config = get_config() |
|
|
|
|
|
print("Loading Prompts") |
|
prompts = get_prompts(config) |
|
chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config) |
|
synthesis_prompt_template = get_synthesis_prompt(config) |
|
|
|
|
|
print("Building LLM") |
|
groq_model_name = ( |
|
config["groq_model_name"] if not os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME") else "" |
|
) |
|
llm = get_llm_api(groq_model_name) |
|
|
|
|
|
print("Loading Databases") |
|
qdrants = get_qdrants(config) |
|
|
|
if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"): |
|
bdd_presse = get_vectorstore_api("presse") |
|
bdd_afp = get_vectorstore_api("afp") |
|
|
|
else: |
|
qdrants_public = get_qdrants_public(config) |
|
qdrants = {**qdrants, **qdrants_public} |
|
bdd_presse = None |
|
bdd_afp = None |
|
|
|
|
|
css, source_information = get_assets() |
|
theme = get_theme() |
|
init_prompt = get_init_prompt() |
|
|
|
|
|
def reformulate_questions( |
|
question, |
|
llm=llm, |
|
chat_reformulation_prompts=chat_reformulation_prompts, |
|
config=config, |
|
): |
|
for elt in zip_longest_fill( |
|
*[ |
|
reformulate(llm, chat_reformulation_prompts, question, tab, config=config) |
|
for tab in config["tabs"] |
|
] |
|
): |
|
time.sleep(0.02) |
|
yield elt |
|
|
|
|
|
def retrieve_sources( |
|
*questions, |
|
qdrants=qdrants, |
|
bdd_presse=bdd_presse, |
|
bdd_afp=bdd_afp, |
|
config=config, |
|
): |
|
formated_sources, text_sources = get_sources( |
|
questions, qdrants, bdd_presse, bdd_afp, config |
|
) |
|
|
|
return (formated_sources, *text_sources) |
|
|
|
|
|
def answer_questions( |
|
*questions_sources, llm=llm, chat_qa_prompts=chat_qa_prompts, config=config |
|
): |
|
questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] |
|
sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] |
|
|
|
for elt in zip_longest_fill( |
|
*[ |
|
answer(llm, chat_qa_prompts, question, source, tab, config) |
|
for question, source, tab in zip(questions, sources, config["tabs"]) |
|
] |
|
): |
|
time.sleep(0.02) |
|
yield [ |
|
[(question, parse_output_llm_with_sources(ans))] |
|
for question, ans in zip(questions, elt) |
|
] |
|
|
|
|
|
def get_synthesis( |
|
question, |
|
*answers, |
|
llm=llm, |
|
synthesis_prompt_template=synthesis_prompt_template, |
|
config=config, |
|
): |
|
answer = [] |
|
for i, tab in enumerate(config["tabs"]): |
|
if len(str(answers[i])) >= 100: |
|
answer.append( |
|
f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") |
|
) |
|
|
|
if len(answer) == 0: |
|
return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" |
|
else: |
|
for elt in llm.stream( |
|
synthesis_prompt_template, |
|
{ |
|
"question": question.replace("<p>", "").replace("</p>\n", ""), |
|
"answers": "\n\n".join(answer), |
|
}, |
|
): |
|
time.sleep(0.01) |
|
yield [(question, parse_output_llm_with_sources(elt))] |
|
|
|
|
|
with gr.Blocks( |
|
title=f"🔍 Spinoza", |
|
css=css, |
|
js=update_footer(), |
|
theme=theme, |
|
) as demo: |
|
chatbots = {} |
|
question = gr.State("") |
|
docs_textbox = gr.State([""]) |
|
agent_questions = {elt: gr.State("") for elt in config["tabs"]} |
|
component_sources = {elt: gr.State("") for elt in config["tabs"]} |
|
text_sources = {elt: gr.State("") for elt in config["tabs"]} |
|
tab_states = {elt: gr.State(elt) for elt in config["tabs"]} |
|
|
|
with gr.Tab("Q&A", elem_id="main-component"): |
|
with gr.Row(elem_id="chatbot-row"): |
|
with gr.Column(scale=2, elem_id="center-panel"): |
|
with gr.Group(elem_id="chatbot-group"): |
|
for tab in list(config["tabs"].keys()) + ["Spinoza"]: |
|
if tab == "Spinoza": |
|
agent_name = f"Spinoza" |
|
elem_id = f"accordion-{tab}" |
|
elem_classes = "accordion accordion-agent spinoza-agent" |
|
else: |
|
agent_name = f"Agent {config['source_mapping'][tab]}" |
|
elem_id = f"accordion-{config['source_mapping'][tab]}" |
|
elem_classes = "accordion accordion-agent" |
|
|
|
with gr.Accordion( |
|
agent_name, |
|
open=True if agent_name == "Spinoza" else False, |
|
elem_id=elem_id, |
|
elem_classes=elem_classes, |
|
): |
|
|
|
chatbots[tab] = gr.Chatbot( |
|
value=( |
|
[(None, init_prompt)] |
|
if agent_name == "Spinoza" |
|
else None |
|
), |
|
show_copy_button=True, |
|
show_share_button=False, |
|
show_label=False, |
|
elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}", |
|
layout="panel", |
|
avatar_images=( |
|
"./assets/logos/help.png", |
|
( |
|
"./assets/logos/spinoza.png" |
|
if agent_name == "Spinoza" |
|
else None |
|
), |
|
), |
|
) |
|
|
|
with gr.Row(elem_id="input-message"): |
|
ask = gr.Textbox( |
|
placeholder="Ask me anything here!", |
|
show_label=False, |
|
scale=7, |
|
lines=1, |
|
interactive=True, |
|
elem_id="input-textbox", |
|
) |
|
|
|
with gr.Column(scale=1, variant="panel", elem_id="right-panel"): |
|
with gr.TabItem("Sources", elem_id="tab-sources", id=0): |
|
sources_textbox = gr.HTML( |
|
show_label=False, elem_id="sources-textbox" |
|
) |
|
|
|
with gr.Tab("Source information", elem_id="source-component"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown(source_information) |
|
|
|
with gr.Tab("Contact", elem_id="contact-component"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("For any issue contact **spinoza.support@ekimetrics.com**.") |
|
|
|
ask.submit( |
|
start_agents, inputs=[], outputs=[chatbots["Spinoza"]], js=accordion_trigger() |
|
).then( |
|
fn=reformulate_questions, |
|
inputs=[ask], |
|
outputs=[agent_questions[tab] for tab in config["tabs"]], |
|
).then( |
|
fn=retrieve_sources, |
|
inputs=[agent_questions[tab] for tab in config["tabs"]], |
|
outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], |
|
).then( |
|
fn=answer_questions, |
|
inputs=[agent_questions[tab] for tab in config["tabs"]] |
|
+ [text_sources[tab] for tab in config["tabs"]], |
|
outputs=[chatbots[tab] for tab in config["tabs"]], |
|
).then( |
|
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() |
|
).then( |
|
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() |
|
).then( |
|
fn=get_synthesis, |
|
inputs=[agent_questions[list(config["tabs"].keys())[1]]] |
|
+ [chatbots[tab] for tab in config["tabs"]], |
|
outputs=[chatbots["Spinoza"]], |
|
).then( |
|
fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() |
|
).then( |
|
fn=end_agents, inputs=[], outputs=[] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(debug=True) |
|
|