File size: 2,500 Bytes
f626f3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41120f1
f626f3c
 
41120f1
f626f3c
 
 
41120f1
 
f626f3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41120f1
f626f3c
 
 
41120f1
f626f3c
 
41120f1
f626f3c
 
 
 
41120f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.agents.agent_types import AgentType
from langchain_experimental.agents.agent_toolkits import create_csv_agent
import pandas as pd
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
import numpy as np

# Define a function to create the CSV agent
def create_csv_agent_instance(llm, file_path):
    # Set up the agent
    agent = create_csv_agent(
        llm,
        file_path,
        verbose=True,
        agent_type=AgentType.OPENAI_FUNCTIONS,
    )
    return agent

# Define the function to perform QA and optionally plot graphs
def qa_app(csv_file, question):
    try:
        df = pd.read_csv(csv_file.name)
        llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k", openai_api_key=key)  # ensure key is properly managed
        agent = create_csv_agent_instance(llm, csv_file.name)
        response = agent.run(question)

        # Check if the user's question implies a request for a plot
        if "plot" in question.lower() or "graph" in question.lower():
            graph_output = plot_data(df)
        else:
            graph_output = None

        return response, graph_output
    except Exception as e:
        return f"Error in processing: {str(e)}", None

def plot_data(df):
    plt.figure(figsize=(10, 5))
    if np.issubdtype(df[df.columns[0]].dtype, np.number) and np.issubdtype(df[df.columns[1]].dtype, np.number):
        plt.scatter(df[df.columns[0]], df[df.columns[1]])  # Use scatter for numeric vs numeric
    else:
        df[df.columns[0]].value_counts().plot(kind='bar')  # Example for categorical data
    plt.title('Data Distribution')
    plt.xlabel(df.columns[0])
    plt.ylabel(df.columns[1])
    plt.grid(True)
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close()
    buf.seek(0)
    return Image.open(buf)

# Set up the Gradio interface
demo = gr.Interface(
    fn=qa_app,
    inputs=[
        gr.File(label="Upload CSV file"),
        gr.Textbox(label="Question")
    ],
    outputs=[
        gr.Textbox(label="Answer"),
        gr.Image(label="Generated Plot", type="pil", optional=True),  # Mark the plot as optional
    ],
    title="Data Analysis Chatbot",
    description="Upload a CSV file, ask a question about the data. Include 'plot' or 'graph' in your question to generate graphs."
)

# Launch the Gradio app with debugging enabled to trace any runtime issues
demo.launch(debug=True)