bluenevus's picture
Update app.py
05370c5 verified
raw
history blame
4.97 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai
import traceback
def process_file(file, instructions, api_key):
try:
# Initialize Gemini
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read uploaded file
file_path = file.name
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
# Generate visualization code
response = model.generate_content(f"""
Analyze the following dataset and instructions:
Data columns: {list(df.columns)}
Instructions: {instructions}
Based on this, create 3 appropriate visualizations. For each visualization, provide:
1. A title
2. The most suitable plot type (choose from: bar, line, scatter, hist)
3. The column to use for the x-axis
4. The column(s) to use for the y-axis (can be a list for multiple columns, or None for histograms)
5. Any necessary data preprocessing steps (e.g., grouping, sorting, etc.)
Return your response as a Python list of dictionaries:
[
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}}
]
""")
# Extract code block safely
code_block = response.text
if '```python' in code_block:
code_block = code_block.split('```python')[1].split('```')[0].strip()
elif '```' in code_block:
code_block = code_block.split('```')[1].strip()
print("Generated code block:")
print(code_block)
plots = ast.literal_eval(code_block)
# Generate visualizations
images = []
for plot in plots[:3]: # Ensure max 3 plots
fig, ax = plt.subplots(figsize=(10, 6))
# Apply preprocessing
plot_df = df.copy()
if 'Group data by' in plot['preprocessing']:
group_by = plot['x']
agg_column = plot['y'][0] if isinstance(plot['y'], list) else plot['y']
plot_df = plot_df.groupby(group_by)[agg_column].sum().reset_index()
if 'Sort' in plot['preprocessing']:
plot_df = plot_df.sort_values(by=plot['y'][0] if isinstance(plot['y'], list) else plot['y'], ascending=False)
if 'Filter to keep only the top 5' in plot['preprocessing']:
plot_df = plot_df.head(5)
if plot['plot_type'] == 'bar':
plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'line':
plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'scatter':
plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'hist':
plot_df[plot['x']].hist(ax=ax)
ax.set_title(plot['title'])
ax.set_xlabel(plot['x'])
ax.set_ylabel(plot['y'][0] if isinstance(plot['y'], list) else plot['y'])
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
images.append(img)
plt.close(fig)
return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
except Exception as e:
error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message) # Print to console for debugging
error_image = Image.new('RGB', (800, 400), (255, 255, 255))
draw = ImageDraw.Draw(error_image)
draw.text((10, 10), error_message, fill=(255, 0, 0))
return [error_image] * 3
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# Data Analysis Dashboard")
with gr.Row():
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
api_key = gr.Textbox(label="Gemini API Key", type="password")
submit = gr.Button("Generate Insights", variant="primary")
output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
submit.click(
process_file,
inputs=[file, instructions, api_key],
outputs=output_images
)
if __name__ == "__main__":
demo.launch()