Update app.py
Browse files
app.py
CHANGED
@@ -32,15 +32,18 @@ if uploaded_file is not None and ('last_uploaded_file' not in st.session_state o
|
|
32 |
st.session_state.last_uploaded_file = uploaded_file.name # 最後にアップロードされたファイル名を保存
|
33 |
|
34 |
if uploaded_file:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
# ターゲット変数の選択
|
46 |
if 'uploaded_data' in st.session_state:
|
@@ -62,7 +65,7 @@ if 'uploaded_data' in st.session_state:
|
|
62 |
|
63 |
# Excelファイルダウンロード機能
|
64 |
towrite = io.BytesIO()
|
65 |
-
downloaded_file = filtered_data.to_excel(towrite,
|
66 |
towrite.seek(0)
|
67 |
b64 = base64.b64encode(towrite.read()).decode()
|
68 |
original_filename = uploaded_file.name.split('.')[0] # 元のファイル名を取得
|
@@ -83,41 +86,68 @@ if 'uploaded_data' in st.session_state:
|
|
83 |
st.warning("ターゲット変数に欠損値が含まれているレコードを削除します。")
|
84 |
st.session_state.uploaded_data = st.session_state.uploaded_data.dropna(subset=[target_variable])
|
85 |
|
|
|
|
|
|
|
86 |
|
87 |
# 前処理(モデル作成の準備)
|
88 |
if 'exp_clf101_setup_done' not in st.session_state: # このセッションで既にセットアップが完了していない場合のみ実行
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
-
st.info("前処理が完了しました")
|
99 |
setup_list_df = pull()
|
100 |
st.write("前処理の結果")
|
101 |
st.caption("以下は、前処理のステップとそれに伴うデータのパラメータを示す表です。")
|
102 |
st.write(setup_list_df)
|
103 |
|
104 |
# モデルの比較
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# モデルの選択とチューニング
|
112 |
if 'models_comparison' in st.session_state:
|
113 |
-
st.success("モデルの比較が完了しました!")
|
114 |
# モデル比較の表示
|
115 |
models_comparison_df = pull()
|
116 |
st.session_state.models_comparison_df = models_comparison_df
|
117 |
st.write("モデル比較結果")
|
118 |
st.caption("以下は、利用可能な各モデルの性能を示す表です。")
|
119 |
st.dataframe(st.session_state.models_comparison_df)
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
st.header("5. モデルの選択とチューニング")
|
122 |
st.caption("最も性能の良いモデルを選択し、さらにそのモデルのパラメータをチューニングします。")
|
123 |
selected_model_name = st.selectbox('使用するモデルを選択してください。', st.session_state.models_comparison_df.index)
|
@@ -130,6 +160,8 @@ if 'uploaded_data' in st.session_state:
|
|
130 |
if selected_model_name in tree_models:
|
131 |
max_depth = st.slider("決定木の最大の深さを選択", 1, 10, 3) # 例として最小1、最大10、デフォルト3
|
132 |
|
|
|
|
|
133 |
|
134 |
if st.button('チューニングの実行'):
|
135 |
with st.spinner('チューニング中...'):
|
@@ -153,17 +185,23 @@ if 'uploaded_data' in st.session_state:
|
|
153 |
st.write(setup_tuned_model_df)
|
154 |
st.caption("上記表は、チューニング前後のモデルの交差検証結果を示す表です。")
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
# チューニング後のモデルを保存
|
157 |
if 'tuned_model' in st.session_state:
|
158 |
# モデルをバイナリ形式で保存
|
159 |
-
with open("
|
160 |
joblib.dump(st.session_state.tuned_model, f)
|
161 |
|
162 |
# ファイルをbase64エンコードしてダウンロードリンクを作成
|
163 |
-
with open("
|
164 |
model_file = f.read()
|
165 |
model_b64 = base64.b64encode(model_file).decode()
|
166 |
-
href = f'<a href="data:application/octet-stream;base64,{model_b64}" download="
|
167 |
st.markdown(href, unsafe_allow_html=True)
|
168 |
|
169 |
st.header("6. モデルの可視化及び評価")
|
@@ -178,67 +216,69 @@ if 'uploaded_data' in st.session_state:
|
|
178 |
('vc', '<検証曲線��', 'パラメータの異なる値に対するモデルの性能を示しています'),
|
179 |
('manifold', '<マニホールド学習>', '高次元データを2次元にマッピングしたものを示しています')
|
180 |
]
|
181 |
-
|
182 |
for plot_type, plot_name, plot_description in plot_types:
|
183 |
-
|
184 |
-
|
185 |
-
st.write(plot_name)
|
186 |
img = plot_model(tuned_model, plot=plot_type, display_format="streamlit", save=True)
|
|
|
187 |
st.image(img)
|
188 |
st.caption(plot_description) # グラフの説明を追加
|
189 |
-
|
190 |
-
|
191 |
|
192 |
# 決定木のプロット
|
193 |
if selected_model_name in tree_models:
|
194 |
-
st.
|
195 |
st.caption("決定木は、モデルがどのように予測を行っているかを理解するのに役立ちます。")
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
elif selected_model_name == 'gbr':
|
219 |
-
from sklearn.tree import plot_tree
|
220 |
-
base_estimator = tuned_model.get_model().estimators_[0][0]
|
221 |
-
fig, ax = plt.subplots(figsize=(20,10))
|
222 |
-
plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
|
223 |
-
st.pyplot(fig)
|
224 |
-
|
225 |
-
elif selected_model_name == 'catboost':
|
226 |
-
from catboost import CatBoostClassifier, plot_tree
|
227 |
-
catboost_model = tuned_model.get_model()
|
228 |
-
fig, ax = plt.subplots(figsize=(20,10))
|
229 |
-
plot_tree(catboost_model, tree_idx=0, ax=ax, max_depth=3)
|
230 |
-
st.pyplot(fig)
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
st.session_state.last_uploaded_file = uploaded_file.name # 最後にアップロードされたファイル名を保存
|
33 |
|
34 |
if uploaded_file:
|
35 |
+
try:
|
36 |
+
if uploaded_file.name.endswith('.csv'):
|
37 |
+
train_data = pd.read_csv(uploaded_file)
|
38 |
+
elif uploaded_file.name.endswith('.xlsx'):
|
39 |
+
train_data = pd.read_excel(uploaded_file)
|
40 |
+
else:
|
41 |
+
raise ValueError("無効なファイルタイプです。CSVまたはExcelファイルをアップロードしてください。")
|
42 |
+
|
43 |
+
st.session_state.uploaded_data = train_data
|
44 |
+
st.dataframe(train_data.head())
|
45 |
+
except Exception as e:
|
46 |
+
st.error(str(e))
|
47 |
|
48 |
# ターゲット変数の選択
|
49 |
if 'uploaded_data' in st.session_state:
|
|
|
65 |
|
66 |
# Excelファイルダウンロード機能
|
67 |
towrite = io.BytesIO()
|
68 |
+
downloaded_file = filtered_data.to_excel(towrite, index=False, header=True)
|
69 |
towrite.seek(0)
|
70 |
b64 = base64.b64encode(towrite.read()).decode()
|
71 |
original_filename = uploaded_file.name.split('.')[0] # 元のファイル名を取得
|
|
|
86 |
st.warning("ターゲット変数に欠損値が含まれているレコードを削除します。")
|
87 |
st.session_state.uploaded_data = st.session_state.uploaded_data.dropna(subset=[target_variable])
|
88 |
|
89 |
+
# 前処理の進捗状況を表示
|
90 |
+
progress_bar = st.progress(0)
|
91 |
+
status_text = st.empty()
|
92 |
|
93 |
# 前処理(モデル作成の準備)
|
94 |
if 'exp_clf101_setup_done' not in st.session_state: # このセッションで既にセットアップが完了していない場合のみ実行
|
95 |
+
try:
|
96 |
+
with st.spinner('データの前処理中...'):
|
97 |
+
exp_clf101 = setup(data=st.session_state.uploaded_data,
|
98 |
+
target=target_variable,
|
99 |
+
session_id=123,
|
100 |
+
remove_outliers=remove_outliers_option,
|
101 |
+
ignore_features=ignore_variable)
|
102 |
+
st.session_state.exp_clf101 = exp_clf101
|
103 |
+
st.session_state.exp_clf101_setup_done = True # セットアップ完了フラグをセッションに保存
|
104 |
+
|
105 |
+
# 前処理の進捗状況を更新
|
106 |
+
progress_bar.progress(50)
|
107 |
+
status_text.text("前処理が完了しました。")
|
108 |
+
except Exception as e:
|
109 |
+
st.error(f"前処理中にエラーが発生しました: {str(e)}")
|
110 |
+
st.stop()
|
111 |
|
|
|
112 |
setup_list_df = pull()
|
113 |
st.write("前処理の結果")
|
114 |
st.caption("以下は、前処理のステップとそれに伴うデータのパラメータを示す表です。")
|
115 |
st.write(setup_list_df)
|
116 |
|
117 |
# モデルの比較
|
118 |
+
try:
|
119 |
+
with st.spinner('モデルを比較中...'):
|
120 |
+
models_comparison = compare_models(exclude=['dummy','catboost'])
|
121 |
+
st.session_state.models_comparison = models_comparison # セッション状態にモデル比較を保存
|
122 |
+
models_comparison_df = pull()
|
123 |
+
st.session_state.models_comparison_df = models_comparison_df
|
124 |
+
|
125 |
+
# モデル比較の進捗状況を���新
|
126 |
+
progress_bar.progress(100)
|
127 |
+
status_text.text("モデルの比較が完了しました!")
|
128 |
+
except Exception as e:
|
129 |
+
st.error(f"モデルの比較中にエラーが発生しました: {str(e)}")
|
130 |
+
st.stop()
|
131 |
|
132 |
# モデルの選択とチューニング
|
133 |
if 'models_comparison' in st.session_state:
|
|
|
134 |
# モデル比較の表示
|
135 |
models_comparison_df = pull()
|
136 |
st.session_state.models_comparison_df = models_comparison_df
|
137 |
st.write("モデル比較結果")
|
138 |
st.caption("以下は、利用可能な各モデルの性能を示す表です。")
|
139 |
st.dataframe(st.session_state.models_comparison_df)
|
140 |
+
|
141 |
+
# モデル比較結果の解釈の説明を追加
|
142 |
+
st.write("モデル比較結果の解釈:")
|
143 |
+
st.write("- Accuracy: モデルの予測精度を示します。値が高いほどモデルの性能が良いことを示します。")
|
144 |
+
st.write("- AUC: ROC曲線下の面積を示します。値が高いほどモデルの性能が良いことを示します。")
|
145 |
+
st.write("- Recall: 実際の正例のうち、正しく正例と予測された割合を示します。")
|
146 |
+
st.write("- Precision: 正例と予測されたもののうち、実際に正例である割合を示します。")
|
147 |
+
st.write("- F1: RecallとPrecisionの調和平均を示します。両者のバランスを考慮した指標です。")
|
148 |
+
st.write("- Kappa: モデルの予測結果と実際の結果の一致度を示します。値が高いほどモデルの性能が良いことを示します。")
|
149 |
+
st.write("- MCC: 不均衡データにおけるモデルの性能を示します。値が高いほどモデルの性能が良いことを示します。")
|
150 |
+
|
151 |
st.header("5. モデルの選択とチューニング")
|
152 |
st.caption("最も性能の良いモデルを選択し、さらにそのモデルのパラメータをチューニングします。")
|
153 |
selected_model_name = st.selectbox('使用するモデルを選択してください。', st.session_state.models_comparison_df.index)
|
|
|
160 |
if selected_model_name in tree_models:
|
161 |
max_depth = st.slider("決定木の最大の深さを選択", 1, 10, 3) # 例として最小1、最大10、デフォルト3
|
162 |
|
163 |
+
# モデル名の入力
|
164 |
+
model_name = st.text_input("保存するモデルの名前を入力してください", value="tuned_model")
|
165 |
|
166 |
if st.button('チューニングの実行'):
|
167 |
with st.spinner('チューニング中...'):
|
|
|
185 |
st.write(setup_tuned_model_df)
|
186 |
st.caption("上記表は、チューニング前後のモデルの交差検証結果を示す表です。")
|
187 |
|
188 |
+
# チューニング前後の比較結果の解釈の説明を追加
|
189 |
+
st.write("チューニング前後の比較結果の解釈:")
|
190 |
+
st.write("- チューニング後の方が、Accuracy、AUC、Recall、Precision、F1、Kappa、MCCの値が高い場合、モデルの性能が向上したことを示します。")
|
191 |
+
st.write("- チューニング後の方が、これらの指標の値が低い場合、モデルの性能が悪化したことを示します。")
|
192 |
+
st.write("- チューニングによる変化がない場合は、モデルの性能に大きな影響がなかったことを示します。")
|
193 |
+
|
194 |
# チューニング後のモデルを保存
|
195 |
if 'tuned_model' in st.session_state:
|
196 |
# モデルをバイナリ形式で保存
|
197 |
+
with open(f"{model_name}.pkl", "wb") as f:
|
198 |
joblib.dump(st.session_state.tuned_model, f)
|
199 |
|
200 |
# ファイルをbase64エンコードしてダウンロードリンクを作成
|
201 |
+
with open(f"{model_name}.pkl", "rb") as f:
|
202 |
model_file = f.read()
|
203 |
model_b64 = base64.b64encode(model_file).decode()
|
204 |
+
href = f'<a href="data:application/octet-stream;base64,{model_b64}" download="{model_name}.pkl">チューニングされたモデルをダウンロード</a>'
|
205 |
st.markdown(href, unsafe_allow_html=True)
|
206 |
|
207 |
st.header("6. モデルの可視化及び評価")
|
|
|
216 |
('vc', '<検証曲線��', 'パラメータの異なる値に対するモデルの性能を示しています'),
|
217 |
('manifold', '<マニホールド学習>', '高次元データを2次元にマッピングしたものを示しています')
|
218 |
]
|
|
|
219 |
for plot_type, plot_name, plot_description in plot_types:
|
220 |
+
try:
|
221 |
+
with st.spinner(f'{plot_name}のプロット中...'):
|
|
|
222 |
img = plot_model(tuned_model, plot=plot_type, display_format="streamlit", save=True)
|
223 |
+
st.subheader(plot_name) # グラフのタイトルを追加
|
224 |
st.image(img)
|
225 |
st.caption(plot_description) # グラフの説明を追加
|
226 |
+
except Exception as e:
|
227 |
+
st.warning(f"{plot_name}の表示中にエラーが発生しました: {str(e)}")
|
228 |
|
229 |
# 決定木のプロット
|
230 |
if selected_model_name in tree_models:
|
231 |
+
st.subheader("<決定木のプロット>")
|
232 |
st.caption("決定木は、モデルがどのように予測を行っているかを理解するのに役立ちます。")
|
233 |
|
234 |
+
try:
|
235 |
+
with st.spinner('決定木のプロット中...'):
|
236 |
+
|
237 |
+
if selected_model_name in ['dt']:
|
238 |
+
from sklearn.tree import plot_tree
|
239 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
240 |
+
plot_tree(tuned_model, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
|
241 |
+
st.pyplot(fig)
|
242 |
+
|
243 |
+
elif selected_model_name in ['rf', 'et']:
|
244 |
+
from sklearn.tree import plot_tree
|
245 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
246 |
+
plot_tree(tuned_model.estimators_[0], feature_names=train_data.columns, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
|
247 |
+
st.pyplot(fig)
|
248 |
+
|
249 |
+
elif selected_model_name == 'ada':
|
250 |
+
from sklearn.tree import plot_tree
|
251 |
+
base_estimator = tuned_model.get_model().estimators_[0]
|
252 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
253 |
+
plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
|
254 |
+
st.pyplot(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
+
elif selected_model_name == 'gbr':
|
257 |
+
from sklearn.tree import plot_tree
|
258 |
+
base_estimator = tuned_model.get_model().estimators_[0][0]
|
259 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
260 |
+
plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3, fontsize=14) # フォントサイズを変更
|
261 |
+
st.pyplot(fig)
|
262 |
|
263 |
+
elif selected_model_name == 'catboost':
|
264 |
+
from catboost import CatBoostClassifier, plot_tree
|
265 |
+
catboost_model = tuned_model.get_model()
|
266 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
267 |
+
plot_tree(catboost_model, tree_idx=0, ax=ax, max_depth=3)
|
268 |
+
st.pyplot(fig)
|
269 |
+
|
270 |
+
elif selected_model_name == 'lightgbm':
|
271 |
+
import lightgbm as lgb
|
272 |
+
booster = tuned_model._Booster # LightGBM Booster object
|
273 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
274 |
+
lgb.plot_tree(booster, tree_index=0, ax=ax, max_depth=3)
|
275 |
+
st.pyplot(fig)
|
276 |
+
|
277 |
+
elif selected_model_name == 'xgboost':
|
278 |
+
import xgboost as xgb
|
279 |
+
booster = tuned_model.get_booster() # XGBoost Booster object
|
280 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
281 |
+
xgb.plot_tree(booster, num_trees=0, ax=ax, max_depth=3)
|
282 |
+
st.pyplot(fig)
|
283 |
+
except Exception as e:
|
284 |
+
st.warning(f"決定木のプロット中にエラーが発生しました: {str(e)}")
|