itou-daiki commited on
Commit
f2d31c5
·
verified ·
1 Parent(s): e259083

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -82
app.py CHANGED
@@ -32,15 +32,18 @@ if uploaded_file is not None and ('last_uploaded_file' not in st.session_state o
32
  st.session_state.last_uploaded_file = uploaded_file.name # 最後にアップロードされたファイル名を保存
33
 
34
  if uploaded_file:
35
- if uploaded_file.name.endswith('.csv'):
36
- train_data = pd.read_csv(uploaded_file)
37
- elif uploaded_file.name.endswith('.xlsx'):
38
- train_data = pd.read_excel(uploaded_file)
39
- else:
40
- st.error("無効なファイルタイプです。CSVまたはExcelファイルをアップロードしてください.")
41
-
42
- st.session_state.uploaded_data = train_data
43
- st.dataframe(train_data.head())
 
 
 
44
 
45
  # ターゲット変数の選択
46
  if 'uploaded_data' in st.session_state:
@@ -62,7 +65,7 @@ if 'uploaded_data' in st.session_state:
62
 
63
  # Excelファイルダウンロード機能
64
  towrite = io.BytesIO()
65
- downloaded_file = filtered_data.to_excel(towrite, encoding='utf-8', index=False, header=True)
66
  towrite.seek(0)
67
  b64 = base64.b64encode(towrite.read()).decode()
68
  original_filename = uploaded_file.name.split('.')[0] # 元のファイル名を取得
@@ -83,41 +86,68 @@ if 'uploaded_data' in st.session_state:
83
  st.warning("ターゲット変数に欠損値が含まれているレコードを削除します。")
84
  st.session_state.uploaded_data = st.session_state.uploaded_data.dropna(subset=[target_variable])
85
 
 
 
 
86
 
87
  # 前処理(モデル作成の準備)
88
  if 'exp_clf101_setup_done' not in st.session_state: # このセッションで既にセットアップが完了していない場合のみ実行
89
- with st.spinner('データの前処理中...'):
90
- exp_clf101 = setup(data=st.session_state.uploaded_data,
91
- target=target_variable,
92
- session_id=123,
93
- remove_outliers=remove_outliers_option,
94
- ignore_features=ignore_variable)
95
- st.session_state.exp_clf101 = exp_clf101
96
- st.session_state.exp_clf101_setup_done = True # セットアップ完了フラグをセッションに保存
 
 
 
 
 
 
 
 
97
 
98
- st.info("前処理が完了しました")
99
  setup_list_df = pull()
100
  st.write("前処理の結果")
101
  st.caption("以下は、前処理のステップとそれに伴うデータのパラメータを示す表です。")
102
  st.write(setup_list_df)
103
 
104
  # モデルの比較
105
- with st.spinner('モデルを比較中...'):
106
- models_comparison = compare_models(exclude=['dummy','catboost'])
107
- st.session_state.models_comparison = models_comparison # セッション状態にモデル比較を保存
108
- models_comparison_df = pull()
109
- st.session_state.models_comparison_df = models_comparison_df
 
 
 
 
 
 
 
 
110
 
111
  # モデルの選択とチューニング
112
  if 'models_comparison' in st.session_state:
113
- st.success("モデルの比較が完了しました!")
114
  # モデル比較の表示
115
  models_comparison_df = pull()
116
  st.session_state.models_comparison_df = models_comparison_df
117
  st.write("モデル比較結果")
118
  st.caption("以下は、利用可能な各モデルの性能を示す表です。")
119
  st.dataframe(st.session_state.models_comparison_df)
120
-
 
 
 
 
 
 
 
 
 
 
121
  st.header("5. モデルの選択とチューニング")
122
  st.caption("最も性能の良いモデルを選択し、さらにそのモデルのパラメータをチューニングします。")
123
  selected_model_name = st.selectbox('使用するモデルを選択してください。', st.session_state.models_comparison_df.index)
@@ -130,6 +160,8 @@ if 'uploaded_data' in st.session_state:
130
  if selected_model_name in tree_models:
131
  max_depth = st.slider("決定木の最大の深さを選択", 1, 10, 3) # 例として最小1、最大10、デフォルト3
132
 
 
 
133
 
134
  if st.button('チューニングの実行'):
135
  with st.spinner('チューニング中...'):
@@ -153,17 +185,23 @@ if 'uploaded_data' in st.session_state:
153
  st.write(setup_tuned_model_df)
154
  st.caption("上記表は、チューニング前後のモデルの交差検証結果を示す表です。")
155
 
 
 
 
 
 
 
156
  # チューニング後のモデルを保存
157
  if 'tuned_model' in st.session_state:
158
  # モデルをバイナリ形式で保存
159
- with open("tuned_model.pkl", "wb") as f:
160
  joblib.dump(st.session_state.tuned_model, f)
161
 
162
  # ファイルをbase64エンコードしてダウンロードリンクを作成
163
- with open("tuned_model.pkl", "rb") as f:
164
  model_file = f.read()
165
  model_b64 = base64.b64encode(model_file).decode()
166
- href = f'<a href="data:application/octet-stream;base64,{model_b64}" download="tuned_model.pkl">チューニングされたモデルをダウンロード</a>'
167
  st.markdown(href, unsafe_allow_html=True)
168
 
169
  st.header("6. モデルの可視化及び評価")
@@ -178,67 +216,69 @@ if 'uploaded_data' in st.session_state:
178
  ('vc', '<検証曲線��', 'パラメータの異なる値に対するモデルの性能を示しています'),
179
  ('manifold', '<マニホールド学習>', '高次元データを2次元にマッピングしたものを示しています')
180
  ]
