bluenevus commited on
Commit
e5b6950
·
verified ·
1 Parent(s): 5fcb5c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -33
app.py CHANGED
@@ -6,7 +6,6 @@ import google.generativeai as genai
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
9
- # Configure Gemini with precise model version
10
  genai.configure(api_key=api_key)
11
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
12
 
@@ -16,34 +15,23 @@ def process_file(api_key, file, instructions):
16
  else:
17
  df = pd.read_excel(file.name)
18
 
19
- # Enhanced prompt with strict code requirements
20
- prompt = f"""Generate 3 matplotlib visualization codes for this data:
21
- Columns: {list(df.columns)}
22
- First 3 rows: {df.head(3).to_dict()}
 
 
23
 
24
- Requirements:
25
- 1. Each visualization must start with:
26
- plt.figure(figsize=(16,9), dpi=120)
27
- plt.style.use('seaborn')
28
- 2. Include complete plotting code with:
29
- - Title
30
- - Axis labels
31
- - Legend if needed
32
- - plt.tight_layout()
33
- 3. Different chart types (bar, line, scatter, etc)
34
- 4. No explanations - only valid Python code
35
-
36
- User instructions: {instructions}
37
 
38
  Format exactly as:
39
  # Visualization 1
40
- [complete code]
41
-
42
- # Visualization 2
43
- [complete code]
44
-
45
- # Visualization 3
46
- [complete code]
47
  """
48
 
49
  response = model.generate_content(prompt)
@@ -52,21 +40,22 @@ def process_file(api_key, file, instructions):
52
  visualizations = []
53
  for i, block in enumerate(code_blocks, 1):
54
  try:
55
- # Clean and validate code
56
  cleaned_code = '\n'.join([
57
  line.strip() for line in block.split('\n')
58
  if line.strip() and not line.startswith('```')
59
  ])
60
 
61
- # Create HD figure
62
  buf = io.BytesIO()
63
  plt.figure(figsize=(16, 9), dpi=120)
64
- plt.style.use('seaborn')
65
 
66
- # Execute generated code
67
- exec(cleaned_code, {'df': df, 'plt': plt})
 
 
 
 
 
68
 
69
- # Save HD image
70
  plt.tight_layout()
71
  plt.savefig(buf, format='png', bbox_inches='tight')
72
  plt.close()
@@ -77,7 +66,6 @@ def process_file(api_key, file, instructions):
77
  print(f"Visualization {i} Error: {str(e)}")
78
  visualizations.append(None)
79
 
80
- # Return exactly 3 images, filling with None if needed
81
  return visualizations + [None]*(3-len(visualizations))
82
 
83
  # Gradio interface
@@ -88,7 +76,7 @@ with gr.Blocks() as demo:
88
  api_key = gr.Textbox(label="Gemini API Key", type="password")
89
  file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
90
 
91
- instructions = gr.Textbox(label="Instructions (optional)")
92
  submit = gr.Button("Generate Visualizations")
93
 
94
  with gr.Row():
 
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
 
9
  genai.configure(api_key=api_key)
10
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
11
 
 
15
  else:
16
  df = pd.read_excel(file.name)
17
 
18
+ # Updated prompt with valid matplotlib styles
19
+ prompt = f"""Generate 3 matplotlib codes with these requirements:
20
+ 1. Start with: plt.figure(figsize=(16,9), dpi=120)
21
+ 2. Use one of these styles: ggplot, bmh, dark_background, fast
22
+ 3. Include: title, labels, grid, legend if needed
23
+ 4. Different chart types (bar, line, scatter, etc)
24
 
25
+ Data columns: {list(df.columns)}
26
+ Sample data: {df.head(3).to_dict()}
27
+ User instructions: {instructions or 'None'}
 
 
 
 
 
 
 
 
 
 
28
 
29
  Format exactly as:
30
  # Visualization 1
31
+ plt.figure(figsize=(16,9), dpi=120)
32
+ plt.style.use('ggplot') # Example valid style
33
+ [code]
34
+ plt.tight_layout()
 
 
 
35
  """
36
 
37
  response = model.generate_content(prompt)
 
40
  visualizations = []
41
  for i, block in enumerate(code_blocks, 1):
42
  try:
 
43
  cleaned_code = '\n'.join([
44
  line.strip() for line in block.split('\n')
45
  if line.strip() and not line.startswith('```')
46
  ])
47
 
 
48
  buf = io.BytesIO()
49
  plt.figure(figsize=(16, 9), dpi=120)
 
50
 
51
+ # Execute code with safe environment
52
+ exec_env = {
53
+ 'df': df,
54
+ 'plt': plt,
55
+ 'pd': pd
56
+ }
57
+ exec(cleaned_code, exec_env)
58
 
 
59
  plt.tight_layout()
60
  plt.savefig(buf, format='png', bbox_inches='tight')
61
  plt.close()
 
66
  print(f"Visualization {i} Error: {str(e)}")
67
  visualizations.append(None)
68
 
 
69
  return visualizations + [None]*(3-len(visualizations))
70
 
71
  # Gradio interface
 
76
  api_key = gr.Textbox(label="Gemini API Key", type="password")
77
  file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
78
 
79
+ instructions = gr.Textbox(label="Custom Instructions")
80
  submit = gr.Button("Generate Visualizations")
81
 
82
  with gr.Row():