bluenevus's picture
Update app.py
5fcb5c4 verified
raw
history blame
3.05 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
def process_file(api_key, file, instructions):
# Configure Gemini with precise model version
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read file
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
# Enhanced prompt with strict code requirements
prompt = f"""Generate 3 matplotlib visualization codes for this data:
Columns: {list(df.columns)}
First 3 rows: {df.head(3).to_dict()}
Requirements:
1. Each visualization must start with:
plt.figure(figsize=(16,9), dpi=120)
plt.style.use('seaborn')
2. Include complete plotting code with:
- Title
- Axis labels
- Legend if needed
- plt.tight_layout()
3. Different chart types (bar, line, scatter, etc)
4. No explanations - only valid Python code
User instructions: {instructions}
Format exactly as:
# Visualization 1
[complete code]
# Visualization 2
[complete code]
# Visualization 3
[complete code]
"""
response = model.generate_content(prompt)
code_blocks = response.text.split("# Visualization ")[1:4]
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
# Clean and validate code
cleaned_code = '\n'.join([
line.strip() for line in block.split('\n')
if line.strip() and not line.startswith('```')
])
# Create HD figure
buf = io.BytesIO()
plt.figure(figsize=(16, 9), dpi=120)
plt.style.use('seaborn')
# Execute generated code
exec(cleaned_code, {'df': df, 'plt': plt})
# Save HD image
plt.tight_layout()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
visualizations.append(Image.open(buf))
except Exception as e:
print(f"Visualization {i} Error: {str(e)}")
visualizations.append(None)
# Return exactly 3 images, filling with None if needed
return visualizations + [None]*(3-len(visualizations))
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Data Visualization Tool")
with gr.Row():
api_key = gr.Textbox(label="Gemini API Key", type="password")
file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Instructions (optional)")
submit = gr.Button("Generate Visualizations")
with gr.Row():
outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
demo.launch()