181
-
182
  for plot_type, plot_name, plot_description in plot_types:
183
- with st.spinner('プロット中...'):
184
- try:
185
- st.write(plot_name)
186
  img = plot_model(tuned_model, plot=plot_type, display_format="streamlit", save=True)
 
187
  st.image(img)
188
  st.caption(plot_description) # グラフの説明を追加
189
- except Exception as e:
190
- st.warning(f"{plot_name}の表示中にエラーが発生しました: {str(e)}")
191
 
192
  # 決定木のプロット
193
  if selected_model_name in tree_models:
194
- st.write("<決定木のプロット>")
195
  st.caption("決定木は、モデルがどのように予測を行っているかを理解するのに役立ちます。")
196
 
197
- with st.spinner('プロット中...'):
198
-
199
- if selected_model_name in ['dt']:
200
- from sklearn.tree import plot_tree
201
- fig, ax = plt.subplots(figsize=(20,10))
202
- plot_tree(tuned_model, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3)
203
- st.pyplot(fig)
204
-
205
- elif selected_model_name in ['rf', 'et']:
206
- from sklearn.tree import plot_tree
207
- fig, ax = plt.subplots(figsize=(20,10))
208
- plot_tree(tuned_model.estimators_[0], feature_names=train_data.columns, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3)
209
- st.pyplot(fig)
210
-
211
- elif selected_model_name == 'ada':
212
- from sklearn.tree import plot_tree
213
- base_estimator = tuned_model.get_model().estimators_[0]
214
- fig, ax = plt.subplots(figsize=(20,10))
215
- plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
216
- st.pyplot(fig)
217
-
218
- elif selected_model_name == 'gbr':
219
- from sklearn.tree import plot_tree
220
- base_estimator = tuned_model.get_model().estimators_[0][0]
221
- fig, ax = plt.subplots(figsize=(20,10))
222
- plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
223
- st.pyplot(fig)
224
-
225
- elif selected_model_name == 'catboost':
226
- from catboost import CatBoostClassifier, plot_tree
227
- catboost_model = tuned_model.get_model()
228
- fig, ax = plt.subplots(figsize=(20,10))
229
- plot_tree(catboost_model, tree_idx=0, ax=ax, max_depth=3)
230
- st.pyplot(fig)
231
 
232
- elif selected_model_name == 'lightgbm':
233
- import lightgbm as lgb
234
- booster = tuned_model._Booster # LightGBM Booster object
235
- fig, ax = plt.subplots(figsize=(20,10))
236
- lgb.plot_tree(booster, tree_index=0, ax=ax, max_depth=3)
237
- st.pyplot(fig)
238
 
239
- elif selected_model_name == 'xgboost':
240
- import xgboost as xgb
241
- booster = tuned_model.get_booster() # XGBoost Booster object
242
- fig, ax = plt.subplots(figsize=(20,10))
243
- xgb.plot_tree(booster, num_trees=0, ax=ax, max_depth=3)
244
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  st.session_state.last_uploaded_file = uploaded_file.name # 最後にアップロードされたファイル名を保存
33
 
34
  if uploaded_file:
35
+ try:
36
+ if uploaded_file.name.endswith('.csv'):
37
+ train_data = pd.read_csv(uploaded_file)
38
+ elif uploaded_file.name.endswith('.xlsx'):
39
+ train_data = pd.read_excel(uploaded_file)
40
+ else:
41
+ raise ValueError("無効なファイルタイプです。CSVまたはExcelファイルをアップロードしてください。")
42
+
43
+ st.session_state.uploaded_data = train_data
44
+ st.dataframe(train_data.head())
45
+ except Exception as e:
46
+ st.error(str(e))
47
 
48
  # ターゲット変数の選択
49
  if 'uploaded_data' in st.session_state:
 
65
 
66
  # Excelファイルダウンロード機能
67
  towrite = io.BytesIO()
