bluenevus commited on
Commit
a5f2a3b
·
verified ·
1 Parent(s): 72c5969

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -19,11 +19,24 @@ def process_file(file, instructions, api_key):
19
 
20
  # Generate visualization code
21
  response = model.generate_content(f"""
22
- Create 3 matplotlib visualization codes based on: {instructions}
 
23
  Data columns: {list(df.columns)}
24
- Return Python code as: [('title','plot_type','x','y'), ...]
25
- Allowed plot_types: bar, line, scatter, hist
26
- Use only DataFrame 'df' and these exact variable names.
 
 
 
 
 
 
 
 
 
 
 
 
27
  """)
28
 
29
  # Extract code block safely
@@ -42,20 +55,23 @@ def process_file(file, instructions, api_key):
42
  images = []
43
  for plot in plots[:3]: # Ensure max 3 plots
44
  fig, ax = plt.subplots(figsize=(10, 6))
45
- title, plot_type, x, y = plot
46
 
47
- if plot_type == 'bar':
48
- df.plot(kind='bar', x=x, y=y, ax=ax)
49
- elif plot_type == 'line':
50
- df.plot(kind='line', x=x, y=y, ax=ax)
51
- elif plot_type == 'scatter':
52
- df.plot(kind='scatter', x=x, y=y, ax=ax)
53
- elif plot_type == 'hist':
54
- df[x].hist(ax=ax)
 
 
 
 
55
 
56
- ax.set_title(title)
57
- ax.set_xlabel(x)
58
- ax.set_ylabel(y if y else 'Frequency')
59
  plt.tight_layout()
60
 
61
  buf = io.BytesIO()
 
19
 
20
  # Generate visualization code
21
  response = model.generate_content(f"""
22
+ Analyze the following dataset and instructions:
23
+
24
  Data columns: {list(df.columns)}
25
+ Instructions: {instructions}
26
+
27
+ Based on this, create 3 appropriate visualizations. For each visualization, provide:
28
+ 1. A title
29
+ 2. The most suitable plot type (choose from: bar, line, scatter, hist)
30
+ 3. The column to use for the x-axis
31
+ 4. The column(s) to use for the y-axis (can be a list for multiple columns, or None for histograms)
32
+ 5. Any necessary data preprocessing steps (e.g., grouping, sorting, etc.)
33
+
34
+ Return your response as a Python list of dictionaries:
35
+ [
36
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}},
37
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}},
38
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "preprocessing": "..."}}
39
+ ]
40
  """)
41
 
42
  # Extract code block safely
 
55
  images = []
56
  for plot in plots[:3]: # Ensure max 3 plots
57
  fig, ax = plt.subplots(figsize=(10, 6))
 
58
 
59
+ # Apply preprocessing if any
60
+ if plot['preprocessing']:
61
+ exec(plot['preprocessing'])
62
+
63
+ if plot['plot_type'] == 'bar':
64
+ df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
65
+ elif plot['plot_type'] == 'line':
66
+ df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
67
+ elif plot['plot_type'] == 'scatter':
68
+ df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax)
69
+ elif plot['plot_type'] == 'hist':
70
+ df[plot['x']].hist(ax=ax)
71
 
72
+ ax.set_title(plot['title'])
73
+ ax.set_xlabel(plot['x'])
74
+ ax.set_ylabel(plot['y'] if plot['y'] else 'Frequency')
75
  plt.tight_layout()
76
 
77
  buf = io.BytesIO()