|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import io |
|
import ast |
|
from PIL import Image, ImageDraw |
|
import google.generativeai as genai |
|
import traceback |
|
|
|
def process_file(file, instructions, api_key): |
|
try: |
|
|
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
|
|
|
|
|
file_path = file.name |
|
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path) |
|
|
|
|
|
response = model.generate_content(f""" |
|
Analyze the following dataset and instructions: |
|
|
|
Data columns: {list(df.columns)} |
|
Data shape: {df.shape} |
|
Instructions: {instructions} |
|
|
|
Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization: |
|
1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap) |
|
2. Determine appropriate data aggregation (e.g., top 5 categories, monthly averages) |
|
3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size) |
|
4. Provide a clear, concise title that explains the insight |
|
|
|
Consider data density and choose visualizations that simplify and clarify the information. |
|
Limit the number of data points displayed to ensure readability (e.g., top 5, top 10). |
|
|
|
Return your response as a Python list of dictionaries: |
|
[ |
|
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, |
|
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, |
|
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}} |
|
] |
|
""") |
|
|
|
|
|
code_block = response.text |
|
if '```python' in code_block: |
|
code_block = code_block.split('```python')[1].split('```')[0].strip() |
|
elif '```' in code_block: |
|
code_block = code_block.split('```')[1].strip() |
|
|
|
print("Generated code block:") |
|
print(code_block) |
|
|
|
plots = ast.literal_eval(code_block) |
|
|
|
|
|
images = [] |
|
for plot in plots[:3]: |
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
|
|
|
plot_df = df.copy() |
|
if plot['agg_func'] == 'sum': |
|
plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index() |
|
elif plot['agg_func'] == 'mean': |
|
plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index() |
|
elif plot['agg_func'] == 'count': |
|
plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y']) |
|
|
|
if 'top_n' in plot and plot['top_n']: |
|
plot_df = plot_df.nlargest(plot['top_n'], plot['y']) |
|
|
|
if plot['plot_type'] == 'bar': |
|
plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax) |
|
elif plot['plot_type'] == 'line': |
|
plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax) |
|
elif plot['plot_type'] == 'scatter': |
|
plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax, |
|
c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')]) |
|
elif plot['plot_type'] == 'hist': |
|
plot_df[plot['x']].hist(ax=ax, bins=20) |
|
elif plot['plot_type'] == 'pie': |
|
plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%') |
|
elif plot['plot_type'] == 'heatmap': |
|
pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y']) |
|
ax.imshow(pivot_df, cmap='YlOrRd') |
|
ax.set_xticks(range(len(pivot_df.columns))) |
|
ax.set_yticks(range(len(pivot_df.index))) |
|
ax.set_xticklabels(pivot_df.columns) |
|
ax.set_yticklabels(pivot_df.index) |
|
|
|
ax.set_title(plot['title']) |
|
if plot['plot_type'] != 'pie': |
|
ax.set_xlabel(plot['x']) |
|
ax.set_ylabel(plot['y']) |
|
plt.tight_layout() |
|
|
|
buf = io.BytesIO() |
|
plt.savefig(buf, format='png') |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
images.append(img) |
|
plt.close(fig) |
|
|
|
return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images)) |
|
|
|
except Exception as e: |
|
error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
print(error_message) |
|
error_image = Image.new('RGB', (800, 400), (255, 255, 255)) |
|
draw = ImageDraw.Draw(error_image) |
|
draw.text((10, 10), error_message, fill=(255, 0, 0)) |
|
return [error_image] * 3 |
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as demo: |
|
gr.Markdown("# Data Analysis Dashboard") |
|
|
|
with gr.Row(): |
|
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"]) |
|
instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...") |
|
|
|
api_key = gr.Textbox(label="Gemini API Key", type="password") |
|
submit = gr.Button("Generate Insights", variant="primary") |
|
|
|
output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)] |
|
|
|
submit.click( |
|
process_file, |
|
inputs=[file, instructions, api_key], |
|
outputs=output_images |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |