import os from datetime import datetime import gradio as gr from pinecone import Pinecone from huggingface_hub import whoami from langchain.prompts import ChatPromptTemplate from langchain_community.embeddings import HuggingFaceBgeEmbeddings from langchain.prompts.prompt import PromptTemplate from langchain_groq import ChatGroq from langchain.memory import ConversationBufferMemory from langchain_community.vectorstores import Pinecone as PineconeVectorstore from celsius_csrd_chatbot.utils import ( make_html_source, make_pairs, _format_chat_history, _combine_documents, init_env, parse_output_llm_with_sources, ) from celsius_csrd_chatbot.agent import make_graph_agent, display_graph init_env() demo_name = "ESRS_QA" hf_model = "BAAI/bge-base-en-v1.5" embeddings = HuggingFaceBgeEmbeddings( model_name=hf_model, encode_kwargs={"normalize_embeddings": True}, ) pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY")) index = pc.Index(os.getenv("PINECONE_API_INDEX")) vectorstore = PineconeVectorstore(index, embeddings, "page_content") llm = ChatGroq(temperature=0, model_name="llama-3.2-90b-text-preview") agent = make_graph_agent(llm, vectorstore) memory = ConversationBufferMemory( return_messages=True, output_key="answer", input_key="question" ) async def chat(query, history): """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: (messages in gradio format, messages in langchain format, source documents)""" date_now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f">> NEW QUESTION ({date_now}) : {query}") inputs = {"query": query} result = agent.astream_events(inputs, version="v1") docs = [] docs_html = "" output_query = "" start_streaming = False steps_display = { "categorize_esrs": ("🔄️ Analyzing user query", True), "retrieve_documents": ("🔄️ Searching in the knowledge base", True), } try: async for event in result: print(event) if event["event"] == "on_chat_model_stream": # print("line 66") if start_streaming == False: # print("line 68") start_streaming = True history[-1] = (query, "") new_token = event["data"]["chunk"].content previous_answer = history[-1][1] previous_answer = previous_answer if previous_answer is not None else "" answer_yet = previous_answer + new_token answer_yet = parse_output_llm_with_sources(answer_yet) history[-1] = (query, answer_yet) elif ( event["name"] == "answer_rag_wrong" and event["event"] == "on_chain_stream" ): history[-1] = (query, event["data"]["chunk"]["answer"]) elif ( event["name"] == "retrieve_documents" and event["event"] == "on_chain_end" ): try: # print(event) # print("line 84") docs = event["data"]["output"]["documents"] docs_html = [] for i, doc in enumerate(docs, 1): docs_html.append(make_html_source(i, doc)) # print(docs_html) docs_html = "".join(docs_html) # print(docs_html) except Exception as e: print(f"Error getting documents: {e}") print(event) for event_name, ( event_description, display_output, ) in steps_display.items(): if event["name"] == event_name: # print("line 99") if event["event"] == "on_chain_start": # print("line 101") answer_yet = event_description history[-1] = (query, answer_yet) history = [tuple(x) for x in history] yield history, docs_html except Exception as e: raise gr.Error(f"{e}") with open("./assets/style.css", "r") as f: css = f.read() # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) init_prompt = """ Hello, I am ESRS Q&A, a conversational assistant designed to help you understand the content of European Sustainability Reporting Standards (ESRS). I will answer your questions based **on the official definition of each ESRS as well as complementary guidelines**. ⚠️ Limitations *Please note that this chatbot is in an early stage phase, it is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* What do you want to learn ? """ with gr.Blocks(title=f"{demo_name}", css=css, theme=theme) as demo: with gr.Column(visible=True) as bloc_2: with gr.Tab("ESRS Q&A"): with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( value=[(None, init_prompt)], show_copy_button=True, show_label=False, elem_id="chatbot", layout="panel", avatar_images=( None, "https://i.ibb.co/cN0czLp/celsius-logo.png", ), ) state = gr.State([]) with gr.Row(elem_id="input-message"): ask = gr.Textbox( placeholder="Ask me anything here!", show_label=False, scale=7, lines=1, interactive=True, elem_id="input-textbox", ) with gr.Column(scale=1, variant="panel", elem_id="right-panel"): with gr.Tab("Sources", elem_id="tab-citations", id=1): sources_textbox = gr.HTML( show_label=False, elem_id="sources-textbox" ) docs_textbox = gr.State("") with gr.Tab("About", elem_classes="max-height other-tabs"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("WIP") def start_chat(query, history): history = history + [(query, None)] history = [tuple(x) for x in history] return (gr.update(interactive=False), history) def finish_chat(): return gr.update(interactive=True, value="") ask.submit( start_chat, [ask, chatbot], [ask, chatbot], queue=False, api_name="start_chat_textbox", ).then( fn=chat, inputs=[ ask, chatbot, ], outputs=[chatbot, sources_textbox], ).then( finish_chat, None, [ask], api_name="finish_chat_textbox" ) demo.launch( share=True, debug=True, )