rag_csv / app.py
gael1130's picture
Update app.py
2a3e737 verified
import os
import gradio as gr
import pandas as pd
from langchain_together import ChatTogether
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_experimental.tools import PythonAstREPLTool
from langchain_core.output_parsers.openai_tools import JsonOutputKeyToolsParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
# Global variable to store QA history
qa_history = []
def load_model(api_key):
return ChatTogether(
api_key=api_key,
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0
)
def create_chain(df, llm):
tool = PythonAstREPLTool(locals={"df": df})
llm_with_tools = llm.bind_tools([tool], tool_choice=tool.name)
parser = JsonOutputKeyToolsParser(key_name=tool.name, first_tool_only=True)
system = f"""You have access to a pandas dataframe `df`. Here is the output of `df.head().to_markdown()`:
```
{df.head().to_markdown()}
```
Given a user question, write the Python code to answer it. Don't assume you have access to any libraries other than built-in Python ones and pandas.
Respond directly to the question once you have enough information to answer it."""
prompt = ChatPromptTemplate.from_messages([
("system", system),
("human", "{question}"),
MessagesPlaceholder("chat_history", optional=True),
])
def _get_chat_history(x):
ai_msg = x["ai_msg"]
tool_call_id = x["ai_msg"].additional_kwargs["tool_calls"][0]["id"]
tool_msg = ToolMessage(tool_call_id=tool_call_id, content=str(x["tool_output"]))
return [ai_msg, tool_msg]
chain = (
RunnablePassthrough.assign(ai_msg=prompt | llm_with_tools)
.assign(tool_output=itemgetter("ai_msg") | parser | tool)
.assign(chat_history=_get_chat_history)
.assign(response=prompt | llm | StrOutputParser())
.pick(["tool_output", "response"])
)
return chain
def update_qa_history():
# Convert QA history to DataFrame for display
if not qa_history:
return pd.DataFrame(columns=["CSV File", "Question", "Answer"]).to_markdown()
return pd.DataFrame(qa_history, columns=["CSV File", "Question", "Answer"]).to_markdown()
def process_query(csv_file, api_key, query):
if not api_key.strip():
return "Please provide an API key", update_qa_history()
try:
df = pd.read_csv(csv_file.name)
llm = load_model(api_key)
chain = create_chain(df, llm)
result = chain.invoke({"question": query})
# Format the response
response = f"Analysis Result:\n{result['response']}\n\nTechnical Details:\n{result['tool_output']}"
# Extract just the filename without path
filename = os.path.basename(csv_file.name)
# Add to QA history
qa_history.append([
filename, # Store only the filename
query,
result['response'] # Store just the human-readable response
])
return response, update_qa_history()
except Exception as e:
return f"Error: {str(e)}", update_qa_history()
# Create Gradio interface
with gr.Blocks(title="CSV Analysis Assistant") as iface:
gr.Markdown("# CSV Analysis Assistant")
gr.Markdown("Upload a CSV file and ask questions about it using natural language.")
# Top section: Split into left (inputs) and right (result)
with gr.Row():
# Left column for inputs
with gr.Column(scale=1):
file_input = gr.File(label="Upload CSV File")
api_key = gr.Textbox(label="Together.ai API Key", type="password")
query = gr.Textbox(label="Your Question")
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
# Right column for result
with gr.Column(scale=1):
output = gr.Textbox(label="Result", lines=10)
# Bottom section: Full width for history table
with gr.Row():
history = gr.Markdown(value="### Question & Answer History\n" + update_qa_history())
# Handle button events
submit_btn.click(
fn=process_query,
inputs=[file_input, api_key, query],
outputs=[output, history]
)
def clear_inputs():
return [None, "", "", "", "### Question & Answer History\n" + update_qa_history()]
clear_btn.click(
fn=clear_inputs,
inputs=[],
outputs=[file_input, api_key, query, output, history]
)
# For Hugging Face Spaces deployment
iface.launch()