68
+ downloaded_file = filtered_data.to_excel(towrite, index=False, header=True)
69
  towrite.seek(0)
70
  b64 = base64.b64encode(towrite.read()).decode()
71
  original_filename = uploaded_file.name.split('.')[0] # 元のファイル名を取得
 
86
  st.warning("ターゲット変数に欠損値が含まれているレコードを削除します。")
87
  st.session_state.uploaded_data = st.session_state.uploaded_data.dropna(subset=[target_variable])
88
 
89
+ # 前処理の進捗状況を表示
90
+ progress_bar = st.progress(0)
91
+ status_text = st.empty()
92
 
93
  # 前処理(モデル作成の準備)
94
  if 'exp_clf101_setup_done' not in st.session_state: # このセッションで既にセットアップが完了していない場合のみ実行
95
+ try:
96
+ with st.spinner('データの前処理中...'):
97
+ exp_clf101 = setup(data=st.session_state.uploaded_data,
98
+ target=target_variable,
99
+ session_id=123,
100
+ remove_outliers=remove_outliers_option,
101
+ ignore_features=ignore_variable)
102
+ st.session_state.exp_clf101 = exp_clf101
103
+ st.session_state.exp_clf101_setup_done = True # セットアップ完了フラグをセッションに保存
104
+
105
+ # 前処理の進捗状況を更新
106
+ progress_bar.progress(50)
107
+ status_text.text("前処理が完了しました。")
108
+ except Exception as e:
109
+ st.error(f"前処理中にエラーが発生しました: {str(e)}")
110
+ st.stop()
111
 
 
112
  setup_list_df = pull()
113
  st.write("前処理の結果")
114
  st.caption("以下は、前処理のステップとそれに伴うデータのパラメータを示す表です。")
115
  st.write(setup_list_df)
116
 
117
  # モデルの比較
118
+ try:
119
+ with st.spinner('モデルを比較中...'):
120
+ models_comparison = compare_models(exclude=['dummy','catboost'])
121
+ st.session_state.models_comparison = models_comparison # セッション状態にモデル比較を保存
122
+ models_comparison_df = pull()
123
+ st.session_state.models_comparison_df = models_comparison_df
124
+
125
+ # モデル比較の進捗状況を���新
126
+ progress_bar.progress(100)
127
+ status_text.text("モデルの比較が完了しました!")
128
+ except Exception as e:
129
+ st.error(f"モデルの比較中にエラーが発生しました: {str(e)}")
130
+ st.stop()
131
 
132
  # モデルの選択とチューニング
133
  if 'models_comparison' in st.session_state:
 
134
  # モデル比較の表示
135
  models_comparison_df = pull()
136
  st.session_state.models_comparison_df = models_comparison_df
137
  st.write("モデル比較結果")
138
  st.caption("以下は、利用可能な各モデルの性能を示す表です。")
139
  st.dataframe(st.session_state.models_comparison_df)
140
+
141
+ # モデル比較結果の解釈の説明を追加
142
+ st.write("モデル比較結果の解釈:")
143
+ st.write("- Accuracy: モデルの予測精度を示します。値が高いほどモデルの性能が良いことを示します。")
144
+ st.write("- AUC: ROC曲線下の面積を示します。値が高いほどモデルの性能が良いことを示します。")
145
+ st.write("- Recall: 実際の正例のうち、正しく正例と予測された割合を示します。")
146
+ st.write("- Precision: 正例と予測されたもののうち、実際に正例である割合を示します。")
147
+ st.write("- F1: RecallとPrecisionの調和平均を示します。両者のバランスを考慮した指標です。")
148
+ st.write("- Kappa: モデルの予測結果と実際の結果の一致度を示します。値が高いほどモデルの性能が良いことを示します。")
149
+ st.write("- MCC: 不均衡データにおけるモデルの性能を示します。値が高いほどモデルの性能が良いことを示します。")
150
+
151
  st.header("5. モデルの選択とチューニング")
152
  st.caption("最も性能の良いモデルを選択し、さらにそのモデルのパラメータをチューニングします。")
153
  selected_model_name = st.selectbox('使用するモデルを選択してください。', st.session_state.models_comparison_df.index)
 
160
  if selected_model_name in tree_models:
161
  max_depth = st.slider("決定木の最大の深さを選択", 1, 10, 3) # 例として最小1、最大10、デフォルト3
162
 
163
+ # モデル名の入力
164
+ model_name = st.text_input("保存するモデルの名前を入力してください", value="tuned_model")
165
 
166
  if st.button('チューニングの実行'):
167
  with st.spinner('チューニング中...'):
 
185
  st.write(setup_tuned_model_df)
186
  st.caption("上記表は、チューニング前後のモデルの交差検証結果を示す表です。")
187
 
