bluenevus commited on
Commit
b7ede47
·
verified ·
1 Parent(s): 3ffb96a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -86
app.py CHANGED
@@ -1,24 +1,78 @@
1
- import gradio as gr
2
- import pandas as pd
3
- import matplotlib.pyplot as plt
4
  import io
 
5
  import ast
6
- from PIL import Image, ImageDraw
7
- import google.generativeai as genai
8
  import traceback
9
- import os
 
 
 
 
 
 
 
10
 
11
- def process_file(file, instructions):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  try:
13
  # Initialize Gemini
14
  api_key = os.environ.get('GEMINI_API_KEY')
15
  genai.configure(api_key=api_key)
16
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
17
 
18
- # Read uploaded file
19
- file_path = file.name
20
- df = pd.read_csv(file_path) if file_path.endswith('.csv') else pd.read_excel(file_path)
21
-
22
  # Generate visualization code
23
  response = model.generate_content(f"""
24
  Analyze the following dataset and instructions:
@@ -50,86 +104,64 @@ def process_file(file, instructions):
50
  elif '```' in code_block:
51
  code_block = code_block.split('```')[1].strip()
52
 
53
- print("Generated code block:")
54
- print(code_block)
55
-
56
  plots = ast.literal_eval(code_block)
57
-
58
- # Generate visualizations
59
- images = []
60
- for plot in plots[:3]: # Ensure max 3 plots
61
- fig, ax = plt.subplots(figsize=(10, 6))
62
-
63
- # Apply preprocessing and aggregation
64
- plot_df = df.copy()
65
- if plot['agg_func'] == 'sum':
66
- plot_df = plot_df.groupby(plot['x'])[plot['y']].sum().reset_index()
67
- elif plot['agg_func'] == 'mean':
68
- plot_df = plot_df.groupby(plot['x'])[plot['y']].mean().reset_index()
69
- elif plot['agg_func'] == 'count':
70
- plot_df = plot_df.groupby(plot['x']).size().reset_index(name=plot['y'])
71
-
72
- if 'top_n' in plot and plot['top_n']:
73
- plot_df = plot_df.nlargest(plot['top_n'], plot['y'])
74
-
75
- if plot['plot_type'] == 'bar':
76
- plot_df.plot(kind='bar', x=plot['x'], y=plot['y'], ax=ax)
77
- elif plot['plot_type'] == 'line':
78
- plot_df.plot(kind='line', x=plot['x'], y=plot['y'], ax=ax)
79
- elif plot['plot_type'] == 'scatter':
80
- plot_df.plot(kind='scatter', x=plot['x'], y=plot['y'], ax=ax,
81
- c=plot['additional'].get('color'), s=plot_df[plot['additional'].get('size', 'y')])
82
- elif plot['plot_type'] == 'hist':
83
- plot_df[plot['x']].hist(ax=ax, bins=20)
84
- elif plot['plot_type'] == 'pie':
85
- plot_df.plot(kind='pie', y=plot['y'], labels=plot_df[plot['x']], ax=ax, autopct='%1.1f%%')
86
- elif plot['plot_type'] == 'heatmap':
87
- pivot_df = plot_df.pivot(index=plot['x'], columns=plot['additional']['color'], values=plot['y'])
88
- ax.imshow(pivot_df, cmap='YlOrRd')
89
- ax.set_xticks(range(len(pivot_df.columns)))
90
- ax.set_yticks(range(len(pivot_df.index)))
91
- ax.set_xticklabels(pivot_df.columns)
92
- ax.set_yticklabels(pivot_df.index)
93
-
94
- ax.set_title(plot['title'])
95
- if plot['plot_type'] != 'pie':
96
- ax.set_xlabel(plot['x'])
97
- ax.set_ylabel(plot['y'])
98
- plt.tight_layout()
99
-
100
- buf = io.BytesIO()
101
- plt.savefig(buf, format='png')
102
- buf.seek(0)
103
- img = Image.open(buf)
104
- images.append(img)
105
- plt.close(fig)
106
-
107
- return images if len(images) == 3 else images + [Image.new('RGB', (800, 600), (255,255,255))]*(3-len(images))
108
-
109
  except Exception as e:
110
- error_message = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
111
- print(error_message) # Print to console for debugging
112
- error_image = Image.new('RGB', (800, 400), (255, 255, 255))
113
- draw = ImageDraw.Draw(error_image)
114
- draw.text((10, 10), error_message, fill=(255, 0, 0))
115
- return [error_image] * 3
116
 
117
- with gr.Blocks(theme=gr.themes.Default()) as demo:
118
- gr.Markdown("# Data Analysis Dashboard")
 
 
 
 
 
 
119
 
120
- with gr.Row():
121
- file = gr.File(label="Upload Dataset", file_types=[".csv", ".xlsx"])
122
- instructions = gr.Textbox(label="Analysis Instructions", placeholder="Describe the analysis you want...")
123
 
124
- submit = gr.Button("Generate Insights", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- output_images = [gr.Image(label=f"Visualization {i+1}") for i in range(3)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- submit.click(
129
- process_file,
130
- inputs=[file, instructions],
131
- outputs=output_images
132
- )
133
 
134
- if __name__ == "__main__":
135
- demo.launch()
 
1
+ import base64
 
 
2
  import io
3
+ import os
4
  import ast
 
 
5
  import traceback
6
+ from threading import Thread
7
+
8
+ import dash
9
+ from dash import dcc, html, Input, Output, State
10
+ import dash_bootstrap_components as dbc
11
+ import pandas as pd
12
+ import plotly.graph_objs as go
13
+ import google.generativeai as genai
14
 
15
+ # Initialize Dash app
16
+ app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
17
+
18
+ # Layout
19
+ app.layout = dbc.Container([
20
+ html.H1("Data Analysis Dashboard", className="my-4"),
21
+ dbc.Card([
22
+ dbc.CardBody([
23
+ dcc.Upload(
24
+ id='upload-data',
25
+ children=html.Div([
26
+ 'Drag and Drop or ',
27
+ html.A('Select Files')
28
+ ]),
29
+ style={
30
+ 'width': '100%',
31
+ 'height': '60px',
32
+ 'lineHeight': '60px',
33
+ 'borderWidth': '1px',
34
+ 'borderStyle': 'dashed',
35
+ 'borderRadius': '5px',
36
+ 'textAlign': 'center',
37
+ 'margin': '10px'
38
+ },
39
+ multiple=False
40
+ ),
41
+ dbc.Input(id="instructions", placeholder="Describe the analysis you want...", type="text"),
42
+ dbc.Button("Generate Insights", id="submit-button", color="primary", className="mt-3"),
43
+ ])
44
+ ], className="mb-4"),
45
+ dbc.Card([
46
+ dbc.CardBody([
47
+ dcc.Graph(id='visualization-1'),
48
+ dcc.Graph(id='visualization-2'),
49
+ dcc.Graph(id='visualization-3'),
50
+ ])
51
+ ])
52
+ ], fluid=True)
53
+
54
+ def parse_contents(contents, filename):
55
+ content_type, content_string = contents.split(',')
56
+ decoded = base64.b64decode(content_string)
57
+ try:
58
+ if 'csv' in filename:
59
+ df = pd.read_csv(io.StringIO(decoded.decode('utf-8')))
60
+ elif 'xls' in filename:
61
+ df = pd.read_excel(io.BytesIO(decoded))
62
+ else:
63
+ return None
64
+ return df
65
+ except Exception as e:
66
+ print(e)
67
+ return None
68
+
69
+ def process_data(df, instructions):
70
  try:
71
  # Initialize Gemini
72
  api_key = os.environ.get('GEMINI_API_KEY')
73
  genai.configure(api_key=api_key)
74
  model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
75
 
 
 
 
 
76
  # Generate visualization code
77
  response = model.generate_content(f"""
78
  Analyze the following dataset and instructions:
 
104
  elif '```' in code_block:
105
  code_block = code_block.split('```')[1].strip()
106
 
 
 
 
107
  plots = ast.literal_eval(code_block)
108
+ return plots
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  except Exception as e:
110
+ print(f"Error in process_data: {str(e)}")
111
+ return None
 
 
 
 
112
 
113
+ def generate_plot(df, plot_info):
114
+ plot_df = df.copy()
115
+ if plot_info['agg_func'] == 'sum':
116
+ plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].sum().reset_index()
117
+ elif plot_info['agg_func'] == 'mean':
118
+ plot_df = plot_df.groupby(plot_info['x'])[plot_info['y']].mean().reset_index()
119
+ elif plot_info['agg_func'] == 'count':
120
+ plot_df = plot_df.groupby(plot_info['x']).size().reset_index(name=plot_info['y'])
121
 
