bluenevus commited on
Commit
5fcb5c4
Β·
verified Β·
1 Parent(s): 12598b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -40
app.py CHANGED
@@ -6,39 +6,45 @@ import google.generativeai as genai
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
 
9
  genai.configure(api_key=api_key)
10
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
11
 
12
- try:
13
- if file.name.endswith('.csv'):
14
- df = pd.read_csv(file.name)
15
- else:
16
- df = pd.read_excel(file.name)
17
- except Exception as e:
18
- return [f"File Error: {str(e)}"] * 3
19
 
20
- # Properly terminated prompt string
21
- prompt = f"""Generate exactly 3 distinct matplotlib visualizations for:
22
  Columns: {list(df.columns)}
23
- Data types: {dict(df.dtypes)}
24
- Sample data: {df.head(3).to_dict()}
25
 
26
  Requirements:
27
- 1. 1920x1080 resolution (figsize=(16,9), dpi=120)
28
- 2. Professional styling (seaborn, grid, proper labels)
29
- 3. Diverse chart types (include at least 1 advanced visualization)
 
 
 
 
 
 
 
30
 
31
- User instructions: {instructions or 'None provided'}
32
 
33
- Format response strictly as:
34
  # Visualization 1
35
- plt.figure(figsize=(16,9), dpi=120)
36
- [code]
37
- plt.tight_layout()
38
 
39
  # Visualization 2
40
- ...
41
- """ # Closing triple quotes added here
 
 
 
42
 
43
  response = model.generate_content(prompt)
44
  code_blocks = response.text.split("# Visualization ")[1:4]
@@ -46,42 +52,47 @@ def process_file(api_key, file, instructions):
46
  visualizations = []
47
  for i, block in enumerate(code_blocks, 1):
48
  try:
 
 
 
 
 
 
 
 
49
  plt.figure(figsize=(16, 9), dpi=120)
50
  plt.style.use('seaborn')
51
 
52
- cleaned_code = '\n'.join([line.strip() for line in block.split('\n')[1:] if line.strip()])
53
  exec(cleaned_code, {'df': df, 'plt': plt})
54
- plt.title(f"Visualization {i}", fontsize=14)
 
55
  plt.tight_layout()
56
-
57
- buf = io.BytesIO()
58
- plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
59
  plt.close()
 
60
  buf.seek(0)
61
  visualizations.append(Image.open(buf))
62
  except Exception as e:
63
- print(f"Visualization {i} failed: {str(e)}")
64
- visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
65
-
66
- while len(visualizations) < 3:
67
- visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
68
 
69
- return visualizations[:3]
 
70
 
71
  # Gradio interface
72
- with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
73
- gr.Markdown("# **HD Data Visualizer** πŸ“Šβœ¨")
74
 
75
  with gr.Row():
76
- api_key = gr.Textbox(label="πŸ”‘ Gemini API Key", type="password")
77
- file = gr.File(label="πŸ“ Upload Dataset", file_types=[".csv", ".xlsx"])
78
 
79
- instructions = gr.Textbox(label="πŸ’‘ Custom Instructions (optional)",
80
- placeholder="E.g.: Focus on time series patterns...")
81
- submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
82
 
83
  with gr.Row():
84
- outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
85
 
86
  submit.click(
87
  process_file,
 
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
9
+ # Configure Gemini with precise model version
10
  genai.configure(api_key=api_key)
11
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
12
 
13
+ # Read file
14
+ if file.name.endswith('.csv'):
15
+ df = pd.read_csv(file.name)
16
+ else:
17
+ df = pd.read_excel(file.name)
 
 
18
 
19
+ # Enhanced prompt with strict code requirements
20
+ prompt = f"""Generate 3 matplotlib visualization codes for this data:
21
  Columns: {list(df.columns)}
22
+ First 3 rows: {df.head(3).to_dict()}
 
23
 
24
  Requirements:
25
+ 1. Each visualization must start with:
26
+ plt.figure(figsize=(16,9), dpi=120)
27
+ plt.style.use('seaborn')
28
+ 2. Include complete plotting code with:
29
+ - Title
30
+ - Axis labels
31
+ - Legend if needed
32
+ - plt.tight_layout()
33
+ 3. Different chart types (bar, line, scatter, etc)
34
+ 4. No explanations - only valid Python code
35
 
36
+ User instructions: {instructions}
37
 
38
+ Format exactly as:
39
  # Visualization 1
40
+ [complete code]
 
 
41
 
42
  # Visualization 2
43
+ [complete code]
44
+
45
+ # Visualization 3
46
+ [complete code]
47
+ """
48
 
49
  response = model.generate_content(prompt)
50
  code_blocks = response.text.split("# Visualization ")[1:4]
 
52
  visualizations = []
53
  for i, block in enumerate(code_blocks, 1):
54
  try:
55
+ # Clean and validate code
56
+ cleaned_code = '\n'.join([
57
+ line.strip() for line in block.split('\n')
58
+ if line.strip() and not line.startswith('```')
59
+ ])
60
+
61
+ # Create HD figure
62
+ buf = io.BytesIO()
63
  plt.figure(figsize=(16, 9), dpi=120)
64
  plt.style.use('seaborn')
65
 
66
+ # Execute generated code
67
  exec(cleaned_code, {'df': df, 'plt': plt})
68
+
69
+ # Save HD image
70
  plt.tight_layout()
71
+ plt.savefig(buf, format='png', bbox_inches='tight')
 
 
72
  plt.close()
73
+
74
  buf.seek(0)
75
  visualizations.append(Image.open(buf))
76
  except Exception as e:
77
+ print(f"Visualization {i} Error: {str(e)}")
78
+ visualizations.append(None)
 
 
 
79
 
80
+ # Return exactly 3 images, filling with None if needed
81
+ return visualizations + [None]*(3-len(visualizations))
82
 
83
  # Gradio interface
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("# Data Visualization Tool")
86
 
87
  with gr.Row():
88
+ api_key = gr.Textbox(label="Gemini API Key", type="password")
89
+ file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
90
 
91
+ instructions = gr.Textbox(label="Instructions (optional)")
92
+ submit = gr.Button("Generate Visualizations")
 
93
 
94
  with gr.Row():
95
+ outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
96
 
97
  submit.click(
98
  process_file,