singhk28
commited on
Commit
·
d6586a1
1
Parent(s):
6c16a82
Add Classifier. Improve error catching for SHAP analysis/ feature importance.
Browse files
app.py
CHANGED
@@ -4,7 +4,8 @@ import numpy as np
|
|
4 |
import streamlit as st
|
5 |
from pycaret import regression as reg
|
6 |
from pycaret import classification as clf
|
7 |
-
from sklearn.metrics import mean_absolute_error, max_error, r2_score, mean_squared_error
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import streamlit.components.v1 as components
|
10 |
import mpld3
|
@@ -22,32 +23,44 @@ col1, mid, col2 = st.columns([10,1,20])
|
|
22 |
with col1:
|
23 |
st.image('https://images.pexels.com/photos/2599244/pexels-photo-2599244.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
|
24 |
with col2:
|
25 |
-
st.markdown("""This tool prepares a machine learning model
|
|
|
|
|
|
|
26 |
st.markdown("""---""")
|
27 |
|
28 |
st.markdown(f"**To use this tool**, fill out all the requested fields from top to bottom.")
|
29 |
st.markdown(f"**Note:** If an error is obtained refresh the page and start over.")
|
|
|
30 |
## Column Name
|
31 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"1) Provide name of the column you want to predict with model."}</h3>', unsafe_allow_html=True)
|
32 |
target_col = st.text_input("Enter the exact name of the column with your target variable. This field is case sensitive. (i.e., capital letters must match.)")
|
33 |
-
|
|
|
34 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"2) Select type of model you would like to build"}</h3>', unsafe_allow_html=True)
|
35 |
-
mod_type = st.selectbox("What type of model would you like to train? Pick regression model for continous values
|
|
|
36 |
## Mode of Use
|
37 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"3) Select mode of use"}</h3>', unsafe_allow_html=True)
|
38 |
mode_type = st.selectbox("What would you like to use the tool for?", ('Benchmarking (finding the best algorithm for your problem)', 'Parameter Search (find combination of parameters to get a desired value)'))
|
39 |
if mode_type == 'Parameter Search (find combination of parameters to get a desired value)':
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
43 |
if mod_type == 'regression':
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
else:
|
47 |
desired_value = st.text_input("Enter the desired target parameter value. This field is case sensitive. (i.e., capital letters must match.)", key="DV for Classifier")
|
48 |
## Ask for Dataset
|
49 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"5) Upload CSV file "}</h3>', unsafe_allow_html=True)
|
50 |
uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
|
|
|
51 |
else:
|
52 |
## Ask for Dataset
|
53 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"4) Upload CSV file "}</h3>', unsafe_allow_html=True)
|
@@ -84,20 +97,23 @@ if uploaded_file:
|
|
84 |
# Figure out Column Data Types
|
85 |
object_columns = data.select_dtypes(include="object").columns.tolist()
|
86 |
|
|
|
|
|
|
|
87 |
# ---------------------------------------------------------------------------------------------------------------------- #
|
88 |
# Build Regression Model
|
89 |
if mod_type == "regression":
|
90 |
# Setup Regressor Problem
|
91 |
if object_columns:
|
92 |
if data_size > 20:
|
93 |
-
s = reg.setup(train_data, target = target_col,
|
94 |
else:
|
95 |
-
s = reg.setup(data, target = target_col,
|
96 |
else:
|
97 |
if data_size > 20:
|
98 |
-
s = reg.setup(train_data, target = target_col,
|
99 |
else:
|
100 |
-
s = reg.setup(data, target = target_col,
|
101 |
|
102 |
# Find the best algorithm to build Model:
|
103 |
st.subheader("Algorithm Selection")
|
@@ -127,8 +143,9 @@ if uploaded_file:
|
|
127 |
st.write('Best hyperparameters: ', final_mod.get_params())
|
128 |
|
129 |
# Print a SHAP Analysis Summary Plot:
|
130 |
-
|
131 |
-
|
|
|
132 |
|
133 |
if len(data) > 20:
|
134 |
# Predict on the test set if it was created:
|
@@ -231,6 +248,64 @@ if uploaded_file:
|
|
231 |
# ---------------------------------------------------------------------------------------------------------------------- #
|
232 |
# Build Classifier Model
|
233 |
if mod_type == "classifier":
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
|
|
236 |
st.markdown("")
|
|
|
4 |
import streamlit as st
|
5 |
from pycaret import regression as reg
|
6 |
from pycaret import classification as clf
|
7 |
+
from sklearn.metrics import mean_absolute_error, max_error, r2_score, mean_squared_error, confusion_matrix, ConfusionMatrixDisplay
|
8 |
+
from sklearn.metrics import accuracy_score, auc, recall_score, precision_score, f1_score, cohen_kappa_score
|
9 |
import matplotlib.pyplot as plt
|
10 |
import streamlit.components.v1 as components
|
11 |
import mpld3
|
|
|
23 |
with col1:
|
24 |
st.image('https://images.pexels.com/photos/2599244/pexels-photo-2599244.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=1')
|
25 |
with col2:
|
26 |
+
st.markdown("""This tool prepares a machine learning model using your tabular data. The tool can be used in 2 ways:""", unsafe_allow_html=True)
|
27 |
+
st.markdown("""1) Benchmark different algorithms for your dataset to find the best algorithm and then tune that model to determine best hyperparameters.""", unsafe_allow_html=True)
|
28 |
+
st.markdown("""2) In the case of experimental science, the best obtained model can be used to make predictions for various combinations of the provided data to try to obtain a combination that achieves a desired target value (if possible).""", unsafe_allow_html=True)
|
29 |
+
st.markdown("""**The tool is currently under active development. Please direct any bug reports or inquiries to the <a href="http://cleanenergy.utoronto.ca/">clean energy lab at UofT.</a>**""", unsafe_allow_html=True)
|
30 |
st.markdown("""---""")
|
31 |
|
32 |
st.markdown(f"**To use this tool**, fill out all the requested fields from top to bottom.")
|
33 |
st.markdown(f"**Note:** If an error is obtained refresh the page and start over.")
|
34 |
+
|
35 |
## Column Name
|
36 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"1) Provide name of the column you want to predict with model."}</h3>', unsafe_allow_html=True)
|
37 |
target_col = st.text_input("Enter the exact name of the column with your target variable. This field is case sensitive. (i.e., capital letters must match.)")
|
38 |
+
|
39 |
+
## Task Type: Regression or Classification
|
40 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"2) Select type of model you would like to build"}</h3>', unsafe_allow_html=True)
|
41 |
+
mod_type = st.selectbox("What type of model would you like to train? Pick regression model for continous values or classifier for categorical values.", ('regression', 'classifier'))
|
42 |
+
|
43 |
## Mode of Use
|
44 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"3) Select mode of use"}</h3>', unsafe_allow_html=True)
|
45 |
mode_type = st.selectbox("What would you like to use the tool for?", ('Benchmarking (finding the best algorithm for your problem)', 'Parameter Search (find combination of parameters to get a desired value)'))
|
46 |
if mode_type == 'Parameter Search (find combination of parameters to get a desired value)':
|
47 |
+
## Desired Target Value
|
48 |
+
if mod_type == 'classifier':
|
49 |
+
st.write('Parameter search not currently supported with classifier type models.')
|
50 |
+
st.write('Please refresh page and try again with the supported tasks.')
|
51 |
+
exit()
|
52 |
+
|
53 |
if mod_type == 'regression':
|
54 |
+
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"4) Type of parameter search"}</h3>', unsafe_allow_html=True)
|
55 |
+
opt_type = st.selectbox("What do you want to do with the output?", ('Maximize it', 'Minimize it', 'Obtain a desired value'))
|
56 |
+
if opt_type == 'Obtain a desired value':
|
57 |
+
desired_value = float(st.number_input("Enter the desired value for the target variable."))
|
58 |
else:
|
59 |
desired_value = st.text_input("Enter the desired target parameter value. This field is case sensitive. (i.e., capital letters must match.)", key="DV for Classifier")
|
60 |
## Ask for Dataset
|
61 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"5) Upload CSV file "}</h3>', unsafe_allow_html=True)
|
62 |
uploaded_file = st.file_uploader("Upload a CSV file", type="csv")
|
63 |
+
|
64 |
else:
|
65 |
## Ask for Dataset
|
66 |
st.markdown(f'<h3 style="color:#000000;font-size:20px;">{"4) Upload CSV file "}</h3>', unsafe_allow_html=True)
|
|
|
97 |
# Figure out Column Data Types
|
98 |
object_columns = data.select_dtypes(include="object").columns.tolist()
|
99 |
|
100 |
+
# Create a list of Tree Models:
|
101 |
+
tree_mods_list = ['Extra Trees Regressor', 'Extra Trees Classifier', 'Random Forest Regressor', 'Random Forest Classifier', 'Decision Tree Regressor', 'Decision Tree Classifier', 'CatBoost Regressor', 'Light Gradient Boosting Machine']
|
102 |
+
|
103 |
# ---------------------------------------------------------------------------------------------------------------------- #
|
104 |
# Build Regression Model
|
105 |
if mod_type == "regression":
|
106 |
# Setup Regressor Problem
|
107 |
if object_columns:
|
108 |
if data_size > 20:
|
109 |
+
s = reg.setup(train_data, target = target_col, normalize=True, categorical_features=object_columns, fold=5, silent= True)
|
110 |
else:
|
111 |
+
s = reg.setup(data, target = target_col, normalize=True, categorical_features=object_columns, silent= True)
|
112 |
else:
|
113 |
if data_size > 20:
|
114 |
+
s = reg.setup(train_data, target = target_col, normalize=True, silent= True, fold=5)
|
115 |
else:
|
116 |
+
s = reg.setup(data, target = target_col, normalize=True, silent= True)
|
117 |
|
118 |
# Find the best algorithm to build Model:
|
119 |
st.subheader("Algorithm Selection")
|
|
|
143 |
st.write('Best hyperparameters: ', final_mod.get_params())
|
144 |
|
145 |
# Print a SHAP Analysis Summary Plot:
|
146 |
+
if best_mod_name in tree_mods_list:
|
147 |
+
st.subheader("SHAP Analysis Summary Plot")
|
148 |
+
st.pyplot(reg.interpret_model(final_mod))
|
149 |
|
150 |
if len(data) > 20:
|
151 |
# Predict on the test set if it was created:
|
|
|
248 |
# ---------------------------------------------------------------------------------------------------------------------- #
|
249 |
# Build Classifier Model
|
250 |
if mod_type == "classifier":
|
251 |
+
# Setup Classifier Problem
|
252 |
+
if data_size > 20:
|
253 |
+
s = clf.setup(train_data, target = target_col, normalize=True, silent= True, fold=5)
|
254 |
+
else:
|
255 |
+
s = clf.setup(data, target = target_col, normalize=True, silent= True)
|
256 |
+
|
257 |
+
# Find the best algorithm to build Model:
|
258 |
+
st.subheader("Algorithm Selection")
|
259 |
+
start_algo = time.time()
|
260 |
+
with st.spinner(text="Finding the best algorithm for your dataset..."):
|
261 |
+
best_mod = clf.compare_models()
|
262 |
+
classifier_results = clf.pull()
|
263 |
+
best_mod_name = classifier_results.Model[0]
|
264 |
+
st.write(classifier_results)
|
265 |
+
end_algo = time.time()
|
266 |
+
st.write('Time taken to select algorithm:', end_algo - start_algo, 'seconds')
|
267 |
+
|
268 |
+
# Tune the hyperparameters for the best algorithm:
|
269 |
+
st.subheader("Tuning the Model")
|
270 |
+
start_tune = time.time()
|
271 |
+
with st.spinner(text="Tuning the algorithm..."):
|
272 |
+
tuned_mod = clf.tune_model(best_mod, optimize = 'AUC', n_iter=5)
|
273 |
+
end_tune = time.time()
|
274 |
+
st.write('Time taken to select hyperparameters:', end_tune - start_tune, 'seconds')
|
275 |
+
|
276 |
+
# Finalize the model (Train on the entire train dataset):
|
277 |
+
with st.spinner("Finalizing the model..."):
|
278 |
+
final_mod = clf.finalize_model(tuned_mod)
|
279 |
+
|
280 |
+
st.success('Model successfully trained! Here are your results:')
|
281 |
+
st.write('Best algorithm: ', best_mod_name)
|
282 |
+
st.write('Best hyperparameters: ', final_mod.get_params())
|
283 |
+
|
284 |
+
# Print a Feature Importance Plot:
|
285 |
+
if best_mod_name in tree_mods_list:
|
286 |
+
st.subheader("Feature Importance Plot")
|
287 |
+
st.pyplot(clf.plot_model(final_mod, plot='feature'))
|
288 |
+
|
289 |
+
if len(data) > 20:
|
290 |
+
# Predict on the test set if it was created:
|
291 |
+
st.subheader("Evaluating model on the test/hold out data:")
|
292 |
+
predictions = clf.predict_model(final_mod, data=test_data)
|
293 |
+
st.success('Here are your results:')
|
294 |
+
st.write(predictions)
|
295 |
+
st.caption('"Label" is the value predicted by the model.')
|
296 |
+
st.write('---')
|
297 |
+
|
298 |
+
# Provide Accuracy:
|
299 |
+
mod_accuracy = accuracy_score(predictions[target_col], predictions['Label'])
|
300 |
+
st.write('**Model accuracy on test set :**', f'{(mod_accuracy):.2f}')
|
301 |
+
|
302 |
+
# Create a confusion matrix:
|
303 |
+
st.subheader("Confusion Matrix for test set:")
|
304 |
+
cm = confusion_matrix(predictions[target_col], predictions['Label'])
|
305 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=predictions[target_col].unique())
|
306 |
+
disp.plot()
|
307 |
+
plt.grid(b=None)
|
308 |
+
st.pyplot()
|
309 |
|
310 |
+
# Visitor Badge
|
311 |
st.markdown("")
|