Spaces:
Runtime error
Runtime error
import gradio as gr | |
from langchain.llms import OpenAI | |
from langchain.chat_models import ChatOpenAI | |
from langchain.agents.agent_types import AgentType | |
from langchain_experimental.agents.agent_toolkits import create_csv_agent | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
from PIL import Image | |
import numpy as np | |
# Define a function to create the CSV agent | |
def create_csv_agent_instance(llm, file_path): | |
# Set up the agent | |
agent = create_csv_agent( | |
llm, | |
file_path, | |
verbose=True, | |
agent_type=AgentType.OPENAI_FUNCTIONS, | |
) | |
return agent | |
# Define the function to perform QA and optionally plot graphs | |
def qa_app(csv_file, question): | |
try: | |
df = pd.read_csv(csv_file.name) | |
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k", openai_api_key=key) # ensure key is properly managed | |
agent = create_csv_agent_instance(llm, csv_file.name) | |
response = agent.run(question) | |
# Check if the user's question implies a request for a plot | |
if "plot" in question.lower() or "graph" in question.lower(): | |
graph_output = plot_data(df) | |
else: | |
graph_output = None | |
return response, graph_output | |
except Exception as e: | |
return f"Error in processing: {str(e)}", None | |
def plot_data(df): | |
plt.figure(figsize=(10, 5)) | |
if np.issubdtype(df[df.columns[0]].dtype, np.number) and np.issubdtype(df[df.columns[1]].dtype, np.number): | |
plt.scatter(df[df.columns[0]], df[df.columns[1]]) # Use scatter for numeric vs numeric | |
else: | |
df[df.columns[0]].value_counts().plot(kind='bar') # Example for categorical data | |
plt.title('Data Distribution') | |
plt.xlabel(df.columns[0]) | |
plt.ylabel(df.columns[1]) | |
plt.grid(True) | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
plt.close() | |
buf.seek(0) | |
return Image.open(buf) | |
# Set up the Gradio interface | |
demo = gr.Interface( | |
fn=qa_app, | |
inputs=[ | |
gr.File(label="Upload CSV file"), | |
gr.Textbox(label="Question") | |
], | |
outputs=[ | |
gr.Textbox(label="Answer"), | |
gr.Image(label="Generated Plot", type="pil", optional=True), # Mark the plot as optional | |
], | |
title="Data Analysis Chatbot", | |
description="Upload a CSV file, ask a question about the data. Include 'plot' or 'graph' in your question to generate graphs." | |
) | |
# Launch the Gradio app with debugging enabled to trace any runtime issues | |
demo.launch(debug=True) | |