Upload 4 files
Browse files- README.md +5 -5
- app.py +244 -0
- gitattributes +35 -0
- requirements.txt +13 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title: Pycaret Datascience Streamlit
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: afl-3.0
|
|
|
1 |
---
|
2 |
+
title: Pycaret Datascience Streamlit
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.27.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: afl-3.0
|
app.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
from pycaret.classification import *
|
5 |
+
from pycaret.regression import *
|
6 |
+
from sklearn.tree import *
|
7 |
+
import plotly.figure_factory as ff
|
8 |
+
import graphviz
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import japanize_matplotlib
|
11 |
+
import joblib
|
12 |
+
import base64
|
13 |
+
import io
|
14 |
+
|
15 |
+
# Streamlitのページ設定
|
16 |
+
st.set_page_config(page_title="AIデータサイエンス")
|
17 |
+
|
18 |
+
# タイトルの表示
|
19 |
+
st.title("AIデータサイエンス")
|
20 |
+
st.caption("Created by Dit-Lab.(Daiki Ito)")
|
21 |
+
st.write("アップロードされたデータセットに基づいて、機械学習モデルの作成と評価を行います。")
|
22 |
+
st.write("データの読み込み → モデル比較 → チューニング → 可視化 を行うことができます")
|
23 |
+
|
24 |
+
# データファイルのアップロード
|
25 |
+
st.header("1. データファイルのアップロード")
|
26 |
+
st.caption("こちらからデータをアップロードしてください。アップロードしたデータは次のステップで前処理され、モデルの訓練に使用されます。")
|
27 |
+
uploaded_file = st.file_uploader("CSVまたはExcelファイルをアップロードしてください。", type=['csv', 'xlsx'])
|
28 |
+
|
29 |
+
# 新しいデータがアップロードされたときにセッションをリセット
|
30 |
+
if uploaded_file is not None and ('last_uploaded_file' not in st.session_state or uploaded_file.name != st.session_state.last_uploaded_file):
|
31 |
+
st.session_state.clear() # セッションステートをクリア
|
32 |
+
st.session_state.last_uploaded_file = uploaded_file.name # 最後にアップロードされたファイル名を保存
|
33 |
+
|
34 |
+
if uploaded_file:
|
35 |
+
if uploaded_file.name.endswith('.csv'):
|
36 |
+
train_data = pd.read_csv(uploaded_file)
|
37 |
+
elif uploaded_file.name.endswith('.xlsx'):
|
38 |
+
train_data = pd.read_excel(uploaded_file)
|
39 |
+
else:
|
40 |
+
st.error("無効なファイルタイプです。CSVまたはExcelファイルをアップロードしてください.")
|
41 |
+
|
42 |
+
st.session_state.uploaded_data = train_data
|
43 |
+
st.dataframe(train_data.head())
|
44 |
+
|
45 |
+
# ターゲット変数の選択
|
46 |
+
if 'uploaded_data' in st.session_state:
|
47 |
+
st.header("2. ターゲット変数の選択")
|
48 |
+
st.caption("モデルの予測対象となるターゲット変数を選択してください。この変数がモデルの予測ターゲットとなります。")
|
49 |
+
target_variable = st.selectbox('ターゲット変数を選択してください。', st.session_state.uploaded_data.columns)
|
50 |
+
st.session_state.target_variable = target_variable
|
51 |
+
|
52 |
+
# 分析から除外する変数の選択
|
53 |
+
st.header("3. 分析から除外する変数の選択")
|
54 |
+
st.caption("モデルの訓練から除外したい変数を選択してください。これらの変数はモデルの訓練には使用されません。")
|
55 |
+
ignore_variable = st.multiselect('分析から除外する変数を選択してください。', st.session_state.uploaded_data.columns)
|
56 |
+
st.session_state.ignore_variable = ignore_variable
|
57 |
+
|
58 |
+
# フィルタリングされたデータフレームの表示
|
59 |
+
filtered_data = st.session_state.uploaded_data.drop(columns=ignore_variable)
|
60 |
+
st.write("フィルタリングされたデータフレームの表示")
|
61 |
+
st.write(filtered_data)
|
62 |
+
|
63 |
+
# Excelファイルダウンロード機能
|
64 |
+
towrite = io.BytesIO()
|
65 |
+
downloaded_file = filtered_data.to_excel(towrite, encoding='utf-8', index=False, header=True)
|
66 |
+
towrite.seek(0)
|
67 |
+
b64 = base64.b64encode(towrite.read()).decode()
|
68 |
+
original_filename = uploaded_file.name.split('.')[0] # 元のファイル名を取得
|
69 |
+
download_filename = f"{original_filename}_filtered.xlsx" # フィルタリングされたファイル名を作成
|
70 |
+
link = f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{download_filename}">フィルタリングされたデータフレームをExcelファイルとしてダウンロード</a>'
|
71 |
+
st.markdown(link, unsafe_allow_html=True)
|
72 |
+
|
73 |
+
# 前処理の実行及びモデルの比較
|
74 |
+
st.header("4. 前処理の実行とモデルの比較")
|
75 |
+
st.caption("データの前処理を行い、利用可能な複数のモデルを比較します。最も適したモデルを選択するための基準としてください。")
|
76 |
+
|
77 |
+
# 外れ値の処理
|
78 |
+
remove_outliers_option = st.checkbox('外れ値を削除する', value=False)
|
79 |
+
|
80 |
+
if st.button('前処理とモデルの比較の実行'): # この条件を追加
|
81 |
+
# データの検証
|
82 |
+
if st.session_state.uploaded_data[target_variable].isnull().any():
|
83 |
+
st.warning("ターゲット変数に欠損値が含まれているレコードを削除します。")
|
84 |
+
st.session_state.uploaded_data = st.session_state.uploaded_data.dropna(subset=[target_variable])
|
85 |
+
|
86 |
+
|
87 |
+
# 前処理(モデル作成の準備)
|
88 |
+
if 'exp_clf101_setup_done' not in st.session_state: # このセッションで既にセットアップが完了していない場合のみ実行
|
89 |
+
with st.spinner('データの前処理中...'):
|
90 |
+
exp_clf101 = setup(data=st.session_state.uploaded_data,
|
91 |
+
target=target_variable,
|
92 |
+
session_id=123,
|
93 |
+
remove_outliers=remove_outliers_option,
|
94 |
+
ignore_features=ignore_variable)
|
95 |
+
st.session_state.exp_clf101 = exp_clf101
|
96 |
+
st.session_state.exp_clf101_setup_done = True # セットアップ完了フラグをセッションに保存
|
97 |
+
|
98 |
+
st.info("前処理が完了しました")
|
99 |
+
setup_list_df = pull()
|
100 |
+
st.write("前処理の結果")
|
101 |
+
st.caption("以下は、前処理のステップとそれに伴うデータのパラメータを示す表です。")
|
102 |
+
st.write(setup_list_df)
|
103 |
+
|
104 |
+
# モデルの比較
|
105 |
+
with st.spinner('モデルを比較中...'):
|
106 |
+
models_comparison = compare_models(exclude=['dummy','catboost'])
|
107 |
+
st.session_state.models_comparison = models_comparison # セッション状態にモデル比較を保存
|
108 |
+
models_comparison_df = pull()
|
109 |
+
st.session_state.models_comparison_df = models_comparison_df
|
110 |
+
|
111 |
+
# モデルの選択とチューニング
|
112 |
+
if 'models_comparison' in st.session_state:
|
113 |
+
st.success("モデルの比較が完了しました!")
|
114 |
+
# モデル比較の表示
|
115 |
+
models_comparison_df = pull()
|
116 |
+
st.session_state.models_comparison_df = models_comparison_df
|
117 |
+
st.write("モデル比較結果")
|
118 |
+
st.caption("以下は、利用可能な各モデルの性能を示す表です。")
|
119 |
+
st.dataframe(st.session_state.models_comparison_df)
|
120 |
+
|
121 |
+
st.header("5. モデルの選択とチューニング")
|
122 |
+
st.caption("最も性能の良いモデルを選択し、さらにそのモデルのパラメータをチューニングします。")
|
123 |
+
selected_model_name = st.selectbox('使用するモデルを選択してください。', st.session_state.models_comparison_df.index)
|
124 |
+
|
125 |
+
# 決定木プロットは可能なモデルのリストを表示
|
126 |
+
tree_models = ['ada', 'et', 'rf', 'dt', 'gbr', 'catboost', 'lightgbm', 'xgboost']
|
127 |
+
st.write("決定木プロットが可能なモデル: " + ", ".join(tree_models))
|
128 |
+
|
129 |
+
# max_depth のオプションを表示
|
130 |
+
if selected_model_name in tree_models:
|
131 |
+
max_depth = st.slider("決定木の最大の深さを選択", 1, 10, 3) # 例として最小1、最大10、デフォルト3
|
132 |
+
|
133 |
+
|
134 |
+
if st.button('チューニングの実行'):
|
135 |
+
with st.spinner('チューニング中...'):
|
136 |
+
if selected_model_name in tree_models:
|
137 |
+
model = create_model(selected_model_name, max_depth=max_depth)
|
138 |
+
else:
|
139 |
+
# モデルの作成とチューニング
|
140 |
+
model = create_model(selected_model_name)
|
141 |
+
|
142 |
+
pre_tuned_scores_df = pull()
|
143 |
+
tuned_model = tune_model(model)
|
144 |
+
st.session_state.tuned_model = tuned_model # tuned_modelをセッションステートに保存
|
145 |
+
st.success("モデルのチューニングが完了しました!")
|
146 |
+
setup_tuned_model_df = pull()
|
147 |
+
col1, col2 = st.columns(2)
|
148 |
+
with col1:
|
149 |
+
st.write("<チューニング前の交差検証の結果>")
|
150 |
+
st.write(pre_tuned_scores_df)
|
151 |
+
with col2:
|
152 |
+
st.write("<チューニング後の交差検証の結果>")
|
153 |
+
st.write(setup_tuned_model_df)
|
154 |
+
st.caption("上記表は、チューニング前後のモデルの交差検証結果を示す表です。")
|
155 |
+
|
156 |
+
# チューニング後のモデルを保存
|
157 |
+
if 'tuned_model' in st.session_state:
|
158 |
+
# モデルをバイナリ形式で保存
|
159 |
+
with open("tuned_model.pkl", "wb") as f:
|
160 |
+
joblib.dump(st.session_state.tuned_model, f)
|
161 |
+
|
162 |
+
# ファイルをbase64エンコードしてダウンロードリンクを作成
|
163 |
+
with open("tuned_model.pkl", "rb") as f:
|
164 |
+
model_file = f.read()
|
165 |
+
model_b64 = base64.b64encode(model_file).decode()
|
166 |
+
href = f'<a href="data:application/octet-stream;base64,{model_b64}" download="tuned_model.pkl">チューニングされたモデルをダウンロード</a>'
|
167 |
+
st.markdown(href, unsafe_allow_html=True)
|
168 |
+
|
169 |
+
st.header("6. モデルの可視化及び評価")
|
170 |
+
st.caption("以下は、チューニング後のモデルのさまざまな可視化を示しています。これらの可視化は、モデルの性能や特性を理解する��に役立ちます。")
|
171 |
+
plot_types = [
|
172 |
+
('pipeline', '<前処理パイプライン>', '前処理の流れ(フロー)を表しています'),
|
173 |
+
('residuals', '<残差プロット>', '実際の値と予測値との差(残差)を示しています'),
|
174 |
+
('error', '<予測誤差プロット>', 'モデルの予測誤差を示しています'),
|
175 |
+
('feature', '<特徴量の重要度>', '各特徴量のモデルにおける重要度を示しています'),
|
176 |
+
('cooks', '<クックの距離プロット>', 'データポイントがモデルに与える影響を示しています'),
|
177 |
+
('learning', '<学習曲線>', '訓練データのサイズに対するモデルの性能を示しています'),
|
178 |
+
('vc', '<検証曲線>', 'パラメータの異なる値に対するモデルの性能を示しています'),
|
179 |
+
('manifold', '<マニホールド学習>', '高次元データを2次元にマッピングしたものを示しています')
|
180 |
+
]
|
181 |
+
|
182 |
+
for plot_type, plot_name, plot_description in plot_types:
|
183 |
+
with st.spinner('プロット中...'):
|
184 |
+
try:
|
185 |
+
st.write(plot_name)
|
186 |
+
img = plot_model(tuned_model, plot=plot_type, display_format="streamlit", save=True)
|
187 |
+
st.image(img)
|
188 |
+
st.caption(plot_description) # グラフの説明を追加
|
189 |
+
except Exception as e:
|
190 |
+
st.warning(f"{plot_name}の表示中にエラーが発生しました: {str(e)}")
|
191 |
+
|
192 |
+
# 決定木のプロット
|
193 |
+
if selected_model_name in tree_models:
|
194 |
+
st.write("<決定木のプロット>")
|
195 |
+
st.caption("決定木は、モデルがどのように予測を行っているかを理解するのに役立ちます。")
|
196 |
+
|
197 |
+
with st.spinner('プロット中...'):
|
198 |
+
|
199 |
+
if selected_model_name in ['dt']:
|
200 |
+
from sklearn.tree import plot_tree
|
201 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
202 |
+
plot_tree(tuned_model, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3)
|
203 |
+
st.pyplot(fig)
|
204 |
+
|
205 |
+
elif selected_model_name in ['rf', 'et']:
|
206 |
+
from sklearn.tree import plot_tree
|
207 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
208 |
+
plot_tree(tuned_model.estimators_[0], feature_names=train_data.columns, proportion=True, filled=True, rounded=True, ax=ax, max_depth=3)
|
209 |
+
st.pyplot(fig)
|
210 |
+
|
211 |
+
elif selected_model_name == 'ada':
|
212 |
+
from sklearn.tree import plot_tree
|
213 |
+
base_estimator = tuned_model.get_model().estimators_[0]
|
214 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
215 |
+
plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
|
216 |
+
st.pyplot(fig)
|
217 |
+
|
218 |
+
elif selected_model_name == 'gbr':
|
219 |
+
from sklearn.tree import plot_tree
|
220 |
+
base_estimator = tuned_model.get_model().estimators_[0][0]
|
221 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
222 |
+
plot_tree(base_estimator, filled=True, rounded=True, ax=ax, max_depth=3)
|
223 |
+
st.pyplot(fig)
|
224 |
+
|
225 |
+
elif selected_model_name == 'catboost':
|
226 |
+
from catboost import CatBoostClassifier, plot_tree
|
227 |
+
catboost_model = tuned_model.get_model()
|
228 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
229 |
+
plot_tree(catboost_model, tree_idx=0, ax=ax, max_depth=3)
|
230 |
+
st.pyplot(fig)
|
231 |
+
|
232 |
+
elif selected_model_name == 'lightgbm':
|
233 |
+
import lightgbm as lgb
|
234 |
+
booster = tuned_model._Booster # LightGBM Booster object
|
235 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
236 |
+
lgb.plot_tree(booster, tree_index=0, ax=ax, max_depth=3)
|
237 |
+
st.pyplot(fig)
|
238 |
+
|
239 |
+
elif selected_model_name == 'xgboost':
|
240 |
+
import xgboost as xgb
|
241 |
+
booster = tuned_model.get_booster() # XGBoost Booster object
|
242 |
+
fig, ax = plt.subplots(figsize=(20,10))
|
243 |
+
xgb.plot_tree(booster, num_trees=0, ax=ax, max_depth=3)
|
244 |
+
st.pyplot(fig)
|
gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
pandas
|
3 |
+
numpy
|
4 |
+
pycaret[full]
|
5 |
+
scikit-learn
|
6 |
+
graphviz
|
7 |
+
matplotlib
|
8 |
+
japanize-matplotlib
|
9 |
+
openpyxl
|
10 |
+
joblib
|
11 |
+
lightgbm
|
12 |
+
xgboost
|
13 |
+
catboost
|