bluenevus commited on
Commit
bca92aa
·
verified ·
1 Parent(s): ec365ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -3,7 +3,7 @@ import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
  import ast
6
- from PIL import Image
7
  import google.generativeai as genai
8
 
9
  def process_file(api_key, file, instructions):
@@ -13,30 +13,25 @@ def process_file(api_key, file, instructions):
13
  model = genai.GenerativeModel('gemini-pro')
14
 
15
  # Read uploaded file
16
- file_path = file.name # Get full file path
17
- if file_path.endswith('.csv'):
18
- df = pd.read_csv(file_path)
19
- else:
20
- df = pd.read_excel(file_path)
21
 
22
- # Generate visualization code based on instructions
23
- columns = list(df.columns)
24
  response = model.generate_content(f"""
25
  Create 3 matplotlib visualization codes based on: {instructions}
26
- Data columns: {columns}
27
- Return only Python code as: [('title','plot_type','x','y'), ...]
28
  Allowed plot_types: bar, line, scatter, hist
29
  Use only DataFrame 'df' and these exact variable names.
30
  """)
 
 
 
 
31
 
32
- # Parse and validate generated code
33
- plots = ast.literal_eval(response.text.split('```')[-2].strip('python\n '))
34
- if len(plots) != 3:
35
- raise ValueError("Exactly 3 visualizations required")
36
-
37
- # Generate plots
38
  images = []
39
- for plot in plots:
40
  fig = plt.figure()
41
  title, plot_type, x, y = plot
42
 
@@ -56,7 +51,7 @@ def process_file(api_key, file, instructions):
56
  images.append(Image.open(buf))
57
  plt.close()
58
 
59
- return images
60
 
61
  except Exception as e:
62
  error_image = Image.new('RGB', (800, 100), (255, 255, 255))
@@ -71,4 +66,17 @@ with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
71
  api_key = gr.Textbox(label="Gemini API Key", type="password")
72
  file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
73
 
74
- instructions = gr.Textbox(label="Analysis Instructions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import matplotlib.pyplot as plt
4
  import io
5
  import ast
6
+ from PIL import Image, ImageDraw
7
  import google.generativeai as genai
8
 
9
  def process_file(api_key, file, instructions):
 
13
  model = genai.GenerativeModel('gemini-pro')
14
 
15
  # Read uploaded file
16
+ file_path = file.name
17
+ df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
 
 
 
18
 
19
+ # Generate visualization code
 
20
  response = model.generate_content(f"""
21
  Create 3 matplotlib visualization codes based on: {instructions}
22
+ Data columns: {list(df.columns)}
23
+ Return Python code as: [('title','plot_type','x','y'), ...]
24
  Allowed plot_types: bar, line, scatter, hist
25
  Use only DataFrame 'df' and these exact variable names.
26
  """)
27
+
28
+ # Extract code block safely
29
+ code_block = response.text.split('```python')[1].split('```')[0].strip()
30
+ plots = ast.literal_eval(code_block)
31
 
32
+ # Generate visualizations
 
 
 
 
 
33
  images = []
34
+ for plot in plots[:3]: # Ensure max 3 plots
35
  fig = plt.figure()
36
  title, plot_type, x, y = plot
37
 
 
51
  images.append(Image.open(buf))
52
  plt.close()
53
 
54
+ return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
55
 
56
  except Exception as e:
57
  error_image = Image.new('RGB', (800, 100), (255, 255, 255))
 
66
  api_key = gr.Textbox(label="Gemini API Key", type="password")
67
  file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
68
 
69
+ instructions = gr.Textbox(label="Analysis Instructions")
70
+ submit = gr.Button("Generate Insights", variant="primary")
71
+
72
+ with gr.Row():
73
+ outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
74
+
75
+ submit.click(
76
+ process_file,
77
+ inputs=[api_key, file, instructions],
78
+ outputs=outputs
79
+ )
80
+
81
+ if __name__ == "__main__":
82
+ demo.launch()