|
import base64 |
|
import io |
|
import os |
|
import ast |
|
import traceback |
|
from threading import Thread |
|
|
|
import dash |
|
from dash import dcc, html, Input, Output, State |
|
import dash_bootstrap_components as dbc |
|
import pandas as pd |
|
import plotly.graph_objs as go |
|
import google.generativeai as genai |
|
|
|
|
|
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) |
|
|
|
|
|
app.layout = dbc.Container([ |
|
html.H1("Data Analysis Dashboard", className="my-4"), |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dcc.Upload( |
|
id='upload-data', |
|
children=html.Div([ |
|
'Drag and Drop or ', |
|
html.A('Select Files') |
|
]), |
|
style={ |
|
'width': '100%', |
|
'height': '60px', |
|
'lineHeight': '60px', |
|
'borderWidth': '1px', |
|
'borderStyle': 'dashed', |
|
'borderRadius': '5px', |
|
'textAlign': 'center', |
|
'margin': '10px' |
|
}, |
|
multiple=False |
|
), |
|
dbc.Input(id="instructions", placeholder="Describe the analysis you want...", type="text"), |
|
dbc.Button("Generate Insights", id="submit-button", color="primary", className="mt-3"), |
|
]) |
|
], className="mb-4"), |
|
dbc.Card([ |
|
dbc.CardBody([ |
|
dcc.Graph(id='visualization-1'), |
|
dcc.Graph(id='visualization-2'), |
|
dcc.Graph(id='visualization-3'), |
|
]) |
|
]) |
|
], fluid=True) |
|
|
|
def parse_contents(contents, filename): |
|
content_type, content_string = contents.split(',') |
|
decoded = base64.b64decode(content_string) |
|
try: |
|
if 'csv' in filename: |
|
df = pd.read_csv(io.StringIO(decoded.decode('utf-8'))) |
|
elif 'xls' in filename: |
|
df = pd.read_excel(io.BytesIO(decoded)) |
|
else: |
|
return None |
|
return df |
|
except Exception as e: |
|
print(e) |
|
return None |
|
|
|
def process_data(df, instructions): |
|
try: |
|
|
|
api_key = os.environ.get('GEMINI_API_KEY') |
|
genai.configure(api_key=api_key) |
|
model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25') |
|
|
|
|
|
response = model.generate_content(f""" |
|
Analyze the following dataset and instructions: |
|
|
|
Data columns: {list(df.columns)} |
|
Data shape: {df.shape} |
|
Instructions: {instructions} |
|
|
|
Based on this, create 3 appropriate visualizations that provide meaningful insights. For each visualization: |
|
1. Choose the most suitable plot type (bar, line, scatter, hist, pie, heatmap) |
|
2. Determine appropriate data aggregation (e.g., top 5 categories, yearly averages) |
|
3. Select relevant columns for x-axis, y-axis, and any additional dimensions (color, size) |
|
4. Provide a clear, concise title that explains the insight |
|
Consider data density and choose visualizations that simplify and clarify the information. |
|
Limit the number of data points displayed to ensure readability (e.g., top 5, top 10, yearly). |
|
|
|
Return your response as a Python list of dictionaries: |
|
[ |
|
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, |
|
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}}, |
|
{{"title": "...", "plot_type": "...", "x": "...", "y": "...", "agg_func": "...", "top_n": ..., "additional": {{"color": "...", "size": "..."}}}} |
|
] |
|
""") |
|
|
|
|
|
code_block = response.text |
|
if '```python' in code_block: |
|
code_block = code_block.split('```python')[1].split('```')[0].strip() |
|
elif '```' in code_block: |
|
code_block = code_block.split('```')[1].strip() |
|
|
|
plots = ast.literal_eval(code_block) |
|
return plots |
|
except Exception as e: |
|
print(f"Error in process_data: {str(e)}") |
|
return None |
|
|
|
def generate_plot(df, plot_info): |
|
plot_df = df.copy() |
|
if plot_info['agg_func'] == 'sum': |
|
plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].sum().reset_index() |
|
elif plot_info['agg_func'] == 'mean': |
|
plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].mean().reset_index() |
|
elif plot_info['agg_func'] == 'count': |
|
plot_df = plot_df.groupby(plot_info['x']).size().reset_index(name=plot_info['y']) |
|
|
|
if 'top_n' in plot_info and plot_info['top_n']: |
|
plot_df = plot_df.nlargest(plot_info['top_n'], plot_info['y']) |
|
|
|
if plot_info['plot_type'] == 'bar': |
|
fig = go.Figure(go.Bar(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']])) |
|
elif plot_info['plot_type'] == 'line': |
|
fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='lines')) |
|
elif plot_info['plot_type'] == 'scatter': |
|
fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='markers')) |
|
elif plot_info['plot_type'] == 'hist': |
|
fig = go.Figure(go.Histogram(x=plot_df[plot_info['x']])) |
|
elif plot_info['plot_type'] == 'pie': |
|
fig = go.Figure(go.Pie(labels=plot_df[plot_info['x']], values=plot_df[plot_info['y']])) |
|
elif plot_info['plot_type'] == 'heatmap': |
|
pivot_df = plot_df.pivot(index=plot_info['x'], columns=plot_info['additional']['color'], values=plot_info['y']) |
|
fig = go.Figure(go.Heatmap(z=pivot_df.values, x=pivot_df.columns, y=pivot_df.index)) |
|
|
|
fig.update_layout(title=plot_info['title'], xaxis_title=plot_info['x'], yaxis_title=plot_info['y']) |
|
return fig |
|
|
|
@app.callback( |
|
[Output('visualization-1', 'figure'), |
|
Output('visualization-2', 'figure'), |
|
Output('visualization-3', 'figure')], |
|
[Input('submit-button', 'n_clicks')], |
|
[State('upload-data', 'contents'), |
|
State('upload-data', 'filename'), |
|
State('instructions', 'value')] |
|
) |
|
def update_output(n_clicks, contents, filename, instructions): |
|
if n_clicks is None or contents is None: |
|
return dash.no_update, dash.no_update, dash.no_update |
|
|
|
df = parse_contents(contents, filename) |
|
if df is None: |
|
return dash.no_update, dash.no_update, dash.no_update |
|
|
|
plots = process_data(df, instructions) |
|
if plots is None or len(plots) < 3: |
|
return dash.no_update, dash.no_update, dash.no_update |
|
|
|
figures = [generate_plot(df, plot_info) for plot_info in plots[:3]] |
|
return figures |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True, host='0.0.0.0', port=7860, threaded=True) |