Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import os | |
import time | |
import re | |
import json | |
from auditqa.sample_questions import QUESTIONS | |
from auditqa.reports import POSSIBLE_REPORTS | |
from auditqa.engine.prompts import audience_prompts, answer_prompt_template, llama3_prompt | |
from auditqa.doc_process import process_pdf | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.llms import HuggingFaceEndpoint | |
from dotenv import load_dotenv | |
load_dotenv() | |
HF_token = os.environ["HF_TOKEN"] | |
vectorstores = process_pdf() | |
async def chat(query,history,sources,reports): | |
"""taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: | |
(messages in gradio format, messages in langchain format, source documents)""" | |
print(f">> NEW QUESTION : {query}") | |
print(f"history:{history}") | |
#print(f"audience:{audience}") | |
print(f"sources:{sources}") | |
print(f"reports:{reports}") | |
docs_html = "" | |
output_query = "" | |
output_language = "English" | |
audience = "Experts" | |
if audience == "Children": | |
audience_prompt = audience_prompts["children"] | |
elif audience == "General public": | |
audience_prompt = audience_prompts["general"] | |
elif audience == "Experts": | |
audience_prompt = audience_prompts["experts"] | |
else: | |
audience_prompt = audience_prompts["experts"] | |
# Prepare default values | |
if len(sources) == 0: | |
sources = ["Consolidated Reports"] | |
if len(reports) == 0: | |
reports = [] | |
if sources == "Ministry": | |
vectorstore = vectorstores["MWTS"] | |
else: | |
vectorstore = vectorstores["Consolidated"] | |
# get context | |
context_retrieved_lst = [] | |
question_lst= [query] | |
for question in question_lst: | |
retriever = vectorstore.as_retriever( | |
search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.6, "k": 3}) | |
context_retrieved = retriever.invoke(question) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
context_retrieved_formatted = format_docs(context_retrieved) | |
context_retrieved_lst.append(context_retrieved_formatted) | |
# get prompt | |
prompt = ChatPromptTemplate.from_template(llama3_prompt) | |
# get llm | |
# llm_qa = HuggingFaceEndpoint( | |
# endpoint_url= "https://mnczdhmrf7lkfd9d.eu-west-1.aws.endpoints.huggingface.cloud", | |
# task="text-generation", | |
# huggingfacehub_api_token=HF_token, | |
# model_kwargs={}) | |
# trying llm new-prompt adapted for llama-3 | |
# https://stackoverflow.com/questions/78429932/langchain-ollama-and-llama-3-prompt-and-response | |
# https://api.python.langchain.com/en/latest/llms/langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.html#langchain_community.llms.huggingface_endpoint.HuggingFaceEndpoint.model_kwargs | |
# https://huggingface.co/blog/llama3#how-to-prompt-llama-3 | |
llm_qa = HuggingFaceEndpoint( | |
endpoint_url= "https://mnczdhmrf7lkfd9d.eu-west-1.aws.endpoints.huggingface.cloud", | |
task="text-generation", | |
huggingfacehub_api_token=HF_token) | |
# create rag chain | |
chain = prompt | llm_qa | StrOutputParser() | |
# get answers | |
answer_lst = [] | |
for question, context in zip(question_lst , context_retrieved_lst): | |
answer = chain.invoke({"context": context, "question": question,'audience':audience_prompt, 'language':'english'}) | |
answer_lst.append(answer) | |
docs_html = [] | |
for i, d in enumerate(context_retrieved, 1): | |
docs_html.append(make_html_source(d, i)) | |
docs_html = "".join(docs_html) | |
previous_answer = history[-1][1] | |
previous_answer = previous_answer if previous_answer is not None else "" | |
answer_yet = previous_answer + answer_lst[0] | |
answer_yet = parse_output_llm_with_sources(answer_yet) | |
history[-1] = (query,answer_yet) | |
history = [tuple(x) for x in history] | |
yield history,docs_html,output_query,output_language | |
def make_html_source(source,i): | |
meta = source.metadata | |
# content = source.page_content.split(":",1)[1].strip() | |
content = source.page_content.strip() | |
name = meta['source'] | |
card = f""" | |
<div class="card" id="doc{i}"> | |
<div class="card-content"> | |
<h2>Doc {i} - {meta['file_path']} - Page {int(meta['page'])}</h2> | |
<p>{content}</p> | |
</div> | |
<div class="card-footer"> | |
<span>{name}</span> | |
<a href="{meta['file_path']}#page={int(meta['page'])}" target="_blank" class="pdf-link"> | |
<span role="img" aria-label="Open PDF">🔗</span> | |
</a> | |
</div> | |
</div> | |
""" | |
return card | |
def parse_output_llm_with_sources(output): | |
# Split the content into a list of text and "[Doc X]" references | |
content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output) | |
parts = [] | |
for part in content_parts: | |
if part.startswith("Doc"): | |
subparts = part.split(",") | |
subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts] | |
subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts] | |
parts.append("".join(subparts)) | |
else: | |
parts.append(part) | |
content_parts = "".join(parts) | |
return content_parts | |
# -------------------------------------------------------------------- | |
# Gradio | |
# -------------------------------------------------------------------- | |
# Set up Gradio Theme | |
theme = gr.themes.Base( | |
primary_hue="blue", | |
secondary_hue="red", | |
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
init_prompt = """ | |
Hello, I am Audit Q&A, a conversational assistant designed to help you understand audit Reports. I will answer your questions by **crawling through the Audit reports publishsed by Auditor General Office**. | |
❓ How to use | |
- **Examples**(tab on right): If this is first time for you using this app, then we have curated some example questions.Select a particular question from category fo questions. | |
- **Reports**(tab on right): You can choose to search or address your question to either specific report or a collection of reportlike Consolidated Annual Report,District or Department focused reports. If you dont select then the Consolidated report is relied upon to answer your question. | |
- **Sources**(tab on right): This tab will display the relied upon paragraphs from the report, to help you in assessing or fact checking if the answer provided by Audit Q&A assitant is correct or not. | |
⚠️ Limitations | |
- *Please note that the AI is not perfect and may sometimes give irrelevant answers. If you are not satisfied with the answer, please ask a more specific question or report your feedback to help us improve the system.* | |
- Audit Q&A assistant is a Generative AI, and therefore is not deterministic, so there might be change in answer to same question. | |
What do you want to learn ? | |
""" | |
# Setting Tabs | |
with gr.Blocks(title="Audit Q&A", css="style.css", theme=theme,elem_id = "main-component") as demo: | |
# user_id_state = gr.State([user_id]) | |
with gr.Tab("AuditQ&A"): | |
with gr.Row(elem_id="chatbot-row"): | |
with gr.Column(scale=2): | |
# state = gr.State([system_template]) | |
chatbot = gr.Chatbot( | |
value=[(None,init_prompt)], | |
show_copy_button=True,show_label = False,elem_id="chatbot",layout = "panel", | |
avatar_images = (None,"data-collection.png"), | |
)#,avatar_images = ("assets/logo4.png",None)) | |
# bot.like(vote,None,None) | |
with gr.Row(elem_id = "input-message"): | |
textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox") | |
# submit = gr.Button("",elem_id = "submit-button",scale = 1,interactive = True,icon = "https://static-00.iconduck.com/assets.00/settings-icon-2048x2046-cw28eevx.png") | |
with gr.Column(scale=1, variant="panel",elem_id = "right-panel"): | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Examples",elem_id = "tab-examples",id = 0): | |
examples_hidden = gr.Textbox(visible = False) | |
first_key = list(QUESTIONS.keys())[0] | |
dropdown_samples = gr.Dropdown(QUESTIONS.keys(),value = first_key,interactive = True,show_label = True,label = "Select a category of sample questions",elem_id = "dropdown-samples") | |
samples = [] | |
for i,key in enumerate(QUESTIONS.keys()): | |
examples_visible = True if i == 0 else False | |
with gr.Row(visible = examples_visible) as group_examples: | |
examples_questions = gr.Examples( | |
QUESTIONS[key], | |
[examples_hidden], | |
examples_per_page=8, | |
run_on_click=False, | |
elem_id=f"examples{i}", | |
api_name=f"examples{i}", | |
# label = "Click on the example question or enter your own", | |
# cache_examples=True, | |
) | |
samples.append(group_examples) | |
with gr.Tab("Reports",elem_id = "tab-config",id = 2): | |
gr.Markdown("Reminder: To get better results select the specific report/reports") | |
dropdown_sources = gr.Dropdown( | |
["Consolidated Reports", "District","Ministry"], | |
label="Select source", | |
value=["Ministry"], | |
interactive=True, | |
) | |
dropdown_reports = gr.Dropdown( | |
POSSIBLE_REPORTS, | |
label="Or select specific reports", | |
multiselect=True, | |
value=None, | |
interactive=True, | |
) | |
#dropdown_audience = "Experts" | |
#dropdown_audience = gr.Dropdown( | |
# ["Children","General public","Experts"], | |
# label="Select audience", | |
# value="Experts", | |
# interactive=True, | |
#) | |
output_query = gr.Textbox(label="Query used for retrieval",show_label = True,elem_id = "reformulated-query",lines = 2,interactive = False) | |
#output_language = gr.Textbox(label="Language",show_label = True,elem_id = "language",lines = 1,interactive = False) | |
with gr.Tab("Sources",elem_id = "tab-citations",id = 1): | |
sources_textbox = gr.HTML(show_label=False, elem_id="sources-textbox") | |
docs_textbox = gr.State("") | |
# with Modal(visible = False) as config_modal: | |
with gr.Tab("About",elem_classes = "max-height other-tabs"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("See more info at [https://www.oag.go.ug/](https://www.oag.go.ug/welcome)") | |
def start_chat(query,history): | |
history = history + [(query,None)] | |
history = [tuple(x) for x in history] | |
return (gr.update(interactive = False),gr.update(selected=1),history) | |
def finish_chat(): | |
return (gr.update(interactive = True,value = "")) | |
(textbox | |
.submit(start_chat, [textbox,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox") | |
.then(chat, [textbox,chatbot, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query],concurrency_limit = 8,api_name = "chat_textbox") | |
.then(finish_chat, None, [textbox],api_name = "finish_chat_textbox") | |
) | |
(examples_hidden | |
.change(start_chat, [examples_hidden,chatbot], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples") | |
.then(chat, [examples_hidden,chatbot, dropdown_sources,dropdown_reports], [chatbot,sources_textbox,output_query],concurrency_limit = 8,api_name = "chat_examples") | |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples") | |
) | |
def change_sample_questions(key): | |
index = list(QUESTIONS.keys()).index(key) | |
visible_bools = [False] * len(samples) | |
visible_bools[index] = True | |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))] | |
dropdown_samples.change(change_sample_questions,dropdown_samples,samples) | |
demo.queue() | |
demo.launch() |