first
Browse files- app.py +99 -0
- definitions.py +24 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from definitions import *
|
2 |
+
|
3 |
+
st.set_option('deprecation.showPyplotGlobalUse', False)
|
4 |
+
st.sidebar.subheader("请选择模型参数:sunglasses:")
|
5 |
+
|
6 |
+
num_leaves = st.sidebar.slider(label = 'num_leaves', min_value = 4,
|
7 |
+
max_value = 200 ,
|
8 |
+
value = 31,
|
9 |
+
step = 1)
|
10 |
+
|
11 |
+
max_depth = st.sidebar.slider(label = 'max_depth', min_value = -1,
|
12 |
+
max_value = 15,
|
13 |
+
value = -1,
|
14 |
+
step = 1)
|
15 |
+
|
16 |
+
min_data_in_leaf = st.sidebar.slider(label = 'min_data_in_leaf', min_value = 8,
|
17 |
+
max_value = 55,
|
18 |
+
value = 20,
|
19 |
+
step = 1)
|
20 |
+
|
21 |
+
feature_fraction = st.sidebar.slider(label = 'feature_fraction', min_value = 0.0,
|
22 |
+
max_value = 1.0 ,
|
23 |
+
value = 0.8,
|
24 |
+
step = 0.1)
|
25 |
+
|
26 |
+
min_data_per_group = st.sidebar.slider(label = 'min_data_per_group', min_value = 6,
|
27 |
+
max_value = 289 ,
|
28 |
+
value = 100,
|
29 |
+
step = 1)
|
30 |
+
|
31 |
+
max_cat_threshold = st.sidebar.slider(label = 'max_cat_threshold', min_value = 6,
|
32 |
+
max_value = 289 ,
|
33 |
+
value = 32,
|
34 |
+
step = 1)
|
35 |
+
|
36 |
+
learning_rate = st.sidebar.slider(label = 'learning_rate', min_value = 0.0,
|
37 |
+
max_value = 1.00,
|
38 |
+
value = 0.05,
|
39 |
+
step = 0.01)
|
40 |
+
|
41 |
+
num_leaves = st.sidebar.slider(label = 'num_leaves', min_value = 6,
|
42 |
+
max_value = 289 ,
|
43 |
+
value = 31,
|
44 |
+
step = 1)
|
45 |
+
|
46 |
+
max_bin = st.sidebar.slider(label = 'max_bin', min_value = 6,
|
47 |
+
max_value = 289 ,
|
48 |
+
value = 255,
|
49 |
+
step = 1)
|
50 |
+
|
51 |
+
num_iterations = st.sidebar.slider(label = 'num_iterations', min_value = 8,
|
52 |
+
max_value = 289,
|
53 |
+
value = 100,
|
54 |
+
step = 1)
|
55 |
+
|
56 |
+
st.header('LightGBM-parameter-tuning-with-streamlit')
|
57 |
+
|
58 |
+
|
59 |
+
# 加载数据
|
60 |
+
breast_cancer = load_breast_cancer()
|
61 |
+
data = breast_cancer.data
|
62 |
+
target = breast_cancer.target
|
63 |
+
|
64 |
+
# 划分训练数据和测试数据
|
65 |
+
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
|
66 |
+
|
67 |
+
# 转换为Dataset数据格式
|
68 |
+
lgb_train = lgb.Dataset(X_train, y_train)
|
69 |
+
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
70 |
+
|
71 |
+
# 模型训练
|
72 |
+
params = {'num_leaves': num_leaves, 'max_depth': max_depth,
|
73 |
+
'min_data_in_leaf': min_data_in_leaf,
|
74 |
+
'feature_fraction': feature_fraction,
|
75 |
+
'min_data_per_group': min_data_per_group,
|
76 |
+
'max_cat_threshold': max_cat_threshold,
|
77 |
+
'learning_rate':learning_rate,'num_leaves':num_leaves,
|
78 |
+
'max_bin':max_bin,'num_iterations':num_iterations
|
79 |
+
}
|
80 |
+
|
81 |
+
gbm = lgb.train(params, lgb_train, num_boost_round=2000, valid_sets=lgb_eval, early_stopping_rounds=500)
|
82 |
+
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
|
83 |
+
probs = gbm.predict(X_test, num_iteration=gbm.best_iteration) # 输出的是概率结果
|
84 |
+
|
85 |
+
fpr, tpr, thresholds = roc_curve(y_test, probs)
|
86 |
+
st.write('------------------------------------')
|
87 |
+
st.write('Confusion Matrix:')
|
88 |
+
st.write(confusion_matrix(y_test, np.where(probs > 0.5, 1, 0)))
|
89 |
+
|
90 |
+
st.write('------------------------------------')
|
91 |
+
st.write('Classification Report:')
|
92 |
+
report = classification_report(y_test, np.where(probs > 0.5, 1, 0), output_dict=True)
|
93 |
+
report_matrix = pd.DataFrame(report).transpose()
|
94 |
+
st.dataframe(report_matrix)
|
95 |
+
|
96 |
+
st.write('------------------------------------')
|
97 |
+
st.write('ROC:')
|
98 |
+
|
99 |
+
plot_roc(fpr, tpr)
|
definitions.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import streamlit as st
|
3 |
+
import numpy as np
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
from sklearn.datasets import load_breast_cancer
|
6 |
+
from sklearn.metrics import roc_auc_score,roc_curve,auc,accuracy_score,classification_report,confusion_matrix,precision_recall_curve
|
7 |
+
import lightgbm as lgb
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import warnings
|
10 |
+
warnings.filterwarnings('ignore')
|
11 |
+
|
12 |
+
def plot_roc(fpr, tpr, label=None):
|
13 |
+
roc_auc = auc(fpr, tpr)
|
14 |
+
plt.title('Receiver Operating Characteristic')
|
15 |
+
plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
|
16 |
+
plt.legend(loc = 'lower right')
|
17 |
+
plt.plot([0, 1], [0, 1],'r--')
|
18 |
+
plt.xlim([0, 1])
|
19 |
+
plt.ylim([0, 1])
|
20 |
+
plt.ylabel('True Positive Rate')
|
21 |
+
plt.xlabel('False Positive Rate')
|
22 |
+
plt.show()
|
23 |
+
st.pyplot()
|
24 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas==1.3.1
|
2 |
+
streamlit==1.8.1
|
3 |
+
numpy==1.20.3
|
4 |
+
lightgbm==3.3.2
|
5 |
+
matplotlib==3.4.2
|
6 |
+
scikit-learn==1.0.1
|