chrisli commited on
Commit
659f17d
1 Parent(s): 196e19d
Files changed (3) hide show
  1. app.py +99 -0
  2. definitions.py +24 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from definitions import *
2
+
3
+ st.set_option('deprecation.showPyplotGlobalUse', False)
4
+ st.sidebar.subheader("请选择模型参数:sunglasses:")
5
+
6
+ num_leaves = st.sidebar.slider(label = 'num_leaves', min_value = 4,
7
+ max_value = 200 ,
8
+ value = 31,
9
+ step = 1)
10
+
11
+ max_depth = st.sidebar.slider(label = 'max_depth', min_value = -1,
12
+ max_value = 15,
13
+ value = -1,
14
+ step = 1)
15
+
16
+ min_data_in_leaf = st.sidebar.slider(label = 'min_data_in_leaf', min_value = 8,
17
+ max_value = 55,
18
+ value = 20,
19
+ step = 1)
20
+
21
+ feature_fraction = st.sidebar.slider(label = 'feature_fraction', min_value = 0.0,
22
+ max_value = 1.0 ,
23
+ value = 0.8,
24
+ step = 0.1)
25
+
26
+ min_data_per_group = st.sidebar.slider(label = 'min_data_per_group', min_value = 6,
27
+ max_value = 289 ,
28
+ value = 100,
29
+ step = 1)
30
+
31
+ max_cat_threshold = st.sidebar.slider(label = 'max_cat_threshold', min_value = 6,
32
+ max_value = 289 ,
33
+ value = 32,
34
+ step = 1)
35
+
36
+ learning_rate = st.sidebar.slider(label = 'learning_rate', min_value = 0.0,
37
+ max_value = 1.00,
38
+ value = 0.05,
39
+ step = 0.01)
40
+
41
+ num_leaves = st.sidebar.slider(label = 'num_leaves', min_value = 6,
42
+ max_value = 289 ,
43
+ value = 31,
44
+ step = 1)
45
+
46
+ max_bin = st.sidebar.slider(label = 'max_bin', min_value = 6,
47
+ max_value = 289 ,
48
+ value = 255,
49
+ step = 1)
50
+
51
+ num_iterations = st.sidebar.slider(label = 'num_iterations', min_value = 8,
52
+ max_value = 289,
53
+ value = 100,
54
+ step = 1)
55
+
56
+ st.header('LightGBM-parameter-tuning-with-streamlit')
57
+
58
+
59
+ # 加载数据
60
+ breast_cancer = load_breast_cancer()
61
+ data = breast_cancer.data
62
+ target = breast_cancer.target
63
+
64
+ # 划分训练数据和测试数据
65
+ X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)
66
+
67
+ # 转换为Dataset数据格式
68
+ lgb_train = lgb.Dataset(X_train, y_train)
69
+ lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
70
+
71
+ # 模型训练
72
+ params = {'num_leaves': num_leaves, 'max_depth': max_depth,
73
+ 'min_data_in_leaf': min_data_in_leaf,
74
+ 'feature_fraction': feature_fraction,
75
+ 'min_data_per_group': min_data_per_group,
76
+ 'max_cat_threshold': max_cat_threshold,
77
+ 'learning_rate':learning_rate,'num_leaves':num_leaves,
78
+ 'max_bin':max_bin,'num_iterations':num_iterations
79
+ }
80
+
81
+ gbm = lgb.train(params, lgb_train, num_boost_round=2000, valid_sets=lgb_eval, early_stopping_rounds=500)
82
+ lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
83
+ probs = gbm.predict(X_test, num_iteration=gbm.best_iteration) # 输出的是概率结果
84
+
85
+ fpr, tpr, thresholds = roc_curve(y_test, probs)
86
+ st.write('------------------------------------')
87
+ st.write('Confusion Matrix:')
88
+ st.write(confusion_matrix(y_test, np.where(probs > 0.5, 1, 0)))
89
+
90
+ st.write('------------------------------------')
91
+ st.write('Classification Report:')
92
+ report = classification_report(y_test, np.where(probs > 0.5, 1, 0), output_dict=True)
93
+ report_matrix = pd.DataFrame(report).transpose()
94
+ st.dataframe(report_matrix)
95
+
96
+ st.write('------------------------------------')
97
+ st.write('ROC:')
98
+
99
+ plot_roc(fpr, tpr)
definitions.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import streamlit as st
3
+ import numpy as np
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.datasets import load_breast_cancer
6
+ from sklearn.metrics import roc_auc_score,roc_curve,auc,accuracy_score,classification_report,confusion_matrix,precision_recall_curve
7
+ import lightgbm as lgb
8
+ import matplotlib.pyplot as plt
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ def plot_roc(fpr, tpr, label=None):
13
+ roc_auc = auc(fpr, tpr)
14
+ plt.title('Receiver Operating Characteristic')
15
+ plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
16
+ plt.legend(loc = 'lower right')
17
+ plt.plot([0, 1], [0, 1],'r--')
18
+ plt.xlim([0, 1])
19
+ plt.ylim([0, 1])
20
+ plt.ylabel('True Positive Rate')
21
+ plt.xlabel('False Positive Rate')
22
+ plt.show()
23
+ st.pyplot()
24
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pandas==1.3.1
2
+ streamlit==1.8.1
3
+ numpy==1.20.3
4
+ lightgbm==3.3.2
5
+ matplotlib==3.4.2
6
+ scikit-learn==1.0.1