122
+ if 'top_n' in plot_info and plot_info['top_n']:
123
+ plot_df = plot_df.nlargest(plot_info['top_n'], plot_info['y'])
 
124
 
125
+ if plot_info['plot_type'] == 'bar':
126
+ fig = go.Figure(go.Bar(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']]))
127
+ elif plot_info['plot_type'] == 'line':
128
+ fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='lines'))
129
+ elif plot_info['plot_type'] == 'scatter':
130
+ fig = go.Figure(go.Scatter(x=plot_df[plot_info['x']], y=plot_df[plot_info['y']], mode='markers'))
131
+ elif plot_info['plot_type'] == 'hist':
132
+ fig = go.Figure(go.Histogram(x=plot_df[plot_info['x']]))
133
+ elif plot_info['plot_type'] == 'pie':
134
+ fig = go.Figure(go.Pie(labels=plot_df[plot_info['x']], values=plot_df[plot_info['y']]))
135
+ elif plot_info['plot_type'] == 'heatmap':
136
+ pivot_df = plot_df.pivot(index=plot_info['x'], columns=plot_info['additional']['color'], values=plot_info['y'])
137
+ fig = go.Figure(go.Heatmap(z=pivot_df.values, x=pivot_df.columns, y=pivot_df.index))
138
 
139
+ fig.update_layout(title=plot_info['title'], xaxis_title=plot_info['x'], yaxis_title=plot_info['y'])
140
+ return fig
141
+
142
+ @app.callback(
143
+ [Output('visualization-1', 'figure'),
144
+ Output('visualization-2', 'figure'),
145
+ Output('visualization-3', 'figure')],
146
+ [Input('submit-button', 'n_clicks')],
147
+ [State('upload-data', 'contents'),
148
+ State('upload-data', 'filename'),
149
+ State('instructions', 'value')]
150
+ )
151
+ def update_output(n_clicks, contents, filename, instructions):
152
+ if n_clicks is None or contents is None:
153
+ return dash.no_update, dash.no_update, dash.no_update
154
+
155
+ df = parse_contents(contents, filename)
156
+ if df is None:
157
+ return dash.no_update, dash.no_update, dash.no_update
158
+
159
+ plots = process_data(df, instructions)
160
+ if plots is None or len(plots) < 3:
161
+ return dash.no_update, dash.no_update, dash.no_update
162
 
163
+ figures = [generate_plot(df, plot_info) for plot_info in plots[:3]]
164
+ return figures
 
 
 
165
 
166
+ if __name__ == '__main__':
167
+ app.run(debug=True, host='0.0.0.0', port=7860, threaded=True)