bluenevus commited on
Commit
b06a85b
·
verified ·
1 Parent(s): 601022d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -2,7 +2,6 @@ import gradio as gr
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
- import json
6
  from PIL import Image, ImageDraw
7
  import google.generativeai as genai
8
  import traceback
@@ -11,7 +10,7 @@ def process_file(file, instructions, api_key):
11
  try:
12
  # Initialize Gemini
13
  genai.configure(api_key=api_key)
14
- model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
15
 
16
  # Read uploaded file
17
  file_path = file.name
@@ -29,24 +28,23 @@ def process_file(file, instructions, api_key):
29
  2. The most suitable plot type (choose from: bar, line, scatter, hist)
30
  3. The column to use for the x-axis
31
  4. The column to use for the y-axis (use None for histograms)
32
- 5. A brief explanation of why this visualization is appropriate
33
 
34
- Return your response as a JSON string in this format:
35
  [
36
- {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "explanation": "..."}},
37
- {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "explanation": "..."}},
38
- {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "explanation": "..."}}
39
  ]
40
  """
41
 
42
  response = model.generate_content(prompt)
43
- plots = json.loads(response.text)
44
 
45
  # Generate visualizations
46
  images = []
47
  for plot in plots:
48
  fig, ax = plt.subplots(figsize=(10, 6))
49
- title, plot_type, x, y = plot['title'], plot['plot_type'], plot['x'], plot['y']
50
 
51
  if plot_type == 'bar':
52
  df.plot(kind='bar', x=x, y=y, ax=ax)
@@ -66,10 +64,10 @@ def process_file(file, instructions, api_key):
66
  plt.savefig(buf, format='png')
67
  buf.seek(0)
68
  img = Image.open(buf)
69
- images.append((img, plot['explanation']))
70
  plt.close(fig)
71
 
72
- return images if len(images) == 3 else images + [(Image.new('RGB', (800, 600), (255,255,255)), "")]*(3-len(images))
73
 
74
  except Exception as e:
75
  error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
@@ -77,7 +75,7 @@ def process_file(file, instructions, api_key):
77
  error_image = Image.new('RGB', (800, 400), (255, 255, 255))
78
  draw = ImageDraw.Draw(error_image)
79
  draw.text((10, 10), error_message, fill=(255, 0, 0))
80
- return [(error_image, "An error occurred")] * 3
81
 
82
  with gr.Blocks(theme=gr.themes.Default()) as demo:
83
  gr.Markdown("# Data Analysis Dashboard")
@@ -90,12 +88,11 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
90
  submit = gr.Button("Generate Insights", variant="primary")
91
 
92
  output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
93
- output_texts = [gr.Textbox(label=f"Explanation {i+1}") for i in range(3)]
94
 
95
  submit.click(
96
  process_file,
97
  inputs=[file, instructions, api_key],
98
- outputs=output_images + output_texts
99
  )
100
 
101
  if __name__ == "__main__":
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
 
5
  from PIL import Image, ImageDraw
6
  import google.generativeai as genai
7
  import traceback
 
10
  try:
11
  # Initialize Gemini
12
  genai.configure(api_key=api_key)
13
+ model = genai.GenerativeModel('gemini-pro')
14
 
15
  # Read uploaded file
16
  file_path = file.name
 
28
  2. The most suitable plot type (choose from: bar, line, scatter, hist)
29
  3. The column to use for the x-axis
30
  4. The column to use for the y-axis (use None for histograms)
 
31
 
32
+ Return your response as a Python list of tuples:
33
  [
34
+ ("Title 1", "plot_type1", "x_column1", "y_column1"),
35
+ ("Title 2", "plot_type2", "x_column2", "y_column2"),
36
+ ("Title 3", "plot_type3", "x_column3", "y_column3")
37
  ]
38
  """
39
 
40
  response = model.generate_content(prompt)
41
+ plots = eval(response.text)
42
 
43
  # Generate visualizations
44
  images = []
45
  for plot in plots:
46
  fig, ax = plt.subplots(figsize=(10, 6))
47
+ title, plot_type, x, y = plot
48
 
49
  if plot_type == 'bar':
50
  df.plot(kind='bar', x=x, y=y, ax=ax)
 
64
  plt.savefig(buf, format='png')
65
  buf.seek(0)
66
  img = Image.open(buf)
67
+ images.append(img)
68
  plt.close(fig)
69
 
70
+ return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
71
 
72
  except Exception as e:
73
  error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
75
  error_image = Image.new('RGB', (800, 400), (255, 255, 255))
76
  draw = ImageDraw.Draw(error_image)
77
  draw.text((10, 10), error_message, fill=(255, 0, 0))
78
+ return [error_image] * 3
79
 
80
  with gr.Blocks(theme=gr.themes.Default()) as demo:
81
  gr.Markdown("# Data Analysis Dashboard")
 
88
  submit = gr.Button("Generate Insights", variant="primary")
89
 
90
  output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
 
91
 
92
  submit.click(
93
  process_file,
94
  inputs=[file, instructions, api_key],
95
+ outputs=output_images
96
  )
97
 
98
  if __name__ == "__main__":