Spaces:
Runtime error
Runtime error
Upload Sklearn-Decision Tree_gradio.py
Browse files
Sklearn-Decision Tree_gradio.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
|
4 |
+
from six import StringIO
|
5 |
+
from sklearn import tree
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import pydotplus
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
|
12 |
+
def Tree_Detection(sample):
|
13 |
+
with open('lenses.txt', 'r') as fr: # 加载文件
|
14 |
+
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
|
15 |
+
lenses_target = [] # 提取每组数据的类别,保存在列表里
|
16 |
+
# print(lenses)
|
17 |
+
for each in lenses:
|
18 |
+
lenses_target.append(each[-1])
|
19 |
+
# print(lenses_target)
|
20 |
+
|
21 |
+
lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签
|
22 |
+
lenses_list = [] # 保存lenses数据的临时列表
|
23 |
+
lenses_dict = {} # 保存lenses数据的字典,用于生成pandas
|
24 |
+
for each_label in lensesLabels: # 提取信息,生成字典
|
25 |
+
for each in lenses:
|
26 |
+
lenses_list.append(each[lensesLabels.index(each_label)])
|
27 |
+
lenses_dict[each_label] = lenses_list
|
28 |
+
lenses_list = []
|
29 |
+
# print(lenses_dict) # 打印字典信息
|
30 |
+
lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame
|
31 |
+
# print(lenses_pd) # 打印pandas.DataFrame
|
32 |
+
le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化
|
33 |
+
for col in lenses_pd.columns: # 序列化
|
34 |
+
lenses_pd[col] = le.fit_transform(lenses_pd[col])
|
35 |
+
# print(lenses_pd) # 打印编码信息
|
36 |
+
|
37 |
+
clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类
|
38 |
+
clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树
|
39 |
+
|
40 |
+
dot_data = StringIO()
|
41 |
+
tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树
|
42 |
+
feature_names=lenses_pd.keys(),
|
43 |
+
class_names=clf.classes_,
|
44 |
+
filled=True, rounded=True,
|
45 |
+
special_characters=True)
|
46 |
+
|
47 |
+
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
|
48 |
+
# img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
|
49 |
+
|
50 |
+
result = clf.predict([sample]) # 预测
|
51 |
+
|
52 |
+
return result
|
53 |
+
# print(Tree_Detection([2, 1, 1, 0]))
|
54 |
+
|
55 |
+
|
56 |
+
demo = gr.Interface(
|
57 |
+
fn=Tree_Detection,
|
58 |
+
inputs='text',
|
59 |
+
outputs='text'
|
60 |
+
)
|
61 |
+
demo.launch()
|
62 |
+
|