Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -12,46 +12,46 @@ import gradio as gr
|
|
12 |
def Tree_Detection(sample):
|
13 |
|
14 |
sample=list(sample)
|
15 |
-
|
16 |
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
|
17 |
-
|
18 |
# print(lenses)
|
19 |
-
|
20 |
lenses_target.append(each[-1])
|
21 |
# print(lenses_target)
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
for each in lenses:
|
28 |
lenses_list.append(each[lensesLabels.index(each_label)])
|
29 |
lenses_dict[each_label] = lenses_list
|
30 |
lenses_list = []
|
31 |
# print(lenses_dict) # 打印字典信息
|
32 |
-
|
33 |
# print(lenses_pd) # 打印pandas.DataFrame
|
34 |
-
|
35 |
-
|
36 |
lenses_pd[col] = le.fit_transform(lenses_pd[col])
|
37 |
# print(lenses_pd) # 打印编码信息
|
38 |
|
39 |
-
|
40 |
-
|
41 |
|
42 |
-
|
43 |
-
|
44 |
feature_names=lenses_pd.keys(),
|
45 |
class_names=clf.classes_,
|
46 |
filled=True, rounded=True,
|
47 |
special_characters=True)
|
48 |
|
49 |
-
|
50 |
# img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
|
51 |
|
52 |
-
|
53 |
|
54 |
-
|
55 |
# print(Tree_Detection([2, 1, 1, 0]))
|
56 |
|
57 |
def test(image):
|
|
|
12 |
def Tree_Detection(sample):
|
13 |
|
14 |
sample=list(sample)
|
15 |
+
with open('lenses.txt', 'r') as fr: # 加载文件
|
16 |
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
|
17 |
+
lenses_target = [] # 提取每组数据的类别,保存在列表里
|
18 |
# print(lenses)
|
19 |
+
for each in lenses:
|
20 |
lenses_target.append(each[-1])
|
21 |
# print(lenses_target)
|
22 |
|
23 |
+
lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签
|
24 |
+
lenses_list = [] # 保存lenses数据的临时列表
|
25 |
+
lenses_dict = {} # 保存lenses数据的字典,用于生成pandas
|
26 |
+
for each_label in lensesLabels: # 提取信息,生成字典
|
27 |
for each in lenses:
|
28 |
lenses_list.append(each[lensesLabels.index(each_label)])
|
29 |
lenses_dict[each_label] = lenses_list
|
30 |
lenses_list = []
|
31 |
# print(lenses_dict) # 打印字典信息
|
32 |
+
lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame
|
33 |
# print(lenses_pd) # 打印pandas.DataFrame
|
34 |
+
le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化
|
35 |
+
for col in lenses_pd.columns: # 序列化
|
36 |
lenses_pd[col] = le.fit_transform(lenses_pd[col])
|
37 |
# print(lenses_pd) # 打印编码信息
|
38 |
|
39 |
+
clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类
|
40 |
+
clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树
|
41 |
|
42 |
+
dot_data = StringIO()
|
43 |
+
tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树
|
44 |
feature_names=lenses_pd.keys(),
|
45 |
class_names=clf.classes_,
|
46 |
filled=True, rounded=True,
|
47 |
special_characters=True)
|
48 |
|
49 |
+
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
|
50 |
# img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
|
51 |
|
52 |
+
result = clf.predict([sample]) # 预测
|
53 |
|
54 |
+
return result
|
55 |
# print(Tree_Detection([2, 1, 1, 0]))
|
56 |
|
57 |
def test(image):
|