bluenevus commited on
Commit
ec365ce
·
verified ·
1 Parent(s): 7de4e79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -71
app.py CHANGED
@@ -2,92 +2,73 @@ import gradio as gr
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
- from PIL import Image, ImageDraw, ImageFont
6
- import traceback
 
7
 
8
  def process_file(api_key, file, instructions):
9
  try:
 
 
 
 
10
  # Read uploaded file
11
- if file.name.endswith('.csv'):
12
- df = pd.read_csv(file.name)
13
- elif file.name.endswith('.xlsx'):
14
- df = pd.read_excel(file.name)
15
  else:
16
- raise ValueError("Unsupported file format")
17
-
18
- # Generate sample visualizations (replace with actual logic)
19
- fig1, ax1 = plt.subplots()
20
- df.plot(kind='bar', ax=ax1)
21
- ax1.set_title("Sample Bar Chart")
22
 
23
- fig2, ax2 = plt.subplots()
24
- df.plot(kind='line', ax=ax2)
25
- ax2.set_title("Sample Line Chart")
 
 
 
 
 
 
26
 
27
- fig3, ax3 = plt.subplots()
28
- df.plot(kind='hist', ax=ax3)
29
- ax3.set_title("Sample Histogram")
30
-
31
- # Convert plots to PIL Images
32
- def fig_to_image(fig):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  buf = io.BytesIO()
34
- fig.savefig(buf, format='png')
35
  buf.seek(0)
36
- return Image.open(buf)
 
37
 
38
- return [
39
- fig_to_image(fig1),
40
- fig_to_image(fig2),
41
- fig_to_image(fig3)
42
- ]
43
 
44
  except Exception as e:
45
- error_message = f"{str(e)}\n{traceback.format_exc()}"
46
- return [generate_error_image(error_message)] * 3
47
-
48
- def generate_error_image(message):
49
- """Create error indication image with message"""
50
- try:
51
- img = Image.new('RGB', (800, 400), color=(255, 255, 255))
52
- draw = ImageDraw.Draw(img)
53
- font = ImageFont.load_default()
54
-
55
- # Wrap text
56
- lines = []
57
- for line in message.split('\n'):
58
- if len(line) > 80:
59
- lines.extend([line[i:i+80] for i in range(0, len(line), 80)])
60
- else:
61
- lines.append(line)
62
-
63
- y_text = 10
64
- for line in lines[:20]: # Limit to 20 lines
65
- draw.text((10, y_text), line, font=font, fill=(255, 0, 0))
66
- y_text += 15
67
 
68
- return img
69
- except Exception as e:
70
- return Image.new('RGB', (800, 400), color=(255, 255, 255))
71
-
72
- # Gradio interface
73
  with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
74
- gr.Markdown("# AutoData Visualizer")
75
 
76
  with gr.Row():
77
  api_key = gr.Textbox(label="Gemini API Key", type="password")
78
- file = gr.File(label="Upload Data File", file_types=[".csv", ".xlsx"])
79
-
80
- instructions = gr.Textbox(label="Visualization Instructions")
81
- submit = gr.Button("Generate Insights", variant="primary")
82
 
83
- with gr.Row():
84
- outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
85
-
86
- submit.click(
87
- process_file,
88
- inputs=[api_key, file, instructions],
89
- outputs=outputs
90
- )
91
-
92
- if __name__ == "__main__":
93
- demo.launch()
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
+ import ast
6
+ from PIL import Image
7
+ import google.generativeai as genai
8
 
9
  def process_file(api_key, file, instructions):
10
  try:
11
+ # Initialize Gemini
12
+ genai.configure(api_key=api_key)
13
+ model = genai.GenerativeModel('gemini-pro')
14
+
15
  # Read uploaded file
16
+ file_path = file.name # Get full file path
17
+ if file_path.endswith('.csv'):
18
+ df = pd.read_csv(file_path)
 
19
  else:
20
+ df = pd.read_excel(file_path)
 
 
 
 
 
21
 
22
+ # Generate visualization code based on instructions
23
+ columns = list(df.columns)
24
+ response = model.generate_content(f"""
25
+ Create 3 matplotlib visualization codes based on: {instructions}
26
+ Data columns: {columns}
27
+ Return only Python code as: [('title','plot_type','x','y'), ...]
28
+ Allowed plot_types: bar, line, scatter, hist
29
+ Use only DataFrame 'df' and these exact variable names.
30
+ """)
31
 
32
+ # Parse and validate generated code
33
+ plots = ast.literal_eval(response.text.split('```')[-2].strip('python\n '))
34
+ if len(plots) != 3:
35
+ raise ValueError("Exactly 3 visualizations required")
36
+
37
+ # Generate plots
38
+ images = []
39
+ for plot in plots:
40
+ fig = plt.figure()
41
+ title, plot_type, x, y = plot
42
+
43
+ if plot_type == 'bar':
44
+ df.plot.bar(x=x, y=y, ax=plt.gca())
45
+ elif plot_type == 'line':
46
+ df.plot.line(x=x, y=y, ax=plt.gca())
47
+ elif plot_type == 'scatter':
48
+ df.plot.scatter(x=x, y=y, ax=plt.gca())
49
+ elif plot_type == 'hist':
50
+ df[y].hist(ax=plt.gca())
51
+
52
+ plt.title(title)
53
  buf = io.BytesIO()
54
+ fig.savefig(buf, format='png', bbox_inches='tight')
55
  buf.seek(0)
56
+ images.append(Image.open(buf))
57
+ plt.close()
58
 
59
+ return images
 
 
 
 
60
 
61
  except Exception as e:
62
+ error_image = Image.new('RGB', (800, 100), (255, 255, 255))
63
+ draw = ImageDraw.Draw(error_image)
64
+ draw.text((10, 40), f"Error: {str(e)}", fill=(255, 0, 0))
65
+ return [error_image] * 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
 
 
 
67
  with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
68
+ gr.Markdown("# Data Analysis Dashboard")
69
 
70
  with gr.Row():
71
  api_key = gr.Textbox(label="Gemini API Key", type="password")
72
+ file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
 
 
 
73
 
74
+ instructions = gr.Textbox(label="Analysis Instructions