rknl's picture
Update app.py
c86e144 verified
raw
history blame
7.9 kB
import os
import gradio as gr
import base64
from llama_index.core import StorageContext, load_index_from_storage
from dotenv import load_dotenv
from retrieve import get_latest_dir, get_latest_html_file
from graph_handler import query_graph_qa, plot_subgraph
from embed_handler import query_rag_qa
from evaluate import evaluate_llm, reasoning_graph, get_coupon
import base64
load_dotenv()
KG_INDEX_PATH = get_latest_dir(os.getenv("GRAPH_DIR"))
KG_PLOT_PATH = get_latest_html_file(os.getenv("GRAPH_VIS"))
RAG_INDEX_PATH = get_latest_dir(os.getenv("EMBEDDING_DIR"))
# Load Graph-RAG index
graph_rag_index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=KG_INDEX_PATH)
)
# Load RAG index
rag_index = load_index_from_storage(
StorageContext.from_defaults(persist_dir=RAG_INDEX_PATH)
)
def query_tqa(query, search_level):
"""
Query the Graph-RAG and RAG models for a given query.
Args:
query (str): The query to ask the RAGs.
search_level (int): The max search level to use for the Graph RAG.
Returns:
tuple: The response, reference, and reference text for the Graph-RAG and RAG models.
"""
if not query.strip():
raise gr.Error("Please enter a query before asking.")
grag_response, grag_reference, grag_reference_text = query_graph_qa(
graph_rag_index, query, search_level
)
rag_response, rag_reference, rag_reference_text = query_rag_qa(
rag_index, query, search_level
)
return (
grag_response,
grag_reference,
grag_reference_text,
rag_response,
rag_reference,
rag_reference_text,
)
# def eval_llm(query, rag_response, grag_response):
# """
# Evaluate the Graph-RAG and RAG responses using an LLM.
# Args:
# query (str): The query that was asked.
# rag_response (str): The response from the Vanilla-RAG model.
# grag_response (str): The response from the Graph-RAG model.
# Returns:
# str: The evaluation text on various criteria from the LLM.
# """
# if not query.strip() or not rag_response.strip() or not grag_response.strip():
# raise gr.Error("Please ask a query and get responses before evaluating.")
# eval_text = evaluate_llm(query, grag_response, rag_response)
# return eval_text
# def reason_and_plot(query, grag_response, grag_reference):
# """
# Get the reasoning graph for a query and plot the knowledge graph.
# Args:
# query (str): The query to ask the Graph-RAG.
# grag_response (str): The response from the Graph-RAG model.
# grag_reference (str): The reference text from the Graph-RAG model.
# Returns:
# tuple: The reasoning graph and the HTML to plot the knowledge graph.
# """
# if not query.strip() or not grag_response.strip() or not grag_reference.strip():
# raise gr.Error(
# "Please ask a query and get a Graph-RAG response before reasoning."
# )
# graph_reasoning = reasoning_graph(query, grag_response, grag_reference)
# escaped_html = plot_subgraph(grag_reference)
# iframe_html = f'<iframe srcdoc="{escaped_html}" width="100%" height="400px" frameborder="0"></iframe>'
# return graph_reasoning, iframe_html
def show_graph():
"""
Show the latest graph visualization in an iframe.
Returns:
str: The HTML content to display the graph visualization in an iframe.
"""
graph_vis_dir = os.getenv("GRAPH_VIS", "graph_vis")
try:
latest_graph = get_latest_html_file(graph_vis_dir)
if latest_graph:
with open(latest_graph, "r", encoding="utf-8") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode()).decode()
iframe_html = f'<iframe src="data:text/html;base64,{encoded_html}" width="100%" height="1000px" frameborder="0"></iframe>'
return iframe_html
else:
return "No graph visualization found."
except Exception as e:
return f"Error: {str(e)}"
def reveal_coupon(query, grag_response):
"""
Get the coupon from the query and response.
Args:
query (str): Query asked to Graph-RAG.
grag_response (str): Response from the Graph-RAG model.
Returns:
str: Coupon with reasoning.
"""
if not query.strip() or not grag_response.strip():
raise gr.Error("Please ask a query and get a response before revealing the coupon.")
coupon = get_coupon(query, grag_response)
return coupon
with gr.Blocks() as demo:
gr.Markdown("# Comfy Virtual Assistant")
with gr.Row():
with gr.Column(scale=4):
query_input = gr.Textbox(label="Input Your Query", lines=3)
with gr.Column(scale=1):
search_level = gr.Slider(
minimum=1, maximum=50, value=3, step=5, label="Search Level"
)
ask_button = gr.Button("Ask Comfy", variant="primary")
examples = gr.Examples(
examples=[
["Recommend me an apple phone that has more than 10MP camera."],
["What is the price of Samsung Galaxy S24 Ultra 12/256Gb Titanium Gray"],
["I want a phone with 5000 mAH or more battery"],
],
inputs=[query_input],
)
with gr.Row():
with gr.Column():
gr.Markdown("### Graph-RAG")
grag_output = gr.Textbox(label="Response", lines=5)
# grag_reference = gr.Textbox(label="Triplets", lines=3)
# with gr.Accordion("Extracted Reference (Raw)", open=False):
# grag_reference_text = gr.Textbox(label="Raw Reference", lines=5)
# with gr.Column():
# gr.Markdown("### Vanilla RAG")
# rag_output = gr.Textbox(label="Response", lines=5)
# rag_reference = gr.Textbox(label="Extracted Reference", lines=3)
# with gr.Accordion("Extracted Reference (Raw)", open=False):
# rag_reference_text = gr.Textbox(label="Raw Reference", lines=5)
# gr.Markdown("### Coupon")
# with gr.Row():
# with gr.Column():
# coupon = gr.Text(label="Coupon", lines=1)
# with gr.Column():
# reveal = gr.Button("Reveal Coupon", variant="secondary")
# with gr.Row():
# gr.Markdown("### Evaluate and Compare")
# with gr.Row():
# eval_button = gr.Button("Evaluate LLMs", variant="secondary")
# grag_performance = gr.Textbox(label="Evaluation", lines=3)
# with gr.Row():
# gr.Markdown("### Graph Reasoning")
# with gr.Row():
# reason_button = gr.Button("Get Graph Reasoning", variant="secondary")
# with gr.Row():
# with gr.Column():
# grag_reasoning = gr.Textbox(label="Graph-RAG Reasoning", lines=5)
# with gr.Column():
# subgraph_plot = gr.HTML()
with gr.Row():
plot_button = gr.Button("Plot Knowledge Graph", variant="secondary")
kg_output = gr.HTML()
ask_button.click(
query_tqa,
inputs=[query_input, search_level],
outputs=[
grag_output,
# grag_reference,
# grag_reference_text,
# rag_output,
# rag_reference,
# rag_reference_text,
],
)
# eval_button.click(
# eval_llm,
# inputs=[query_input, rag_output, grag_output],
# outputs=[grag_performance],
# )
# reason_button.click(
# reason_and_plot,
# inputs=[query_input, grag_output, grag_reference],
# outputs=[grag_reasoning, subgraph_plot],
# )
plot_button.click(
show_graph,
outputs=[kg_output],
)
# reveal.click(
# reveal_coupon,
# inputs=[query_input, grag_output],
# outputs=[coupon],
# )
demo.launch(auth=(os.getenv("ID"), os.getenv("PASS")), share=False)