Théo Villette
commited on
Commit
•
cac478c
1
Parent(s):
02d33be
fix features importance
Browse files
app.py
CHANGED
@@ -11,6 +11,8 @@ st.set_page_config(layout="wide")
|
|
11 |
# add categorical features
|
12 |
# handle missing values with automl (not possible)
|
13 |
|
|
|
|
|
14 |
with st.sidebar:
|
15 |
|
16 |
st.subheader('Demo Datasets')
|
|
|
11 |
# add categorical features
|
12 |
# handle missing values with automl (not possible)
|
13 |
|
14 |
+
# fix importance features problem
|
15 |
+
|
16 |
with st.sidebar:
|
17 |
|
18 |
st.subheader('Demo Datasets')
|
autoML.py
CHANGED
@@ -26,21 +26,25 @@ def autoML(csv, task, budget, label, metric_to_minimize_class, metric_to_minimiz
|
|
26 |
|
27 |
|
28 |
if task == 'Classification':
|
|
|
|
|
29 |
automl_settings = {
|
30 |
"time_budget": int(budget),
|
31 |
-
"metric":
|
32 |
"task": 'classification',
|
33 |
-
"log_file_name":
|
34 |
"early_stop": True,
|
35 |
"eval_method": "holdout"
|
36 |
}
|
37 |
|
38 |
if task == 'Regression':
|
|
|
|
|
39 |
automl_settings = {
|
40 |
"time_budget": int(budget),
|
41 |
-
"metric":
|
42 |
"task": 'regression',
|
43 |
-
"log_file_name":
|
44 |
"early_stop": True,
|
45 |
"eval_method": "holdout"
|
46 |
}
|
@@ -55,13 +59,6 @@ def autoML(csv, task, budget, label, metric_to_minimize_class, metric_to_minimiz
|
|
55 |
tab1, tab2 = st.tabs(["AutoML", "Best Model"])
|
56 |
|
57 |
with tab1:
|
58 |
-
|
59 |
-
if task == 'Classification':
|
60 |
-
log = 'classlog.log'
|
61 |
-
metric = metric_to_minimize_class
|
62 |
-
if task == 'Regression':
|
63 |
-
log = 'reglog.log'
|
64 |
-
metric = metric_to_minimize_reg
|
65 |
|
66 |
time_history, best_valid_loss_history, valid_loss_history, config_history, metric_history = get_output_from_log(filename=log, time_budget=120)
|
67 |
|
@@ -109,17 +106,18 @@ def autoML(csv, task, budget, label, metric_to_minimize_class, metric_to_minimiz
|
|
109 |
col1, col2, col3 = st.columns((1,1,1))
|
110 |
|
111 |
with col1:
|
112 |
-
st.metric(label=
|
113 |
with col2:
|
114 |
st.metric(label="Time to find", value=str(round(automl.time_to_find_best_model, 2))+' sec')
|
115 |
with col3:
|
116 |
st.metric(label="Time to train", value=str(round(automl.best_config_train_time, 2))+' sec')
|
117 |
|
118 |
-
|
119 |
-
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
|
124 |
|
125 |
def download_model(model):
|
|
|
26 |
|
27 |
|
28 |
if task == 'Classification':
|
29 |
+
metric = metric_to_minimize_class
|
30 |
+
log = 'classlog.log'
|
31 |
automl_settings = {
|
32 |
"time_budget": int(budget),
|
33 |
+
"metric": metric,
|
34 |
"task": 'classification',
|
35 |
+
"log_file_name": log,
|
36 |
"early_stop": True,
|
37 |
"eval_method": "holdout"
|
38 |
}
|
39 |
|
40 |
if task == 'Regression':
|
41 |
+
metric = metric_to_minimize_reg
|
42 |
+
log = 'reglog.log'
|
43 |
automl_settings = {
|
44 |
"time_budget": int(budget),
|
45 |
+
"metric": metric,
|
46 |
"task": 'regression',
|
47 |
+
"log_file_name": log,
|
48 |
"early_stop": True,
|
49 |
"eval_method": "holdout"
|
50 |
}
|
|
|
59 |
tab1, tab2 = st.tabs(["AutoML", "Best Model"])
|
60 |
|
61 |
with tab1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
time_history, best_valid_loss_history, valid_loss_history, config_history, metric_history = get_output_from_log(filename=log, time_budget=120)
|
64 |
|
|
|
106 |
col1, col2, col3 = st.columns((1,1,1))
|
107 |
|
108 |
with col1:
|
109 |
+
st.metric(label=metric, value=round(1 - automl.best_loss, 2))
|
110 |
with col2:
|
111 |
st.metric(label="Time to find", value=str(round(automl.time_to_find_best_model, 2))+' sec')
|
112 |
with col3:
|
113 |
st.metric(label="Time to train", value=str(round(automl.best_config_train_time, 2))+' sec')
|
114 |
|
115 |
+
if automl.best_estimator == 'lgbm':
|
116 |
+
df_features_importance = pd.DataFrame({'features name': automl.model.estimator.feature_name_, 'features importance': automl.model.estimator.feature_importances_})
|
117 |
+
fig_features = px.bar(df_features_importance, x='features importance', y='features name')
|
118 |
|
119 |
+
st.divider()
|
120 |
+
st.plotly_chart(fig_features, theme="streamlit")
|
121 |
|
122 |
|
123 |
def download_model(model):
|