File size: 9,477 Bytes
b7ede47
12ce912
72c5969
3c50a2d
7ccda80
b7ede47
 
 
efa0dd6
b7ede47
 
 
 
4fc79a4
b7ede47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efa0dd6
 
 
 
 
 
b7ede47
 
 
ceb21fe
 
 
 
 
 
 
 
 
 
 
 
 
efa0dd6
 
b7ede47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ccda80
9a4fc1b
7ccda80
 
 
 
 
f00eea3
ec365ce
422964b
ec365ce
72c5969
 
a5f2a3b
 
72c5969
2e5c04a
a5f2a3b
 
6ac004d
 
d3510d0
2e5c04a
 
 
d3510d0
2e5c04a
a5f2a3b
 
6ac004d
 
 
a5f2a3b
72c5969
601022d
72c5969
 
 
 
 
 
 
 
b7ede47
9a4fc1b
b7ede47
 
9a4fc1b
b7ede47
 
 
 
 
 
 
 
7a74363
b7ede47
 
7a74363
b7ede47
 
 
 
 
 
 
 
 
 
 
 
 
05370c5
b7ede47
 
 
efa0dd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7ede47
 
 
ceb21fe
 
b7ede47
efa0dd6
b7ede47
7ccda80
b7ede47
7ccda80
ceb21fe
 
b7ede47
ceb21fe
 
 
 
b7ede47
7ccda80
ceb21fe
 
6013c50
ceb21fe
 
 
 
 
05370c5
b7ede47
7ccda80
da1fc70
7ccda80
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import base64
import io
import ast
import traceback
import os
from threading import Thread

import dash
from dash import dcc, html, Input, Output, State, callback_context
import dash_bootstrap_components as dbc
import pandas as pd
import plotly.graph_objs as go
import google.generativeai as genai

# Initialize Dash app
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

# Layout
app.layout = dbc.Container([
    html.H1("Data Analysis Dashboard", className="my-4"),
    dbc.Card([
        dbc.CardBody([
            dcc.Upload(
                id='upload-data',
                children=html.Div([
                    'Drag and Drop or ',
                    html.A('Select Files')
                ]),
                style={
                    'width': '100%',
                    'height': '60px',
                    'lineHeight': '60px',
                    'borderWidth': '1px',
                    'borderStyle': 'dashed',
                    'borderRadius': '5px',
                    'textAlign': 'center',
                    'margin': '10px'
                },
                multiple=False
            ),
            html.Div(id='upload-feedback', className="mt-2"),
            html.Div([
                html.Span(id='filename-display', className="mr-2"),
                dbc.Button("Delete File", id="delete-file-button", color="danger", className="mt-2", style={'display': 'none'})
            ], className="mt-2"),
            dbc.Input(id="instructions", placeholder="Describe the analysis you want...", type="text", className="mt-3"),
            dbc.Button("Generate Insights", id="submit-button", color="primary", className="mt-3"),
        ])
    ], className="mb-4"),
    html.Div(id="error-message", className="text-danger mb-3"),
    dcc.Loading(
        id="loading-visualizations",
        type="default",
        children=[
            dbc.Card([
                dbc.CardBody([
                    dcc.Graph(id='visualization-1'),
                    dcc.Graph(id='visualization-2'),
                    dcc.Graph(id='visualization-3'),
                ])
            ])
        ]
    ),
    dcc.Store(id='uploaded-data')
], fluid=True)

def parse_contents(contents, filename):
    content_type, content_string = contents.split(',')
    decoded = base64.b64decode(content_string)
    try:
        if 'csv' in filename:
            df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
        elif 'xls' in filename:
            df = pd.read_excel(io.BytesIO(decoded))
        else:
            return None
        return df
    except Exception as e:
        print(e)
        return None

def process_data(df, instructions):
    try:
        # Get API key from environment variable
        api_key = os.getenv('GEMINI_API_KEY')
        if not api_key:
            raise ValueError("Gemini API key not found in environment variables")

        # Initialize Gemini with provided API key
        genai.configure(api_key=api_key)
        model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
        
        # Generate visualization code
        response = model.generate_content(f"""
            Analyze the following dataset and instructions:
            
            Data columns: {list(df.columns)}
            Data shape: {df.shape}
            Instructions: {instructions}
            
            Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization:
            1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap)
            2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages)
            3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size)
            4. Provide a clear, concise title that explains the insight
            Consider data density and choose visualizations that simplify and clarify the information.
            Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly).
            
            Return your response as a Python list of dictionaries:
            [
                {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
                {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}},
                {{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}
            ]
        """)

        # Extract code block safely
        code_block = response.text
        if '```python' in code_block:
            code_block = code_block.split('```python')[1].split('```')[0].strip()
        elif '```' in code_block:
            code_block = code_block.split('```')[1].strip()
        
        plots = ast.literal_eval(code_block)
        return plots
    except Exception as e:
        print(f"Error in process_data: {str(e)}")
        return None

