bluenevus commited on
Commit
2e5c04a
·
verified ·
1 Parent(s): 05370c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -22
app.py CHANGED
@@ -22,20 +22,23 @@ def process_file(file, instructions, api_key):
22
  Analyze the following dataset and instructions:
23
 
24
  Data columns: {list(df.columns)}
 
25
  Instructions: {instructions}
26
 
27
- Based on this, create 3 appropriate visualizations. For each visualization, provide:
28
- 1. A title
29
- 2. The most suitable plot type (choose from: bar, line, scatter, hist)
30
- 3. The column to use for the x-axis
31
- 4. The column(s) to use for the y-axis (can be a list for multiple columns, or None for histograms)
32
- 5. Any necessary data preprocessing steps (e.g., grouping, sorting, etc.)
33
 
 
 
 
34
  Return your response as a Python list of dictionaries:
35
  [
36
- {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}},
37
- {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}},
38
- {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}}
39
  ]
40
  """)
41
 
@@ -56,29 +59,41 @@ def process_file(file, instructions, api_key):
56
  for plot in plots[:3]: # Ensure max 3 plots
57
  fig, ax = plt.subplots(figsize=(10, 6))
58
 
59
- # Apply preprocessing
60
  plot_df = df.copy()
61
- if 'Group data by' in plot['preprocessing']:
62
- group_by = plot['x']
63
- agg_column = plot['y'][0] if isinstance(plot['y'], list) else plot['y']
64
- plot_df = plot_df.groupby(group_by)[agg_column].sum().reset_index()
65
- if 'Sort' in plot['preprocessing']:
66
- plot_df = plot_df.sort_values(by=plot['y'][0] if isinstance(plot['y'], list) else plot['y'], ascending=False)
67
- if 'Filter to keep only the top 5' in plot['preprocessing']:
68
- plot_df = plot_df.head(5)
 
69
 
70
  if plot['plot_type'] == 'bar':
71
  plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
72
  elif plot['plot_type'] == 'line':
73
  plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
74
  elif plot['plot_type'] == 'scatter':
75
- plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
 
76
  elif plot['plot_type'] == 'hist':
77
- plot_df[plot['x']].hist(ax=ax)
 
 
 
 
 
 
 
 
 
78
 
79
  ax.set_title(plot['title'])
80
- ax.set_xlabel(plot['x'])
81
- ax.set_ylabel(plot['y'][0] if isinstance(plot['y'], list) else plot['y'])
 
82
  plt.tight_layout()
83
 
84
  buf = io.BytesIO()
 
22
  Analyze the following dataset and instructions:
23
 
24
  Data columns: {list(df.columns)}
25
+ Data shape: {df.shape}
26
  Instructions: {instructions}
27
 
28
+ Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization:
29
+ 1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap)
30
+ 2. Determine appropriate data aggregation (e.g., top 5 categories, monthly averages)
31
+ 3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size)
32
+ 4. Provide a clear, concise title that explains the insight
 
33
 
34
+ Consider data density and choose visualizations that simplify and clarify the information.
35
+ Limit the number of data points displayed to ensure readability (e.g., top 5, top 10).
36
+
37
  Return your response as a Python list of dictionaries:
38
  [
39
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
40
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
41
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}
42
  ]
43
  """)
44
 
 
59
  for plot in plots[:3]: # Ensure max 3 plots
60
  fig, ax = plt.subplots(figsize=(10, 6))
61
 
62
+ # Apply preprocessing and aggregation
63
  plot_df = df.copy()
64
+ if plot['agg_func'] == 'sum':
65
+ plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index()
66
+ elif plot['agg_func'] == 'mean':
67
+ plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index()
68
+ elif plot['agg_func'] == 'count':
69
+ plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y'])
70
+
71
+ if 'top_n' in plot and plot['top_n']:
72
+ plot_df = plot_df.nlargest(plot['top_n'], plot['y'])
73
 
74
  if plot['plot_type'] == 'bar':
75
  plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
76
  elif plot['plot_type'] == 'line':
77
  plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
78
  elif plot['plot_type'] == 'scatter':
79
+ plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax,
80
+ c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')])
81
  elif plot['plot_type'] == 'hist':
82
+ plot_df[plot['x']].hist(ax=ax, bins=20)
83
+ elif plot['plot_type'] == 'pie':
84
+ plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%')
85
+ elif plot['plot_type'] == 'heatmap':
86
+ pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y'])
87
+ ax.imshow(pivot_df, cmap='YlOrRd')
88
+ ax.set_xticks(range(len(pivot_df.columns)))
89
+ ax.set_yticks(range(len(pivot_df.index)))
90
+ ax.set_xticklabels(pivot_df.columns)
91
+ ax.set_yticklabels(pivot_df.index)
92
 
93
  ax.set_title(plot['title'])
94
+ if plot['plot_type'] != 'pie':
95
+ ax.set_xlabel(plot['x'])
96
+ ax.set_ylabel(plot['y'])
97
  plt.tight_layout()
98
 
99
  buf = io.BytesIO()