shane666 commited on
Commit
ac71df5
·
1 Parent(s): 4f469c9

Upload Sklearn-Decision Tree_gradio.py

Browse files
Files changed (1) hide show
  1. Sklearn-Decision Tree_gradio.py +62 -0
Sklearn-Decision Tree_gradio.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from sklearn.preprocessing import LabelEncoder, OneHotEncoder
4
+ from six import StringIO
5
+ from sklearn import tree
6
+ import pandas as pd
7
+ import numpy as np
8
+ import pydotplus
9
+ import gradio as gr
10
+
11
+
12
+ def Tree_Detection(sample):
13
+ with open('lenses.txt', 'r') as fr: # 加载文件
14
+ lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
15
+ lenses_target = [] # 提取每组数据的类别,保存在列表里
16
+ # print(lenses)
17
+ for each in lenses:
18
+ lenses_target.append(each[-1])
19
+ # print(lenses_target)
20
+
21
+ lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签
22
+ lenses_list = [] # 保存lenses数据的临时列表
23
+ lenses_dict = {} # 保存lenses数据的字典,用于生成pandas
24
+ for each_label in lensesLabels: # 提取信息,生成字典
25
+ for each in lenses:
26
+ lenses_list.append(each[lensesLabels.index(each_label)])
27
+ lenses_dict[each_label] = lenses_list
28
+ lenses_list = []
29
+ # print(lenses_dict) # 打印字典信息
30
+ lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame
31
+ # print(lenses_pd) # 打印pandas.DataFrame
32
+ le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化
33
+ for col in lenses_pd.columns: # 序列化
34
+ lenses_pd[col] = le.fit_transform(lenses_pd[col])
35
+ # print(lenses_pd) # 打印编码信息
36
+
37
+ clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类
38
+ clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树
39
+
40
+ dot_data = StringIO()
41
+ tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树
42
+ feature_names=lenses_pd.keys(),
43
+ class_names=clf.classes_,
44
+ filled=True, rounded=True,
45
+ special_characters=True)
46
+
47
+ graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
48
+ # img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
49
+
50
+ result = clf.predict([sample]) # 预测
51
+
52
+ return result
53
+ # print(Tree_Detection([2, 1, 1, 0]))
54
+
55
+
56
+ demo = gr.Interface(
57
+ fn=Tree_Detection,
58
+ inputs='text',
59
+ outputs='text'
60
+ )
61
+ demo.launch()
62
+