Update app.py
Browse files
app.py
CHANGED
@@ -56,22 +56,29 @@ def process_file(file, instructions, api_key):
|
|
56 |
for plot in plots[:3]: # Ensure max 3 plots
|
57 |
fig, ax = plt.subplots(figsize=(10, 6))
|
58 |
|
59 |
-
# Apply preprocessing
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if plot['plot_type'] == 'bar':
|
64 |
-
|
65 |
elif plot['plot_type'] == 'line':
|
66 |
-
|
67 |
elif plot['plot_type'] == 'scatter':
|
68 |
-
|
69 |
elif plot['plot_type'] == 'hist':
|
70 |
-
|
71 |
|
72 |
ax.set_title(plot['title'])
|
73 |
ax.set_xlabel(plot['x'])
|
74 |
-
ax.set_ylabel(plot['y'] if plot['y'] else '
|
75 |
plt.tight_layout()
|
76 |
|
77 |
buf = io.BytesIO()
|
@@ -91,23 +98,4 @@ def process_file(file, instructions, api_key):
|
|
91 |
draw.text((10, 10), error_message, fill=(255, 0, 0))
|
92 |
return [error_image] * 3
|
93 |
|
94 |
-
|
95 |
-
gr.Markdown("# Data Analysis Dashboard")
|
96 |
-
|
97 |
-
with gr.Row():
|
98 |
-
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
|
99 |
-
instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
|
100 |
-
|
101 |
-
api_key = gr.Textbox(label="Gemini API Key", type="password")
|
102 |
-
submit = gr.Button("Generate Insights", variant="primary")
|
103 |
-
|
104 |
-
output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
|
105 |
-
|
106 |
-
submit.click(
|
107 |
-
process_file,
|
108 |
-
inputs=[file, instructions, api_key],
|
109 |
-
outputs=output_images
|
110 |
-
)
|
111 |
-
|
112 |
-
if __name__ == "__main__":
|
113 |
-
demo.launch()
|
|
|
56 |
for plot in plots[:3]: # Ensure max 3 plots
|
57 |
fig, ax = plt.subplots(figsize=(10, 6))
|
58 |
|
59 |
+
# Apply preprocessing
|
60 |
+
plot_df = df.copy()
|
61 |
+
if 'Group data by' in plot['preprocessing']:
|
62 |
+
group_by = plot['x']
|
63 |
+
agg_column = plot['y'][0] if isinstance(plot['y'], list) else plot['y']
|
64 |
+
plot_df = plot_df.groupby(group_by)[agg_column].sum().reset_index()
|
65 |
+
if 'Sort' in plot['preprocessing']:
|
66 |
+
plot_df = plot_df.sort_values(by=plot['y'][0] if isinstance(plot['y'], list) else plot['y'], ascending=False)
|
67 |
+
if 'Filter to keep only the top 5' in plot['preprocessing']:
|
68 |
+
plot_df = plot_df.head(5)
|
69 |
|
70 |
if plot['plot_type'] == 'bar':
|
71 |
+
plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
|
72 |
elif plot['plot_type'] == 'line':
|
73 |
+
plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
|
74 |
elif plot['plot_type'] == 'scatter':
|
75 |
+
plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
|
76 |
elif plot['plot_type'] == 'hist':
|
77 |
+
plot_df[plot['x']].hist(ax=ax)
|
78 |
|
79 |
ax.set_title(plot['title'])
|
80 |
ax.set_xlabel(plot['x'])
|
81 |
+
ax.set_ylabel(plot['y'][0] if isinstance(plot['y'], list) else plot['y'])
|
82 |
plt.tight_layout()
|
83 |
|
84 |
buf = io.BytesIO()
|
|
|
98 |
draw.text((10, 10), error_message, fill=(255, 0, 0))
|
99 |
return [error_image] * 3
|
100 |
|
101 |
+
# The rest of your code remains unchanged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|