bluenevus commited on
Commit
bf5218c
·
verified ·
1 Parent(s): a5f2a3b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -28
app.py CHANGED
@@ -56,22 +56,29 @@ def process_file(file, instructions, api_key):
56
  for plot in plots[:3]: # Ensure max 3 plots
57
  fig, ax = plt.subplots(figsize=(10, 6))
58
 
59
- # Apply preprocessing if any
60
- if plot['preprocessing']:
61
- exec(plot['preprocessing'])
 
 
 
 
 
 
 
62
 
63
  if plot['plot_type'] == 'bar':
64
- df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
65
  elif plot['plot_type'] == 'line':
66
- df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
67
  elif plot['plot_type'] == 'scatter':
68
- df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
69
  elif plot['plot_type'] == 'hist':
70
- df[plot['x']].hist(ax=ax)
71
 
72
  ax.set_title(plot['title'])
73
  ax.set_xlabel(plot['x'])
74
- ax.set_ylabel(plot['y'] if plot['y'] else 'Frequency')
75
  plt.tight_layout()
76
 
77
  buf = io.BytesIO()
@@ -91,23 +98,4 @@ def process_file(file, instructions, api_key):
91
  draw.text((10, 10), error_message, fill=(255, 0, 0))
92
  return [error_image] * 3
93
 
94
- with gr.Blocks(theme=gr.themes.Default()) as demo:
95
- gr.Markdown("# Data Analysis Dashboard")
96
-
97
- with gr.Row():
98
- file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
99
- instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
100
-
101
- api_key = gr.Textbox(label="Gemini API Key", type="password")
102
- submit = gr.Button("Generate Insights", variant="primary")
103
-
104
- output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
105
-
106
- submit.click(
107
- process_file,
108
- inputs=[file, instructions, api_key],
109
- outputs=output_images
110
- )
111
-
112
- if __name__ == "__main__":
113
- demo.launch()
 
56
  for plot in plots[:3]: # Ensure max 3 plots
57
  fig, ax = plt.subplots(figsize=(10, 6))
58
 
59
+ # Apply preprocessing
60
+ plot_df = df.copy()
61
+ if 'Group data by' in plot['preprocessing']:
62
+ group_by = plot['x']
63
+ agg_column = plot['y'][0] if isinstance(plot['y'], list) else plot['y']
64
+ plot_df = plot_df.groupby(group_by)[agg_column].sum().reset_index()
65
+ if 'Sort' in plot['preprocessing']:
66
+ plot_df = plot_df.sort_values(by=plot['y'][0] if isinstance(plot['y'], list) else plot['y'], ascending=False)
67
+ if 'Filter to keep only the top 5' in plot['preprocessing']:
68
+ plot_df = plot_df.head(5)
69
 
70
  if plot['plot_type'] == 'bar':
71
+ plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
72
  elif plot['plot_type'] == 'line':
73
+ plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
74
  elif plot['plot_type'] == 'scatter':
75
+ plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
76
  elif plot['plot_type'] == 'hist':
77
+ plot_df[plot['x']].hist(ax=ax)
78
 
79
  ax.set_title(plot['title'])
80
  ax.set_xlabel(plot['x'])
81
+ ax.set_ylabel(plot['y'][0] if isinstance(plot['y'], list) else plot['y'])
82
  plt.tight_layout()
83
 
84
  buf = io.BytesIO()
 
98
  draw.text((10, 10), error_message, fill=(255, 0, 0))
99
  return [error_image] * 3
100
 
101
+ # The rest of your code remains unchanged