File size: 7,161 Bytes
12ce912
 
 
72c5969
bca92aa
ec365ce
3c50a2d
904e6a1
7a74363
365373d
da863e4
6013c50
7a74363
 
4fc79a4
904e6a1
9a4fc1b
ec365ce
904e6a1
ec365ce
422964b
ec365ce
9a4fc1b
7a74363
 
 
 
 
9a4fc1b
72c5969
 
a5f2a3b
 
72c5969
2e5c04a
a5f2a3b
 
6ac004d
 
d3510d0
2e5c04a
 
a5f2a3b
2e5c04a
d3510d0
2e5c04a
a5f2a3b
 
6ac004d
 
 
a5f2a3b
72c5969
601022d
72c5969
 
 
 
 
 
 
 
 
3c50a2d
72c5969
9a4fc1b
bca92aa
ec365ce
72c5969
1b2886c
ec365ce
2e5c04a
bf5218c
2e5c04a
 
 
 
 
 
 
 
 
a5f2a3b
6ac004d
bf5218c
6ac004d
bf5218c
6ac004d
2e5c04a
 
6ac004d
2e5c04a
6ac004d
2e5c04a
6ac004d
2e5c04a
 
 
 
 
 
1b2886c
a5f2a3b
6ac004d
2e5c04a
 
1b2886c
ec365ce
9a4fc1b
1b2886c
9a4fc1b
1b2886c
b06a85b
1b2886c
12ce912
b06a85b
12ce912
9a4fc1b
3c50a2d
 
 
ec365ce
3c50a2d
b06a85b
9a4fc1b
7a74363
6013c50
7a74363
05370c5
7a74363
 
 
6013c50
 
ebf86ea
 
 
 
 
 
6013c50
 
7a74363
6013c50
 
7a74363
 
 
 
 
 
05370c5
7a74363
 
 
 
 
 
6013c50
7a74363
 
05370c5
7a74363
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import pandas as pd
import matplotlib.pyplot as plt
import io
import ast
from PIL import Image, ImageDraw
import google.generativeai as genai
import traceback
import os
from pywebio import start_server
from pywebio.input import file_upload, input
from pywebio.output import put_text, put_image, put_row, put_column, use_scope, put_buttons
from pywebio.session import run_js, set_env
import base64
import threading

def process_file(file, instructions):
    try:
        # Initialize Gemini
        api_key = os.environ.get('GEMINI_API_KEY')
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
        
        # Read uploaded file
        content = file['content']
        if file['filename'].endswith('.csv'):
            df = pd.read_csv(io.BytesIO(content))
        else:
            df = pd.read_excel(io.BytesIO(content))
        
        # Generate visualization code
        response = model.generate_content(f"""
            Analyze the following dataset and instructions:
            
            Data columns: {list(df.columns)}
            Data shape: {df.shape}
            Instructions: {instructions}
            
            Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization:
            1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap)
            2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages)
            3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size)
            4. Provide a clear, concise title that explains the insight

            Consider data density and choose visualizations that simplify and clarify the information.
            Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly).
            
            Return your response as a Python list of dictionaries:
            [
                {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
                {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
                {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}
            ]
        """)

        # Extract code block safely
        code_block = response.text
        if '```python' in code_block:
            code_block = code_block.split('```python')[1].split('```')[0].strip()
        elif '```' in code_block:
            code_block = code_block.split('```')[1].strip()
        
        print("Generated code block:")
        print(code_block)
        
        plots = ast.literal_eval(code_block)
        
        # Generate visualizations
        images = []
        for plot in plots[:3]:  # Ensure max 3 plots
            fig, ax = plt.subplots(figsize=(10, 6))
            
            # Apply preprocessing and aggregation
            plot_df = df.copy()
            if plot['agg_func'] == 'sum':
                plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index()
            elif plot['agg_func'] == 'mean':
                plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index()
            elif plot['agg_func'] == 'count':
                plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y'])
            
            if 'top_n' in plot and plot['top_n']:
                plot_df = plot_df.nlargest(plot['top_n'], plot['y'])
            
            if plot['plot_type'] == 'bar':
                plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
            elif plot['plot_type'] == 'line':
                plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
            elif plot['plot_type'] == 'scatter':
                plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax, 
                             c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')])
            elif plot['plot_type'] == 'hist':
                plot_df[plot['x']].hist(ax=ax, bins=20)
            elif plot['plot_type'] == 'pie':
                plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%')
            elif plot['plot_type'] == 'heatmap':
                pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y'])
                ax.imshow(pivot_df, cmap='YlOrRd')
                ax.set_xticks(range(len(pivot_df.columns)))
                ax.set_yticks(range(len(pivot_df.index)))
                ax.set_xticklabels(pivot_df.columns)
                ax.set_yticklabels(pivot_df.index)
            
            ax.set_title(plot['title'])
            if plot['plot_type'] != 'pie':
                ax.set_xlabel(plot['x'])
                ax.set_ylabel(plot['y'])
            plt.tight_layout()
            
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            img = Image.open(buf)
            images.append(img)
            plt.close(fig)

        return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))

    except Exception as e:
        error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_message)  # Print to console for debugging
        error_image = Image.new('RGB', (800, 400), (255, 255, 255))
        draw = ImageDraw.Draw(error_image)
        draw.text((10, 10), error_message, fill=(255, 0, 0))
        return [error_image] * 3

def data_analysis_dashboard():
    set_env(title="Data Analysis Dashboard")
    put_text("# Data Analysis Dashboard")
    
    with use_scope('form'):
        put_row([
            put_column([
                file_upload("Upload Dataset", accept=[".csv", ".xlsx"], name="file"),
                input("Analysis Instructions", type="text", placeholder="Describe the analysis you want...", name="instructions"),
                put_buttons(['Generate Insights'], onclick=[lambda: generate_insights()])
            ])
        ])
    
    with use_scope('output'):
        for i in range(3):
            put_scope(f'visualization_{i+1}')

def generate_insights():
    file = file_upload.files.get('file')
    instructions = input.inputs.get('instructions')
    
    if not file or not instructions:
        put_text("Please upload a file and provide instructions.")
        return
    
    images = process_file(file, instructions)
    
    for i, img in enumerate(images):
        buffered = io.BytesIO()
        img.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        with use_scope(f'visualization_{i+1}', clear=True):
            put_image(img_str, width='100%')

def main():
    data_analysis_dashboard()

if __name__ == '__main__':
    start_server(main, host='0.0.0.0', port=7860, debug=True, cdn=False, auto_open_webbrowser=True)