bluenevus commited on
Commit
a888f55
Β·
verified Β·
1 Parent(s): 69a028d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -4
app.py CHANGED
@@ -6,11 +6,9 @@ import google.generativeai as genai
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
9
- # Configure Gemini API
10
  genai.configure(api_key=api_key)
11
  model = genai.GenerativeModel('gemini-2.5-pro-latest')
12
 
13
- # File handling with error prevention
14
  try:
15
  if file.name.endswith('.csv'):
16
  df = pd.read_csv(file.name)
@@ -19,7 +17,7 @@ def process_file(api_key, file, instructions):
19
  except Exception as e:
20
  return [f"File Error: {str(e)}"] * 3
21
 
22
- # Enhanced prompt template
23
  prompt = f"""Generate exactly 3 distinct matplotlib visualizations for:
24
  Columns: {list(df.columns)}
25
  Data types: {dict(df.dtypes)}
@@ -30,4 +28,65 @@ def process_file(api_key, file, instructions):
30
  2. Professional styling (seaborn, grid, proper labels)
31
  3. Diverse chart types (include at least 1 advanced visualization)
32
 
33
- User instructions: {instructions or 'None provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from PIL import Image
7
 
8
  def process_file(api_key, file, instructions):
 
9
  genai.configure(api_key=api_key)
10
  model = genai.GenerativeModel('gemini-2.5-pro-latest')
11
 
 
12
  try:
13
  if file.name.endswith('.csv'):
14
  df = pd.read_csv(file.name)
 
17
  except Exception as e:
18
  return [f"File Error: {str(e)}"] * 3
19
 
20
+ # Properly terminated prompt string
21
  prompt = f"""Generate exactly 3 distinct matplotlib visualizations for:
22
  Columns: {list(df.columns)}
23
  Data types: {dict(df.dtypes)}
 
28
  2. Professional styling (seaborn, grid, proper labels)
29
  3. Diverse chart types (include at least 1 advanced visualization)
30
 
31
+ User instructions: {instructions or 'None provided'}
32
+
33
+ Format response strictly as:
34
+ # Visualization 1
35
+ plt.figure(figsize=(16,9), dpi=120)
36
+ [code]
37
+ plt.tight_layout()
38
+
39
+ # Visualization 2
40
+ ...
41
+ """ # Closing triple quotes added here
42
+
43
+ response = model.generate_content(prompt)
44
+ code_blocks = response.text.split("# Visualization ")[1:4]
45
+
46
+ visualizations = []
47
+ for i, block in enumerate(code_blocks, 1):
48
+ try:
49
+ plt.figure(figsize=(16, 9), dpi=120)
50
+ plt.style.use('seaborn')
51
+
52
+ cleaned_code = '\n'.join([line.strip() for line in block.split('\n')[1:] if line.strip()])
53
+ exec(cleaned_code, {'df': df, 'plt': plt})
54
+ plt.title(f"Visualization {i}", fontsize=14)
55
+ plt.tight_layout()
56
+
57
+ buf = io.BytesIO()
58
+ plt.savefig(buf, format='png', dpi=120, bbox_inches='tight')
59
+ plt.close()
60
+ buf.seek(0)
61
+ visualizations.append(Image.open(buf))
62
+ except Exception as e:
63
+ print(f"Visualization {i} failed: {str(e)}")
64
+ visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
65
+
66
+ while len(visualizations) < 3:
67
+ visualizations.append(Image.new('RGB', (1920, 1080), color=(73, 109, 137)))
68
+
69
+ return visualizations[:3]
70
+
71
+ # Gradio interface
72
+ with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
73
+ gr.Markdown("# **HD Data Visualizer** πŸ“Šβœ¨")
74
+
75
+ with gr.Row():
76
+ api_key = gr.Textbox(label="πŸ”‘ Gemini API Key", type="password")
77
+ file = gr.File(label="πŸ“ Upload Dataset", file_types=[".csv", ".xlsx"])
78
+
79
+ instructions = gr.Textbox(label="πŸ’‘ Custom Instructions (optional)",
80
+ placeholder="E.g.: Focus on time series patterns...")
81
+ submit = gr.Button("πŸš€ Generate Visualizations", variant="primary")
82
+
83
+ with gr.Row():
84
+ outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
85
+
86
+ submit.click(
87
+ process_file,
88
+ inputs=[api_key, file, instructions],
89
+ outputs=outputs
90
+ )
91
+
92
+ demo.launch()