bluenevus's picture
Update app.py
6013c50 verified
raw
history blame
7.16 kB
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai
import traceback
import os
from pywebio import start_server
from pywebio.input import file_upload, input
from pywebio.output import put_text, put_image, put_row, put_column, use_scope, put_buttons
from pywebio.session import run_js, set_env
import base64
import threading
def process_file(file, instructions):
try:
# Initialize Gemini
api_key = os.environ.get('GEMINI_API_KEY')
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read uploaded file
content = file['content']
if file['filename'].endswith('.csv'):
df = pd.read_csv(io.BytesIO(content))
else:
df = pd.read_excel(io.BytesIO(content))
# Generate visualization code
response = model.generate_content(f"""
Analyze the following dataset and instructions:
Data columns: {list(df.columns)}
Data shape: {df.shape}
Instructions: {instructions}
Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization:
1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap)
2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages)
3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size)
4. Provide a clear, concise title that explains the insight
Consider data density and choose visualizations that simplify and clarify the information.
Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly).
Return your response as a Python list of dictionaries:
[
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}
]
""")
# Extract code block safely
code_block = response.text
if '```python' in code_block:
code_block = code_block.split('```python')[1].split('```')[0].strip()
elif '```' in code_block:
code_block = code_block.split('```')[1].strip()
print("Generated code block:")
print(code_block)
plots = ast.literal_eval(code_block)
# Generate visualizations
images = []
for plot in plots[:3]: # Ensure max 3 plots
fig, ax = plt.subplots(figsize=(10, 6))
# Apply preprocessing and aggregation
plot_df = df.copy()
if plot['agg_func'] == 'sum':
plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index()
elif plot['agg_func'] == 'mean':
plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index()
elif plot['agg_func'] == 'count':
plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y'])
if 'top_n' in plot and plot['top_n']:
plot_df = plot_df.nlargest(plot['top_n'], plot['y'])
if plot['plot_type'] == 'bar':
plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'line':
plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'scatter':
plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax,
c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')])
elif plot['plot_type'] == 'hist':
plot_df[plot['x']].hist(ax=ax, bins=20)
elif plot['plot_type'] == 'pie':
plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%')
elif plot['plot_type'] == 'heatmap':
pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y'])
ax.imshow(pivot_df, cmap='YlOrRd')
ax.set_xticks(range(len(pivot_df.columns)))
ax.set_yticks(range(len(pivot_df.index)))
ax.set_xticklabels(pivot_df.columns)
ax.set_yticklabels(pivot_df.index)
ax.set_title(plot['title'])
if plot['plot_type'] != 'pie':
ax.set_xlabel(plot['x'])
ax.set_ylabel(plot['y'])
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
images.append(img)
plt.close(fig)
return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
except Exception as e:
error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message) # Print to console for debugging
error_image = Image.new('RGB', (800, 400), (255, 255, 255))
draw = ImageDraw.Draw(error_image)
draw.text((10, 10), error_message, fill=(255, 0, 0))
return [error_image] * 3
def data_analysis_dashboard():
set_env(title="Data Analysis Dashboard")
put_text("# Data Analysis Dashboard")
with use_scope('form'):
put_row([
put_column([
file_upload("Upload Dataset", accept=[".csv", ".xlsx"], name="file"),
input("Analysis Instructions", type="text", placeholder="Describe the analysis you want...", name="instructions"),
put_buttons(['Generate Insights'], onclick=[lambda: generate_insights()])
])
])
with use_scope('output'):
for i in range(3):
put_scope(f'visualization_{i+1}')
def generate_insights():
file = file_upload.files.get('file')
instructions = input.inputs.get('instructions')
if not file or not instructions:
put_text("Please upload a file and provide instructions.")
return
images = process_file(file, instructions)
for i, img in enumerate(images):
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
with use_scope(f'visualization_{i+1}', clear=True):
put_image(img_str, width='100%')
def main():
data_analysis_dashboard()
if __name__ == '__main__':
start_server(main, host='0.0.0.0', port=7860, debug=True, cdn=False, auto_open_webbrowser=True)