File size: 2,922 Bytes
4fc79a4 12ce912 ec365ce bca92aa ec365ce 4fc79a4 12ce912 9a4fc1b ec365ce 9a4fc1b bca92aa 9a4fc1b bca92aa ec365ce bca92aa ec365ce bca92aa 9a4fc1b bca92aa ec365ce bca92aa ec365ce 9a4fc1b ec365ce 9a4fc1b ec365ce 12ce912 bca92aa 12ce912 9a4fc1b ec365ce 9a4fc1b 23a3b49 ec365ce 5be932a 23a3b49 80cfa8c ec365ce 6cff8d5 bca92aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai
def process_file(api_key, file, instructions):
try:
# Initialize Gemini
genai.configure(api_key=api_key)
model = genai.GenerativeModel('gemini-pro')
# Read uploaded file
file_path = file.name
df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
# Generate visualization code
response = model.generate_content(f"""
Create 3 matplotlib visualization codes based on: {instructions}
Data columns: {list(df.columns)}
Return Python code as: [('title','plot_type','x','y'), ...]
Allowed plot_types: bar, line, scatter, hist
Use only DataFrame 'df' and these exact variable names.
""")
# Extract code block safely
code_block = response.text.split('```python')[1].split('```')[0].strip()
plots = ast.literal_eval(code_block)
# Generate visualizations
images = []
for plot in plots[:3]: # Ensure max 3 plots
fig = plt.figure()
title, plot_type, x, y = plot
if plot_type == 'bar':
df.plot.bar(x=x, y=y, ax=plt.gca())
elif plot_type == 'line':
df.plot.line(x=x, y=y, ax=plt.gca())
elif plot_type == 'scatter':
df.plot.scatter(x=x, y=y, ax=plt.gca())
elif plot_type == 'hist':
df[y].hist(ax=plt.gca())
plt.title(title)
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight')
buf.seek(0)
images.append(Image.open(buf))
plt.close()
return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
except Exception as e:
error_image = Image.new('RGB', (800, 100), (255, 255, 255))
draw = ImageDraw.Draw(error_image)
draw.text((10, 40), f"Error: {str(e)}", fill=(255, 0, 0))
return [error_image] * 3
with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
gr.Markdown("# Data Analysis Dashboard")
with gr.Row():
api_key = gr.Textbox(label="Gemini API Key", type="password")
file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
instructions = gr.Textbox(label="Analysis Instructions")
submit = gr.Button("Generate Insights", variant="primary")
with gr.Row():
outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]
submit.click(
process_file,
inputs=[api_key, file, instructions],
outputs=outputs
)
if __name__ == "__main__":
demo.launch() |