File size: 2,248 Bytes
ac71df5
 
 
 
 
 
 
 
 
 
 
 
76d02ff
84fbda4
ac71df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960a75a
 
 
ac71df5
 
b0ecf6a
 
 
ac71df5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from six import StringIO
from sklearn import tree
import pandas as pd
import numpy as np
import pydotplus
import gradio as gr


def Tree_Detection(sample):
    
    sample=list(sample)
	with open('lenses.txt', 'r') as fr:										# 加载文件
		lenses = [inst.strip().split('\t') for inst in fr.readlines()]		# 处理文件
	lenses_target = []														# 提取每组数据的类别,保存在列表里
	# print(lenses)
	for each in lenses:
		lenses_target.append(each[-1])
	# print(lenses_target)

	lensesLabels = ['noise', 'rotation', 'power-up', 'temp']			# 特征标签
	lenses_list = []														# 保存lenses数据的临时列表
	lenses_dict = {}														# 保存lenses数据的字典,用于生成pandas
	for each_label in lensesLabels:											# 提取信息,生成字典
		for each in lenses:
			lenses_list.append(each[lensesLabels.index(each_label)])
		lenses_dict[each_label] = lenses_list
		lenses_list = []
	# print(lenses_dict)														# 打印字典信息
	lenses_pd = pd.DataFrame(lenses_dict)									# 生成pandas.DataFrame
	# print(lenses_pd)														# 打印pandas.DataFrame
	le = LabelEncoder()														# 创建LabelEncoder()对象,用于序列化
	for col in lenses_pd.columns:											# 序列化
		lenses_pd[col] = le.fit_transform(lenses_pd[col])
	# print(lenses_pd)														# 打印编码信息

	clf = tree.DecisionTreeClassifier(max_depth=None)						# 创建DecisionTreeClassifier()类
	clf = clf.fit(lenses_pd.values.tolist(), lenses_target)					# 使用数据,构建决策树

	dot_data = StringIO()
	tree.export_graphviz(clf, out_file=dot_data,							# 绘制决策树
						feature_names=lenses_pd.keys(),
						class_names=clf.classes_,
						filled=True, rounded=True,
						special_characters=True)

	graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
	# img = graph.write_jpg("tree.jpg")												# 保存绘制好的决策树,以JPG的形式存储。

	result = clf.predict([sample])										# 预测

	return result
# print(Tree_Detection([2, 1, 1, 0]))

def test(image):
    return image


demo = gr.Interface(
	fn=Tree_Detection,
	inputs='text',
	outputs='text'
)
demo.launch()