File size: 3,054 Bytes
4fc79a4 ff768e2 4fc79a4 29f87f1 4fc79a4 5fcb5c4 69a028d 12598b6 4fc79a4 5fcb5c4 69a028d 5fcb5c4 69a028d 5fcb5c4 29f87f1 69a028d 5fcb5c4 4fc79a4 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 5fcb5c4 a888f55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import google.generativeai as genai
from PIL import Image
def process_file(api_key, file, instructions):
# Configure Gemini with precise model version
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
# Read file
if file.name.endswith('.csv'):
df = pd.read_csv(file.name)
else:
df = pd.read_excel(file.name)
# Enhanced prompt with strict code requirements
prompt = f"""Generate 3 matplotlib visualization codes for this data:
Columns: {list(df.columns)}
First 3 rows: {df.head(3).to_dict()}
Requirements:
1. Each visualization must start with:
plt.figure(figsize=(16,9), dpi=120)
plt.style.use('seaborn')
2. Include complete plotting code with:
- Title
- Axis labels
- Legend if needed
- plt.tight_layout()
3. Different chart types (bar, line, scatter, etc)
4. No explanations - only valid Python code
User instructions: {instructions}
Format exactly as:
# Visualization 1
[complete code]
# Visualization 2
[complete code]
# Visualization 3
[complete code]
"""
response = model.generate_content(prompt)
code_blocks = response.text.split("# Visualization ")[1:4]
visualizations = []
for i, block in enumerate(code_blocks, 1):
try:
# Clean and validate code
cleaned_code = '\n'.join([
line.strip() for line in block.split('\n')
if line.strip() and not line.startswith('```')
])
# Create HD figure
buf = io.BytesIO()
plt.figure(figsize=(16, 9), dpi=120)
plt.style.use('seaborn')
# Execute generated code
exec(cleaned_code, {'df': df, 'plt': plt})
# Save HD image
plt.tight_layout()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
visualizations.append(Image.open(buf))
except Exception as e:
print(f"Visualization {i} Error: {str(e)}")
visualizations.append(None)
# Return exactly 3 images, filling with None if needed
return visualizations + [None]*(3-len(visualizations))
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Data Visualization Tool")
with gr.Row():
api_key = gr.Textbox(label="Gemini API Key", type="password")
file = gr.File(label="Upload CSV/Excel", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Instructions (optional)")
submit = gr.Button("Generate Visualizations")
with gr.Row():
outputs = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
demo.launch() |