File size: 6,103 Bytes
9a4f73a 12ce912 72c5969 bca92aa ec365ce 3c50a2d 904e6a1 4fc79a4 904e6a1 9a4fc1b ec365ce 904e6a1 ec365ce 422964b ec365ce 9a4fc1b 9a4f73a 9a4fc1b 72c5969 a5f2a3b 72c5969 2e5c04a a5f2a3b 6ac004d d3510d0 2e5c04a d3510d0 2e5c04a a5f2a3b 6ac004d a5f2a3b 72c5969 601022d 72c5969 3c50a2d 72c5969 9a4fc1b bca92aa ec365ce 72c5969 1b2886c ec365ce 2e5c04a bf5218c 2e5c04a a5f2a3b 6ac004d bf5218c 6ac004d bf5218c 6ac004d 2e5c04a 6ac004d 2e5c04a 6ac004d 2e5c04a 6ac004d 2e5c04a 1b2886c a5f2a3b 6ac004d 2e5c04a 1b2886c ec365ce 9a4fc1b 1b2886c 9a4fc1b 1b2886c b06a85b 1b2886c 12ce912 b06a85b 12ce912 9a4fc1b 3c50a2d ec365ce 3c50a2d b06a85b 9a4fc1b 9a4f73a 7a74363 9a4f73a 7a74363 9a4f73a 05370c5 9a4f73a 6013c50 9a4f73a 05370c5 9a4f73a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai
import traceback
import os
def process_file(file, instructions):
try:
# Initialize Gemini
api_key = os.environ.get('GEMINI_API_KEY')
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read uploaded file
file_path = file.name
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
# Generate visualization code
response = model.generate_content(f"""
Analyze the following dataset and instructions:
Data columns: {list(df.columns)}
Data shape: {df.shape}
Instructions: {instructions}
Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization:
1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap)
2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages)
3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size)
4. Provide a clear, concise title that explains the insight
Consider data density and choose visualizations that simplify and clarify the information.
Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly).
Return your response as a Python list of dictionaries:
[
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}
]
""")
# Extract code block safely
code_block = response.text
if '```python' in code_block:
code_block = code_block.split('```python')[1].split('```')[0].strip()
elif '```' in code_block:
code_block = code_block.split('```')[1].strip()
print("Generated code block:")
print(code_block)
plots = ast.literal_eval(code_block)
# Generate visualizations
images = []
for plot in plots[:3]: # Ensure max 3 plots
fig, ax = plt.subplots(figsize=(10, 6))
# Apply preprocessing and aggregation
plot_df = df.copy()
if plot['agg_func'] == 'sum':
plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index()
elif plot['agg_func'] == 'mean':
plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index()
elif plot['agg_func'] == 'count':
plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y'])
if 'top_n' in plot and plot['top_n']:
plot_df = plot_df.nlargest(plot['top_n'], plot['y'])
if plot['plot_type'] == 'bar':
plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'line':
plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
elif plot['plot_type'] == 'scatter':
plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax,
c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')])
elif plot['plot_type'] == 'hist':
plot_df[plot['x']].hist(ax=ax, bins=20)
elif plot['plot_type'] == 'pie':
plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%')
elif plot['plot_type'] == 'heatmap':
pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y'])
ax.imshow(pivot_df, cmap='YlOrRd')
ax.set_xticks(range(len(pivot_df.columns)))
ax.set_yticks(range(len(pivot_df.index)))
ax.set_xticklabels(pivot_df.columns)
ax.set_yticklabels(pivot_df.index)
ax.set_title(plot['title'])
if plot['plot_type'] != 'pie':
ax.set_xlabel(plot['x'])
ax.set_ylabel(plot['y'])
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
images.append(img)
plt.close(fig)
return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
except Exception as e:
error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message) # Print to console for debugging
error_image = Image.new('RGB', (800, 400), (255, 255, 255))
draw = ImageDraw.Draw(error_image)
draw.text((10, 10), error_message, fill=(255, 0, 0))
return [error_image] * 3
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown("# Data Analysis Dashboard")
with gr.Row():
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
submit = gr.Button("Generate Insights", variant="primary")
output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
submit.click(
process_file,
inputs=[file, instructions],
outputs=output_images
)
if __name__ == "__main__":
demo.launch() |