File size: 2,922 Bytes
4fc79a4
12ce912
 
 
ec365ce
bca92aa
ec365ce
4fc79a4
12ce912
9a4fc1b
ec365ce
 
 
 
9a4fc1b
bca92aa
 
9a4fc1b
bca92aa
ec365ce
 
bca92aa
 
ec365ce
 
 
bca92aa
 
 
 
9a4fc1b
bca92aa
ec365ce
bca92aa
ec365ce
 
 
 
 
 
 
 
 
 
 
 
 
9a4fc1b
ec365ce
9a4fc1b
ec365ce
 
12ce912
bca92aa
12ce912
9a4fc1b
ec365ce
 
 
 
9a4fc1b
23a3b49
ec365ce
5be932a
23a3b49
80cfa8c
ec365ce
6cff8d5
bca92aa
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai

def process_file(api_key, file, instructions):
    try:
        # Initialize Gemini
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-pro')
        
        # Read uploaded file
        file_path = file.name
        df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
        
        # Generate visualization code
        response = model.generate_content(f"""
            Create 3 matplotlib visualization codes based on: {instructions}
            Data columns: {list(df.columns)}
            Return Python code as: [('title','plot_type','x','y'), ...]
            Allowed plot_types: bar, line, scatter, hist
            Use only DataFrame 'df' and these exact variable names.
        """)

        # Extract code block safely
        code_block = response.text.split('```python')[1].split('```')[0].strip()
        plots = ast.literal_eval(code_block)
        
        # Generate visualizations
        images = []
        for plot in plots[:3]:  # Ensure max 3 plots
            fig = plt.figure()
            title, plot_type, x, y = plot
            
            if plot_type == 'bar':
                df.plot.bar(x=x, y=y, ax=plt.gca())
            elif plot_type == 'line':
                df.plot.line(x=x, y=y, ax=plt.gca())
            elif plot_type == 'scatter':
                df.plot.scatter(x=x, y=y, ax=plt.gca())
            elif plot_type == 'hist':
                df[y].hist(ax=plt.gca())
            
            plt.title(title)
            buf = io.BytesIO()
            fig.savefig(buf, format='png', bbox_inches='tight')
            buf.seek(0)
            images.append(Image.open(buf))
            plt.close()

        return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))

    except Exception as e:
        error_image = Image.new('RGB', (800, 100), (255, 255, 255))
        draw = ImageDraw.Draw(error_image)
        draw.text((10, 40), f"Error: {str(e)}", fill=(255, 0, 0))
        return [error_image] * 3

with gr.Blocks(theme=gr.themes.Default(spacing_size="lg")) as demo:
    gr.Markdown("# Data Analysis Dashboard")
    
    with gr.Row():
        api_key = gr.Textbox(label="Gemini API Key", type="password")
        file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
    
    instructions = gr.Textbox(label="Analysis Instructions")
    submit = gr.Button("Generate Insights", variant="primary")
    
    with gr.Row():
        outputs = [gr.Image(label=f"Visualization {i+1}", width=600) for i in range(3)]

    submit.click(
        process_file,
        inputs=[api_key, file, instructions],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()