188
+ # チューニング前後の比較結果の解釈の説明を追加
189
+ st.write("チューニング前後の比較結果の解釈:")
190
+ st.write("- チューニング後の方が、Accuracy、AUC、Recall、Precision、F1、Kappa、MCCの値が高い場合、モデルの性能が向上したことを示します。")
191
+ st.write("- チューニング後の方が、これらの指標の値が低い場合、モデルの性能が悪化したことを示します。")
192
+ st.write("- チューニングによる変化がない場合は、モデルの性能に大きな影響がなかったことを示します。")
193
+
194
  # チューニング後のモデルを保存
195
  if 'tuned_model' in st.session_state:
196
  # モデルをバイナリ形式で保存
197
+ with open(f"{model_name}.pkl", "wb") as f:
198
  joblib.dump(st.session_state.tuned_model, f)
199
 
200
  # ファイルをbase64エンコードしてダウンロードリンクを作成
201
+ with open(f"{model_name}.pkl", "rb") as f:
202
  model_file = f.read()
203
  model_b64 = base64.b64encode(model_file).decode()
204
+ href = f'<a href="data:application/octet-stream;base64,{model_b64}" download="{model_name}.pkl">チューニングされたモデルをダウンロード</a>'
205
  st.markdown(href, unsafe_allow_html=True)
206
 
207
  st.header("6. モデルの可視化及び評価")
 
216
  ('vc', '<検証曲線��', 'パラメータの異なる値に対するモデルの性能を示しています'),
217
  ('manifold', '<マニホールド学習>', '高次元データを2次元にマッピングしたものを示しています')
218
  ]
 
219
  for plot_type, plot_name, plot_description in plot_types:
220
+ try:
221
+ with st.spinner(f'{plot_name}のプロット中...'):
 
222
  img = plot_model(tuned_model, plot=plot_type, display_format="streamlit", save=True)
223
+ st.subheader(plot_name) # グラフのタイトルを追加
224
  st.image(img)
225
  st.caption(plot_description) # グラフの説明を追加
226
+ except Exception as e:
227
+ st.warning(f"{plot_name}の表示中にエラーが発生しました: {str(e)}")
228
 
229
  # 決定木のプロット
230
  if selected_model_name in tree_models:
231
+ st.subheader("<決定木のプロット>")
232
  st.caption("決定木は、モデルがどのように予測を行っているかを理解するのに役立ちます。")
233
 
234
+ try:
235
+ with st.spinner('決定木のプロット中...'):
236
+
237
+ if selected_model_name in ['dt']:
238
+ from sklearn.tree import plot_tree
239
+ fig, ax = plt.subplots(figsize=(20,10))
240
+ plot_tree(tuned_model, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
241
+ st.pyplot(fig)
242
+
243
+ elif selected_model_name in ['rf', 'et']:
244
+ from sklearn.tree import plot_tree
245
+ fig, ax = plt.subplots(figsize=(20,10))
246
+ plot_tree(tuned_model.estimators_[0], feature_names=train_data.columns, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
247
+ st.pyplot(fig)
248
+
249
+ elif selected_model_name == 'ada':
250
+ from sklearn.tree import plot_tree
251
+ base_estimator = tuned_model.get_model().estimators_[0]
252
+ fig, ax = plt.subplots(figsize=(20,10))
253
+ plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
254
+ st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
+ elif selected_model_name == 'gbr':
257
+ from sklearn.tree import plot_tree
258
+ base_estimator = tuned_model.get_model().estimators_[0][0]
259
+ fig, ax = plt.subplots(figsize=(20,10))
260
+ plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
261
+ st.pyplot(fig)
262
 
263
+ elif selected_model_name == 'catboost':
264
+ from catboost import CatBoostClassifier, plot_tree
265
+ catboost_model = tuned_model.get_model()
266
+ fig, ax = plt.subplots(figsize=(20,10))
267
+ plot_tree(catboost_model, tree_idx=0, ax=ax, max_depth=3)
268
+ st.pyplot(fig)
269
+
270
+ elif selected_model_name == 'lightgbm':
271
+ import lightgbm as lgb
272
+ booster = tuned_model._Booster # LightGBM Booster object
273
+ fig, ax = plt.subplots(figsize=(20,10))
274
+ lgb.plot_tree(booster, tree_index=0, ax=ax, max_depth=3)
275
+ st.pyplot(fig)
276
+
277
+ elif selected_model_name == 'xgboost':
278
+ import xgboost as xgb
279
+ booster = tuned_model.get_booster() # XGBoost Booster object
280
+ fig, ax = plt.subplots(figsize=(20,10))
281
+ xgb.plot_tree(booster, num_trees=0, ax=ax, max_depth=3)
282
+ st.pyplot(fig)
283
+ except Exception as e:
284
+ st.warning(f"決定木のプロット中にエラーが発生しました: {str(e)}")