bluenevus commited on
Commit
601022d
·
verified ·
1 Parent(s): 3c50a2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -31
app.py CHANGED
@@ -2,12 +2,12 @@ import gradio as gr
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
- import ast
6
  from PIL import Image, ImageDraw
7
  import google.generativeai as genai
8
  import traceback
9
 
10
- def process_file(api_key, file, instructions):
11
  try:
12
  # Initialize Gemini
13
  genai.configure(api_key=api_key)
@@ -17,29 +17,36 @@ def process_file(api_key, file, instructions):
17
  file_path = file.name
18
  df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
19
 
20
- # Generate visualization code
21
- response = model.generate_content(f"""
22
- Create 3 matplotlib visualization codes based on: {instructions}
23
- Data columns: {list(df.columns)}
24
- Return Python code as: [('title','plot_type','x','y'), ...]
25
- Allowed plot_types: bar, line, scatter, hist
26
- Use only DataFrame 'df' and these exact variable names.
27
- """)
28
-
29
- # Extract code block safely
30
- code_block = response.text.split('```python')[1].split('```')[0].strip()
31
 
32
- # Print the code block for debugging
33
- print("Generated code block:")
34
- print(code_block)
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- plots = ast.literal_eval(code_block)
 
37
 
38
  # Generate visualizations
39
  images = []
40
- for plot in plots[:3]: # Ensure max 3 plots
41
  fig, ax = plt.subplots(figsize=(10, 6))
42
- title, plot_type, x, y = plot
43
 
44
  if plot_type == 'bar':
45
  df.plot(kind='bar', x=x, y=y, ax=ax)
@@ -48,21 +55,21 @@ def process_file(api_key, file, instructions):
48
  elif plot_type == 'scatter':
49
  df.plot(kind='scatter', x=x, y=y, ax=ax)
50
  elif plot_type == 'hist':
51
- df[y].hist(ax=ax)
52
 
53
  ax.set_title(title)
54
  ax.set_xlabel(x)
55
- ax.set_ylabel(y)
56
  plt.tight_layout()
57
 
58
  buf = io.BytesIO()
59
  plt.savefig(buf, format='png')
60
  buf.seek(0)
61
  img = Image.open(buf)
62
- images.append(img)
63
  plt.close(fig)
64
 
65
- return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
66
 
67
  except Exception as e:
68
  error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
@@ -70,25 +77,25 @@ def process_file(api_key, file, instructions):
70
  error_image = Image.new('RGB', (800, 400), (255, 255, 255))
71
  draw = ImageDraw.Draw(error_image)
72
  draw.text((10, 10), error_message, fill=(255, 0, 0))
73
- return [error_image] * 3
74
 
75
- with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
76
  gr.Markdown("# Data Analysis Dashboard")
77
 
78
  with gr.Row():
79
- api_key = gr.Textbox(label="Gemini API Key", type="password")
80
  file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
 
81
 
82
- instructions = gr.Textbox(label="Analysis Instructions")
83
  submit = gr.Button("Generate Insights", variant="primary")
84
 
85
- with gr.Row():
86
- outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
87
 
88
  submit.click(
89
  process_file,
90
- inputs=[api_key, file, instructions],
91
- outputs=outputs
92
  )
93
 
94
  if __name__ == "__main__":
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
+ import json
6
  from PIL import Image, ImageDraw
7
  import google.generativeai as genai
8
  import traceback
9
 
10
+ def process_file(file, instructions, api_key):
11
  try:
12
  # Initialize Gemini
13
  genai.configure(api_key=api_key)
 
17
  file_path = file.name
18
  df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
19
 
20
+ # Generate visualization code using Gemini
21
+ prompt = f"""
22
+ Analyze the following dataset and instructions:
23
+
24
+ Data columns: {list(df.columns)}
25
+ Instructions: {instructions}
 
 
 
 
 
26
 
27
+ Based on this, create 3 appropriate visualizations. For each visualization, provide:
28
+ 1. A title
29
+ 2. The most suitable plot type (choose from: bar, line, scatter, hist)
30
+ 3. The column to use for the x-axis
31
+ 4. The column to use for the y-axis (use None for histograms)
32
+ 5. A brief explanation of why this visualization is appropriate
33
+
34
+ Return your response as a JSON string in this format:
35
+ [
36
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "explanation": "..."}},
37
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "explanation": "..."}},
38
+ {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "explanation": "..."}}
39
+ ]
40
+ """
41
 
42
+ response = model.generate_content(prompt)
43
+ plots = json.loads(response.text)
44
 
45
  # Generate visualizations
46
  images = []
47
+ for plot in plots:
48
  fig, ax = plt.subplots(figsize=(10, 6))
49
+ title, plot_type, x, y = plot['title'], plot['plot_type'], plot['x'], plot['y']
50
 
51
  if plot_type == 'bar':
52
  df.plot(kind='bar', x=x, y=y, ax=ax)
 
55
  elif plot_type == 'scatter':
56
  df.plot(kind='scatter', x=x, y=y, ax=ax)
57
  elif plot_type == 'hist':
58
+ df[x].hist(ax=ax)
59
 
60
  ax.set_title(title)
61
  ax.set_xlabel(x)
62
+ ax.set_ylabel(y if y else 'Frequency')
63
  plt.tight_layout()
64
 
65
  buf = io.BytesIO()
66
  plt.savefig(buf, format='png')
67
  buf.seek(0)
68
  img = Image.open(buf)
69
+ images.append((img, plot['explanation']))
70
  plt.close(fig)
71
 
72
+ return images if len(images) == 3 else images + [(Image.new('RGB', (800, 600), (255,255,255)), "")]*(3-len(images))
73
 
74
  except Exception as e:
75
  error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
77
  error_image = Image.new('RGB', (800, 400), (255, 255, 255))
78
  draw = ImageDraw.Draw(error_image)
79
  draw.text((10, 10), error_message, fill=(255, 0, 0))
80
+ return [(error_image, "An error occurred")] * 3
81
 
82
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
83
  gr.Markdown("# Data Analysis Dashboard")
84
 
85
  with gr.Row():
 
86
  file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
87
+ instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
88
 
89
+ api_key = gr.Textbox(label="Gemini API Key", type="password")
90
  submit = gr.Button("Generate Insights", variant="primary")
91
 
92
+ output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
93
+ output_texts = [gr.Textbox(label=f"Explanation {i+1}") for i in range(3)]
94
 
95
  submit.click(
96
  process_file,
97
+ inputs=[file, instructions, api_key],
98
+ outputs=output_images + output_texts
99
  )
100
 
101
  if __name__ == "__main__":