def generate_plot(df, plot_info):
    plot_df = df.copy()
    if plot_info['agg_func'] == 'sum':
        plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].sum().reset_index()
    elif plot_info['agg_func'] == 'mean':
        plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].mean().reset_index()
    elif plot_info['agg_func'] == 'count':
        plot_df = plot_df.groupby(plot_info['x']).size().reset_index(name=plot_info['y'])
    
    if 'top_n' in plot_info and plot_info['top_n']:
        plot_df = plot_df.nlargest(plot_info['top_n'], plot_info['y'])
    
    if plot_info['plot_type'] == 'bar':
        fig = go.Figure(go.Bar(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']]))
    elif plot_info['plot_type'] == 'line':
        fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='lines'))
    elif plot_info['plot_type'] == 'scatter':
        fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='markers'))
    elif plot_info['plot_type'] == 'hist':
        fig = go.Figure(go.Histogram(x=plot_df[plot_info['x']]))
    elif plot_info['plot_type'] == 'pie':
        fig = go.Figure(go.Pie(labels=plot_df[plot_info['x']], values=plot_df[plot_info['y']]))
    elif plot_info['plot_type'] == 'heatmap':
        pivot_df = plot_df.pivot(index=plot_info['x'], columns=plot_info['additional']['color'], values=plot_info['y'])
        fig = go.Figure(go.Heatmap(z=pivot_df.values, x=pivot_df.columns, y=pivot_df.index))
    
    fig.update_layout(title=plot_info['title'], xaxis_title=plot_info['x'], yaxis_title=plot_info['y'])
    return fig

@app.callback(
    [Output('upload-feedback', 'children'),
     Output('filename-display', 'children'),
     Output('delete-file-button', 'style'),
     Output('uploaded-data', 'data')],
    [Input('upload-data', 'contents'),
     Input('delete-file-button', 'n_clicks')],
    [State('upload-data', 'filename')]
)
def update_upload_feedback(contents, delete_clicks, filename):
    ctx = callback_context
    if not ctx.triggered:
        return dash.no_update, dash.no_update, dash.no_update, dash.no_update
    
    trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]

    if trigger_id == 'delete-file-button':
        return "File deleted.", "", {'display': 'none'}, None
    
    if contents is not None:
        df = parse_contents(contents, filename)
        if df is not None:
            return (
                dbc.Alert("File uploaded successfully!", color="success"),
                f"Uploaded: {filename}",
                {'display': 'inline-block'},
                contents
            )
        else:
            return (
                dbc.Alert("Error parsing the file. Please upload a valid CSV or Excel file.", color="danger"),
                "",
                {'display': 'none'},
                None
            )
    
    return dash.no_update, dash.no_update, dash.no_update, dash.no_update

@app.callback(
    [Output('visualization-1', 'figure'),
     Output('visualization-2', 'figure'),
     Output('visualization-3', 'figure'),
     Output('error-message', 'children')],
    [Input('submit-button', 'n_clicks')],
    [State('uploaded-data', 'data'),
     State('upload-data', 'filename'),
     State('instructions', 'value')]
)
def update_output(n_clicks, contents, filename, instructions):
    if n_clicks is None or contents is None:
        return dash.no_update, dash.no_update, dash.no_update, ""

    try:
        df = parse_contents(contents, filename)
        if df is None:
            return dash.no_update, dash.no_update, dash.no_update, "Unable to parse the uploaded file."

        plots = process_data(df, instructions)
        if plots is None or len(plots) < 3:
            return dash.no_update, dash.no_update, dash.no_update, "Unable to generate visualizations. Please check your instructions and try again."

        figures = [generate_plot(df, plot_info) for plot_info in plots[:3]]
        return figures[0], figures[1], figures[2], ""
    except Exception as e:
        error_message = f"An error occurred: {str(e)}"
        return dash.no_update, dash.no_update, dash.no_update, error_message

if __name__ == '__main__':
    print("Starting the Dash application...")
    app.run(debug=False, host='0.0.0.0', port=7860)
    print("Dash application has finished running.")