bluenevus commited on
Commit
6cff8d5
Β·
verified Β·
1 Parent(s): 8d1334d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -22
app.py CHANGED
@@ -7,67 +7,120 @@ from PIL import Image
7
  import ast
8
 
9
  def process_file(api_key, file, instructions):
 
10
  genai.configure(api_key=api_key)
11
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
12
 
13
  try:
 
14
  if file.name.endswith('.csv'):
15
  df = pd.read_csv(file.name)
16
  else:
17
  df = pd.read_excel(file.name)
18
  except Exception as e:
19
- return [None]*3
 
20
 
21
- # Enhanced prompt with strict variable requirements
22
- prompt = f"""Generate 3 matplotlib codes with these rules:
23
- 1. Use EXACTLY these variables: df (DataFrame), plt (matplotlib)
24
- 2. NO imports or additional variables
25
- 3. Start each visualization with:
26
- plt.figure(figsize=(16,9), dpi=120)
27
- plt.style.use('ggplot')
28
- 4. End with plt.tight_layout()
 
 
29
 
30
- Data columns: {list(df.columns)}
31
- First 3 rows: {df.head(3).to_dict()}
32
  User instructions: {instructions or 'None'}
33
 
34
  Format EXACTLY as:
35
  # Visualization 1
36
- [code using df and plt]
 
 
 
37
  """
38
 
39
- response = model.generate_content(prompt)
40
- code_blocks = response.text.split("# Visualization ")[1:4]
 
 
 
 
41
 
42
  visualizations = []
43
  for i, block in enumerate(code_blocks, 1):
44
  buf = io.BytesIO()
45
  try:
46
- # Advanced code cleaning
47
- cleaned_code = '\n'.join(
48
  line.replace('data', 'df').split('#')[0].strip()
49
  for line in block.split('\n')[1:]
50
  if line.strip() and
51
- not any(s in line.lower() for s in ['import', 'data =', 'data='])
52
- )
53
 
54
- # Validate syntax
55
  ast.parse(cleaned_code)
56
 
57
- # Execute with controlled environment
58
  exec_env = {'df': df, 'plt': plt}
59
  plt.figure(figsize=(16, 9), dpi=120)
60
  exec(cleaned_code, exec_env)
61
 
 
62
  plt.savefig(buf, format='png', bbox_inches='tight')
63
  plt.close()
64
  buf.seek(0)
65
  visualizations.append(Image.open(buf))
66
  except Exception as e:
67
  print(f"Visualization {i} Error: {str(e)}")
68
- print(f"Cleaned Code:\n{cleaned_code}")
69
  visualizations.append(None)
70
 
 
71
  return visualizations + [None]*(3-len(visualizations))
72
 
73
- # Rest of the Gradio interface remains the same
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import ast
8
 
9
  def process_file(api_key, file, instructions):
10
+ # Configure Gemini API with correct model
11
  genai.configure(api_key=api_key)
12
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
13
 
14
  try:
15
+ # Read uploaded file
16
  if file.name.endswith('.csv'):
17
  df = pd.read_csv(file.name)
18
  else:
19
  df = pd.read_excel(file.name)
20
  except Exception as e:
21
+ print(f"File Error: {str(e)}")
22
+ return [None, None, None]
23
 
24
+ # Enhanced prompt with strict coding rules
25
+ prompt = f"""Generate 3 distinct matplotlib visualization codes with these rules:
26
+ 1. Use ONLY these variables: df (existing DataFrame), plt
27
+ 2. No imports or additional data loading
28
+ 3. Each visualization must:
29
+ - Start with: plt.figure(figsize=(16,9), dpi=120)
30
+ - Use plt.style.use('ggplot')
31
+ - Include title, axis labels, and grid
32
+ - End with plt.tight_layout()
33
+ 4. Different chart types (bar, line, scatter, etc)
34
 
35
+ Dataset columns: {list(df.columns)}
36
+ Sample data: {df.head(3).to_dict()}
37
  User instructions: {instructions or 'None'}
38
 
39
  Format EXACTLY as:
40
  # Visualization 1
41
+ plt.figure(figsize=(16,9), dpi=120)
42
+ plt.style.use('ggplot')
43
+ # Visualization code using df
44
+ plt.tight_layout()
45
  """
46
 
47
+ try:
48
+ response = model.generate_content(prompt)
49
+ code_blocks = response.text.split("# Visualization ")[1:4]
50
+ except Exception as e:
51
+ print(f"Gemini Error: {str(e)}")
52
+ return [None, None, None]
53
 
54
  visualizations = []
55
  for i, block in enumerate(code_blocks, 1):
56
  buf = io.BytesIO()
57
  try:
58
+ # Clean and validate generated code
59
+ cleaned_code = '\n'.join([
60
  line.replace('data', 'df').split('#')[0].strip()
61
  for line in block.split('\n')[1:]
62
  if line.strip() and
63
+ not any(s in line.lower() for s in ['import', 'data=', 'data ='])
64
+ ])
65
 
66
+ # Syntax check
67
  ast.parse(cleaned_code)
68
 
69
+ # Execute code in controlled environment
70
  exec_env = {'df': df, 'plt': plt}
71
  plt.figure(figsize=(16, 9), dpi=120)
72
  exec(cleaned_code, exec_env)
73
 
74
+ # Save HD image
75
  plt.savefig(buf, format='png', bbox_inches='tight')
76
  plt.close()
77
  buf.seek(0)
78
  visualizations.append(Image.open(buf))
79
  except Exception as e:
80
  print(f"Visualization {i} Error: {str(e)}")
81
+ print(f"Problematic Code:\n{cleaned_code}")
82
  visualizations.append(None)
83
 
84
+ # Ensure exactly 3 outputs
85
  return visualizations + [None]*(3-len(visualizations))
86
 
87
+ # Gradio interface
88
+ with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
89
+ gr.Markdown("# πŸ” Data Visualization Generator")
90
+
91
+ with gr.Row():
92
+ api_key = gr.Textbox(
93
+ label="πŸ”‘ Gemini API Key",
94
+ type="password",
95
+ placeholder="Enter your API key here"
96
+ )
97
+ file = gr.File(
98
+ label="πŸ“ Upload Dataset",
99
+ file_types=[".csv", ".xlsx"],
100
+ type="filepath"
101
+ )
102
+
103
+ instructions = gr.Textbox(
104
+ label="πŸ’‘ Custom Instructions",
105
+ placeholder="E.g.: Compare sales trends across regions..."
106
+ )
107
+
108
+ submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
109
+
110
+ with gr.Row():
111
+ outputs = [
112
+ gr.Image(
113
+ label=f"Visualization {i+1}",
114
+ width=600,
115
+ height=400
116
+ ) for i in range(3)
117
+ ]
118
+
119
+ submit.click(
120
+ process_file,
121
+ inputs=[api_key, file, instructions],
122
+ outputs=outputs
123
+ )
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch()