itou-daiki commited on
Commit
10831b3
·
verified ·
1 Parent(s): a2ca727

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +5 -5
  2. app.py +244 -0
  3. gitattributes +35 -0
  4. requirements.txt +13 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Pycaret Datascience Streamlit Demo
3
- emoji: 🐠
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: streamlit
7
- sdk_version: 1.30.0
8
  app_file: app.py
9
  pinned: false
10
  license: afl-3.0
 
1
  ---
2
+ title: Pycaret Datascience Streamlit
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: streamlit
7
+ sdk_version: 1.27.2
8
  app_file: app.py
9
  pinned: false
10
  license: afl-3.0
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from pycaret.classification import *
5
+ from pycaret.regression import *
6
+ from sklearn.tree import *
7
+ import plotly.figure_factory as ff
8
+ import graphviz
9
+ import matplotlib.pyplot as plt
10
+ import japanize_matplotlib
11
+ import joblib
12
+ import base64
13
+ import io
14
+
15
+ # Streamlitのページ設定
16
+ st.set_page_config(page_title="AIデータサイエンス")
17
+
18
+ # タイトルの表示
19
+ st.title("AIデータサイエンス")
20
+ st.caption("Created by Dit-Lab.(Daiki Ito)")
21
+ st.write("アップロードされたデータセットに基づいて、機械学習モデルの作成と評価を行います。")
22
+ st.write("データの読み込み → モデル比較 → チューニング → 可視化 を行うことができます")
23
+
24
+ # データファイルのアップロード
25
+ st.header("1. データファイルのアップロード")
26
+ st.caption("こちらからデータをアップロードしてください。アップロードしたデータは次のステップで前処理され、モデルの訓練に使用されます。")
27
+ uploaded_file = st.file_uploader("CSVまたはExcelファイルをアップロードしてください。", type=['csv', 'xlsx'])
28
+
29
+ # 新しいデータがアップロードされたときにセッションをリセット
30
+ if uploaded_file is not None and ('last_uploaded_file' not in st.session_state or uploaded_file.name != st.session_state.last_uploaded_file):
31
+ st.session_state.clear() # セッションステートをクリア
32
+ st.session_state.last_uploaded_file = uploaded_file.name # 最後にアップロードされたファイル名を保存
33
+
34
+ if uploaded_file:
35
+ if uploaded_file.name.endswith('.csv'):
36
+ train_data = pd.read_csv(uploaded_file)
37
+ elif uploaded_file.name.endswith('.xlsx'):
38
+ train_data = pd.read_excel(uploaded_file)
39
+ else:
40
+ st.error("無効なファイルタイプです。CSVまたはExcelファイルをアップロードしてください.")
41
+
42
+ st.session_state.uploaded_data = train_data
43
+ st.dataframe(train_data.head())
44
+
45
+ # ターゲット変数の選択
46
+ if 'uploaded_data' in st.session_state:
47
+ st.header("2. ターゲット変数の選択")
48
+ st.caption("モデルの予測対象となるターゲット変数を選択してください。この変数がモデルの予測ターゲットとなります。")
49
+ target_variable = st.selectbox('ターゲット変数を選択してください。', st.session_state.uploaded_data.columns)
50
+ st.session_state.target_variable = target_variable
51
+
52
+ # 分析から除外する変数の選択
53
+ st.header("3. 分析から除外する変数の選択")
54
+ st.caption("モデルの訓練から除外したい変数を選択してください。これらの変数はモデルの訓練には使用されません。")
55
+ ignore_variable = st.multiselect('分析から除外する変数を選択してください。', st.session_state.uploaded_data.columns)
56
+ st.session_state.ignore_variable = ignore_variable
57
+
58
+ # フィルタリングされたデータフレームの表示
59
+ filtered_data = st.session_state.uploaded_data.drop(columns=ignore_variable)
60
+ st.write("フィルタリングされたデータフレームの表示")
61
+ st.write(filtered_data)
62
+
63
+ # Excelファイルダウンロード機能
64
+ towrite = io.BytesIO()
65
+ downloaded_file = filtered_data.to_excel(towrite, encoding='utf-8', index=False, header=True)
66
+ towrite.seek(0)
67
+ b64 = base64.b64encode(towrite.read()).decode()
68
+ original_filename = uploaded_file.name.split('.')[0] # 元のファイル名を取得
69
+ download_filename = f"{original_filename}_filtered.xlsx" # フィルタリングされたファイル名を作成
70
+ link = f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{download_filename}">フィルタリングされたデータフレームをExcelファイルとしてダウンロード</a>'
71
+ st.markdown(link, unsafe_allow_html=True)
72
+
73
+ # 前処理の実行及びモデルの比較
74
+ st.header("4. 前処理の実行とモデルの比較")
75
+ st.caption("データの前処理を行い、利用可能な複数のモデルを比較します。最も適したモデルを選択するための基準としてください。")
76
+
77
+ # 外れ値の処理
78
+ remove_outliers_option = st.checkbox('外れ値を削除する', value=False)
79
+
80
+ if st.button('前処理とモデルの比較の実行'): # この条件を追加
81
+ # データの検証
82
+ if st.session_state.uploaded_data[target_variable].isnull().any():
83
+ st.warning("ターゲット変数に欠損値が含まれているレコードを削除します。")
84
+ st.session_state.uploaded_data = st.session_state.uploaded_data.dropna(subset=[target_variable])
85
+
86
+
87
+ # 前処理(モデル作成の準備)
88
+ if 'exp_clf101_setup_done' not in st.session_state: # このセッションで既にセットアップが完了していない場合のみ実行
89
+ with st.spinner('データの前処理中...'):
90
+ exp_clf101 = setup(data=st.session_state.uploaded_data,
91
+ target=target_variable,
92
+ session_id=123,
93
+ remove_outliers=remove_outliers_option,
94
+ ignore_features=ignore_variable)
95
+ st.session_state.exp_clf101 = exp_clf101
96
+ st.session_state.exp_clf101_setup_done = True # セットアップ完了フラグをセッションに保存
97
+
98
+ st.info("前処理が完了しました")
99
+ setup_list_df = pull()
100
+ st.write("前処理の結果")
101
+ st.caption("以下は、前処理のステップとそれに伴うデータのパラメータを示す表です。")
102
+ st.write(setup_list_df)
103
+
104
+ # モデルの比較
105
+ with st.spinner('モデルを比較中...'):
106
+ models_comparison = compare_models(exclude=['dummy','catboost'])
107
+ st.session_state.models_comparison = models_comparison # セッション状態にモデル比較を保存
108
+ models_comparison_df = pull()
109
+ st.session_state.models_comparison_df = models_comparison_df
110
+
111
+ # モデルの選択とチューニング
112
+ if 'models_comparison' in st.session_state:
113
+ st.success("モデルの比較が完了しました!")
114
+ # モデル比較の表示
115
+ models_comparison_df = pull()
116
+ st.session_state.models_comparison_df = models_comparison_df
117
+ st.write("モデル比較結果")
118
+ st.caption("以下は、利用可能な各モデルの性能を示す表です。")
119
+ st.dataframe(st.session_state.models_comparison_df)
120
+
121
+ st.header("5. モデルの選択とチューニング")
122
+ st.caption("最も性能の良いモデルを選択し、さらにそのモデルのパラメータをチューニングします。")
123
+ selected_model_name = st.selectbox('使用するモデルを選択してください。', st.session_state.models_comparison_df.index)
124
+
125
+ # 決定木プロットは可能なモデルのリストを表示
126
+ tree_models = ['ada', 'et', 'rf', 'dt', 'gbr', 'catboost', 'lightgbm', 'xgboost']
127
+ st.write("決定木プロットが可能なモデル: " + ", ".join(tree_models))
128
+
129
+ # max_depth のオプションを表示
130
+ if selected_model_name in tree_models:
131
+ max_depth = st.slider("決定木の最大の深さを選択", 1, 10, 3) # 例として最小1、最大10、デフォルト3
132
+
133
+
134
+ if st.button('チューニングの実行'):
135
+ with st.spinner('チューニング中...'):
136
+ if selected_model_name in tree_models:
137
+ model = create_model(selected_model_name, max_depth=max_depth)
138
+ else:
139
+ # モデルの作成とチューニング
140
+ model = create_model(selected_model_name)
141
+
142
+ pre_tuned_scores_df = pull()
143
+ tuned_model = tune_model(model)
144
+ st.session_state.tuned_model = tuned_model # tuned_modelをセッションステートに保存
145
+ st.success("モデルのチューニングが完了しました!")
146
+ setup_tuned_model_df = pull()
147
+ col1, col2 = st.columns(2)
148
+ with col1:
149
+ st.write("<チューニング前の交差検証の結果>")
150
+ st.write(pre_tuned_scores_df)
151
+ with col2:
152
+ st.write("<チューニング後の交差検証の結果>")
153
+ st.write(setup_tuned_model_df)
154
+ st.caption("上記表は、チューニング前後のモデルの交差検証結果を示す表です。")
155
+
156
+ # チューニング後のモデルを保存
157
+ if 'tuned_model' in st.session_state:
158
+ # モデルをバイナリ形式で保存
159
+ with open("tuned_model.pkl", "wb") as f:
160
+ joblib.dump(st.session_state.tuned_model, f)
161
+
162
+ # ファイルをbase64エンコードしてダウンロードリンクを作成
163
+ with open("tuned_model.pkl", "rb") as f:
164
+ model_file = f.read()
165
+ model_b64 = base64.b64encode(model_file).decode()
166
+ href = f'<a href="data:application/octet-stream;base64,{model_b64}" download="tuned_model.pkl">チューニングされたモデルをダウンロード</a>'
167
+ st.markdown(href, unsafe_allow_html=True)
168
+
169
+ st.header("6. モデルの可視化及び評価")
170
+ st.caption("以下は、チューニング後のモデルのさまざまな可視化を示しています。これらの可視化は、モデルの性能や特性を理解する��に役立ちます。")
171
+ plot_types = [
172
+ ('pipeline', '<前処理パイプライン>', '前処理の流れ(フロー)を表しています'),
173
+ ('residuals', '<残差プロット>', '実際の値と予測値との差(残差)を示しています'),
174
+ ('error', '<予測誤差プロット>', 'モデルの予測誤差を示しています'),
175
+ ('feature', '<特徴量の重要度>', '各特徴量のモデルにおける重要度を示しています'),
176
+ ('cooks', '<クックの距離プロット>', 'データポイントがモデルに与える影響を示しています'),
177
+ ('learning', '<学習曲線>', '訓練データのサイズに対するモデルの性能を示しています'),
178
+ ('vc', '<検証曲線>', 'パラメータの異なる値に対するモデルの性能を示しています'),
179
+ ('manifold', '<マニホールド学習>', '高次元データを2次元にマッピングしたものを示しています')
180
+ ]
181
+
182
+ for plot_type, plot_name, plot_description in plot_types:
183
+ with st.spinner('プロット中...'):
184
+ try:
185
+ st.write(plot_name)
186
+ img = plot_model(tuned_model, plot=plot_type, display_format="streamlit", save=True)
187
+ st.image(img)
188
+ st.caption(plot_description) # グラフの説明を追加
189
+ except Exception as e:
190
+ st.warning(f"{plot_name}の表示中にエラーが発生しました: {str(e)}")
191
+
192
+ # 決定木のプロット
193
+ if selected_model_name in tree_models:
194
+ st.write("<決定木のプロット>")
195
+ st.caption("決定木は、モデルがどのように予測を行っているかを理解するのに役立ちます。")
196
+
197
+ with st.spinner('プロット中...'):
198
+
199
+ if selected_model_name in ['dt']:
200
+ from sklearn.tree import plot_tree
201
+ fig, ax = plt.subplots(figsize=(20,10))
202
+ plot_tree(tuned_model, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3)
203
+ st.pyplot(fig)
204
+
205
+ elif selected_model_name in ['rf', 'et']:
206
+ from sklearn.tree import plot_tree
207
+ fig, ax = plt.subplots(figsize=(20,10))
208
+ plot_tree(tuned_model.estimators_[0], feature_names=train_data.columns, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3)
209
+ st.pyplot(fig)
210
+
211
+ elif selected_model_name == 'ada':
212
+ from sklearn.tree import plot_tree
213
+ base_estimator = tuned_model.get_model().estimators_[0]
214
+ fig, ax = plt.subplots(figsize=(20,10))
215
+ plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
216
+ st.pyplot(fig)
217
+
218
+ elif selected_model_name == 'gbr':
219
+ from sklearn.tree import plot_tree
220
+ base_estimator = tuned_model.get_model().estimators_[0][0]
221
+ fig, ax = plt.subplots(figsize=(20,10))
222
+ plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
223
+ st.pyplot(fig)
224
+
225
+ elif selected_model_name == 'catboost':
226
+ from catboost import CatBoostClassifier, plot_tree
227
+ catboost_model = tuned_model.get_model()
228
+ fig, ax = plt.subplots(figsize=(20,10))
229
+ plot_tree(catboost_model, tree_idx=0, ax=ax, max_depth=3)
230
+ st.pyplot(fig)
231
+
232
+ elif selected_model_name == 'lightgbm':
233
+ import lightgbm as lgb
234
+ booster = tuned_model._Booster # LightGBM Booster object
235
+ fig, ax = plt.subplots(figsize=(20,10))
236
+ lgb.plot_tree(booster, tree_index=0, ax=ax, max_depth=3)
237
+ st.pyplot(fig)
238
+
239
+ elif selected_model_name == 'xgboost':
240
+ import xgboost as xgb
241
+ booster = tuned_model.get_booster() # XGBoost Booster object
242
+ fig, ax = plt.subplots(figsize=(20,10))
243
+ xgb.plot_tree(booster, num_trees=0, ax=ax, max_depth=3)
244
+ st.pyplot(fig)
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ pycaret[full]
5
+ scikit-learn
6
+ graphviz
7
+ matplotlib
8
+ japanize-matplotlib
9
+ openpyxl
10
+ joblib
11
+ lightgbm
12
+ xgboost
13
+ catboost