bluenevus commited on
Commit
72c5969
·
verified ·
1 Parent(s): 422964b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -22
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
 
5
  from PIL import Image, ImageDraw
6
  import google.generativeai as genai
7
  import traceback
@@ -16,33 +17,30 @@ def process_file(file, instructions, api_key):
16
  file_path = file.name
17
  df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
18
 
19
- # Generate visualization code using Gemini
20
- prompt = f"""
21
- Analyze the following dataset and instructions:
22
-
23
- Data columns: {list(df.columns)}
24
- Instructions: {instructions}
25
-
26
- Based on this, create 3 appropriate visualizations. For each visualization, provide:
27
- 1. A title
28
- 2. The most suitable plot type (choose from: bar, line, scatter, hist)
29
- 3. The column to use for the x-axis
30
- 4. The column to use for the y-axis (use None for histograms)
31
 
32
- Return your response as a Python list of tuples:
33
- [
34
- ("Title 1", "plot_type1", "x_column1", "y_column1"),
35
- ("Title 2", "plot_type2", "x_column2", "y_column2"),
36
- ("Title 3", "plot_type3", "x_column3", "y_column3")
37
- ]
38
- """
 
 
39
 
40
- response = model.generate_content(prompt)
41
- plots = eval(response.text)
42
 
43
  # Generate visualizations
44
  images = []
45
- for plot in plots:
46
  fig, ax = plt.subplots(figsize=(10, 6))
47
  title, plot_type, x, y = plot
48
 
 
2
  import pandas as pd
3
  import matplotlib.pyplot as plt
4
  import io
5
+ import ast
6
  from PIL import Image, ImageDraw
7
  import google.generativeai as genai
8
  import traceback
 
17
  file_path = file.name
18
  df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
19
 
20
+ # Generate visualization code
21
+ response = model.generate_content(f"""
22
+ Create 3 matplotlib visualization codes based on: {instructions}
23
+ Data columns: {list(df.columns)}
24
+ Return Python code as: [('title','plot_type','x','y'), ...]
25
+ Allowed plot_types: bar, line, scatter, hist
26
+ Use only DataFrame 'df' and these exact variable names.
27
+ """)
 
 
 
 
28
 
29
+ # Extract code block safely
30
+ code_block = response.text
31
+ if '```python' in code_block:
32
+ code_block = code_block.split('```python')[1].split('```')[0].strip()
33
+ elif '```' in code_block:
34
+ code_block = code_block.split('```')[1].strip()
35
+
36
+ print("Generated code block:")
37
+ print(code_block)
38
 
39
+ plots = ast.literal_eval(code_block)
 
40
 
41
  # Generate visualizations
42
  images = []
43
+ for plot in plots[:3]: # Ensure max 3 plots
44
  fig, ax = plt.subplots(figsize=(10, 6))
45
  title, plot_type, x, y = plot
46