Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import gradio as gr
|
|
| 2 |
import pandas as pd
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import io
|
|
|
|
| 5 |
from PIL import Image, ImageDraw
|
| 6 |
import google.generativeai as genai
|
| 7 |
import traceback
|
|
@@ -16,33 +17,30 @@ def process_file(file, instructions, api_key):
|
|
| 16 |
file_path = file.name
|
| 17 |
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
|
| 18 |
|
| 19 |
-
# Generate visualization code
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
1. A title
|
| 28 |
-
2. The most suitable plot type (choose from: bar, line, scatter, hist)
|
| 29 |
-
3. The column to use for the x-axis
|
| 30 |
-
4. The column to use for the y-axis (use None for histograms)
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
plots = eval(response.text)
|
| 42 |
|
| 43 |
# Generate visualizations
|
| 44 |
images = []
|
| 45 |
-
for plot in plots:
|
| 46 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 47 |
title, plot_type, x, y = plot
|
| 48 |
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import matplotlib.pyplot as plt
|
| 4 |
import io
|
| 5 |
+
import ast
|
| 6 |
from PIL import Image, ImageDraw
|
| 7 |
import google.generativeai as genai
|
| 8 |
import traceback
|
|
|
|
| 17 |
file_path = file.name
|
| 18 |
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
|
| 19 |
|
| 20 |
+
# Generate visualization code
|
| 21 |
+
response = model.generate_content(f"""
|
| 22 |
+
Create 3 matplotlib visualization codes based on: {instructions}
|
| 23 |
+
Data columns: {list(df.columns)}
|
| 24 |
+
Return Python code as: [('title','plot_type','x','y'), ...]
|
| 25 |
+
Allowed plot_types: bar, line, scatter, hist
|
| 26 |
+
Use only DataFrame 'df' and these exact variable names.
|
| 27 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# Extract code block safely
|
| 30 |
+
code_block = response.text
|
| 31 |
+
if '```python' in code_block:
|
| 32 |
+
code_block = code_block.split('```python')[1].split('```')[0].strip()
|
| 33 |
+
elif '```' in code_block:
|
| 34 |
+
code_block = code_block.split('```')[1].strip()
|
| 35 |
+
|
| 36 |
+
print("Generated code block:")
|
| 37 |
+
print(code_block)
|
| 38 |
|
| 39 |
+
plots = ast.literal_eval(code_block)
|
|
|
|
| 40 |
|
| 41 |
# Generate visualizations
|
| 42 |
images = []
|
| 43 |
+
for plot in plots[:3]: # Ensure max 3 plots
|
| 44 |
fig, ax = plt.subplots(figsize=(10, 6))
|
| 45 |
title, plot_type, x, y = plot
|
| 46 |
|