Tamqeen's picture
Update app.py
4213c4a verified
import pandas as pd
import pdfplumber
import docx
import openai
import seaborn as sns
import matplotlib.pyplot as plt
import gradio as gr
# Set your OpenAI API key
openai.api_key = 'sk-proj-PMkGJxtGRdaihzh15yJYT3BlbkFJ0bEWbrsZjjwV5d3XYSFc'
def load_file(file):
file_type = file.name.split('.')[-1]
if file_type == 'csv':
return pd.read_csv(file.name)
elif file_type in ['xls', 'xlsx']:
return pd.read_excel(file.name)
elif file_type == 'pdf':
return load_pdf(file)
elif file_type in ['doc', 'docx']:
return load_doc(file)
else:
raise ValueError("Unsupported file type")
def load_pdf(file):
with pdfplumber.open(file.name) as pdf:
pages = [page.extract_text() for page in pdf.pages]
text = "\n".join(pages)
return pd.DataFrame({"text": [text]})
def load_doc(file):
doc = docx.Document(file.name)
text = "\n".join([para.text for para in doc.paragraphs])
return pd.DataFrame({"text": [text]})
def generate_query(prompt):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
)
return response['choices'][0]['message']['content'].strip()
def handle_query(query, df):
if "number of columns" in query.lower():
return f"The number of columns is {df.shape[1]}"
elif "number of rows" in query.lower():
return f"The number of rows is {df.shape[0]}"
else:
try:
result_df = df.query(query)
return result_df.to_html()
except Exception as e:
return str(e)
def draw_chart(query, df):
try:
result_df = df.query(query)
sns.scatterplot(data=result_df, x=result_df.columns[0], y=result_df.columns[1])
plt.title("Generated Chart")
plt.xlabel(result_df.columns[0])
plt.ylabel(result_df.columns[1])
plt.savefig('/content/chart.png')
plt.close()
return '/content/chart.png'
except Exception as e:
return str(e)
def chatbot(file, input_text):
try:
# Load the file into a DataFrame
df = load_file(file)
# Generate a query from the input text
query = generate_query(input_text)
# Handle the query and generate a response
response = handle_query(query, df)
# If the query is suitable for generating a chart, do so
if "chart" in query.lower() or "graph" in query.lower():
chart_path = draw_chart(query, df)
return chart_path, response
# Return the query response
return None, response
except Exception as e:
return None, str(e)
# Create a Gradio interface
iface = gr.Interface(
fn=chatbot,
inputs=[gr.File(type="file", label="Upload File"), gr.Textbox(lines=2, placeholder="Enter your query here...")],
outputs=["image", "html"],
title="Data Analyst Chatbot",
description="Upload a file and enter a query to get responses based on the data."
)
# Launch the interface
iface.launch()