Tamqeen's picture
Update app.py
539de04 verified
raw
history blame
No virus
4.25 kB
import pandas as pd
import pdfplumber
import docx
import openai
import seaborn as sns
import matplotlib.pyplot as plt
import gradio as gr
openai.api_key = 'sk-proj-PMkGJxtGRdaihzh15yJYT3BlbkFJ0bEWbrsZjjwV5d3XYSFc'
def load_file(file):
file_type = file.name.split('.')[-1]
if file_type == 'csv':
return pd.read_csv(file.name)
elif file_type in ['xls', 'xlsx']:
return pd.read_excel(file.name)
elif file_type == 'pdf':
return load_pdf(file)
elif file_type in ['doc', 'docx']:
return load_doc(file)
else:
raise ValueError("Unsupported file type")
def load_pdf(file):
with pdfplumber.open(file.name) as pdf:
pages = [page.extract_text() for page in pdf.pages]
text = "\n".join(pages)
return pd.DataFrame({"text": [text]})
def load_doc(file):
doc = docx.Document(file.name)
text = "\n".join([para.text for para in doc.paragraphs])
return pd.DataFrame({"text": [text]})
def generate_query(prompt):
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=150
)
return response.choices[0].text.strip()
def handle_query(query, df):
if "number of columns" in query.lower():
return f"The number of columns is {df.shape[1]}"
elif "number of rows" in query.lower():
return f"The number of rows is {df.shape[0]}"
else:
try:
# Try executing the query as a pandas query
result_df = df.query(query)
return result_df.to_html()
except Exception as e:
return str(e)
def draw_chart(query, df):
try:
result_df = df.query(query)
sns.scatterplot(data=result_df, x=result_df.columns[0], y=result_df.columns[1])
plt.title("Generated Chart")
plt.xlabel(result_df.columns[0])
plt.ylabel(result_df.columns[1])
plt.savefig('/content/chart.png')
plt.close()
return '/content/chart.png'
except Exception as e:
return str(e)
def generate_query(prompt):
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
)
return response['choices'][0]['message']['content'].strip()
def handle_query(query, df):
if "number of columns" in query.lower():
return f"The number of columns is {df.shape[1]}"
elif "number of rows" in query.lower():
return f"The number of rows is {df.shape[0]}"
else:
try:
result_df = df.query(query)
return result_df.to_html()
except Exception as e:
return str(e)
def draw_chart(query, df):
try:
result_df = df.query(query)
sns.scatterplot(data=result_df, x=result_df.columns[0], y=result_df.columns[1])
plt.title("Generated Chart")
plt.xlabel(result_df.columns[0])
plt.ylabel(result_df.columns[1])
plt.savefig('/content/chart.png')
plt.close()
return '/content/chart.png'
except Exception as e:
return str(e)
def chatbot(file, input_text):
try:
# Load the file into a DataFrame
df = load_file(file)
# Generate a query from the input text
query = generate_query(input_text)
# Handle the query and generate a response
response = handle_query(query, df)
# If the query is suitable for generating a chart, do so
if "chart" in query.lower() or "graph" in query.lower():
chart_path = draw_chart(query, df)
return chart_path, response
# Return the query response
return None, response
except Exception as e:
return None, str(e)
# Create a Gradio interface
iface = gr.Interface(
fn=chatbot,
inputs=[gr.File(type="filepath", label="Upload File"), gr.Textbox(lines=2, placeholder="Enter your query here...")],
outputs=["image", "html"],
title="Data Analyst Chatbot",
description="Upload a file and enter a query to get responses based on the data."
)
# Launch the interface
iface.launch()