File size: 2,903 Bytes
ac71df5
 
 
 
 
 
 
 
 
68c6bf2
ac71df5
 
3b59c19
76d02ff
ccbcf3a
 
6f6c90e
ccbcf3a
b2f8e9c
9f91232
b2f8e9c
ac71df5
b2f8e9c
9f91232
ac71df5
 
b2f8e9c
 
 
 
9f91232
 
 
 
ac71df5
b2f8e9c
ac71df5
b2f8e9c
 
9f91232
ac71df5
 
b2f8e9c
 
ac71df5
b2f8e9c
 
ac71df5
 
 
 
 
b2f8e9c
3a7225e
ccbcf3a
 
 
3b59c19
ccbcf3a
ac71df5
d2fa1da
ac71df5
68c6bf2
 
e087c8d
ac71df5
 
960a75a
 
 
ac71df5
 
b0ecf6a
25c68df
 
 
 
 
 
68c6bf2
ac71df5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from six import StringIO
from sklearn import tree
import pandas as pd
import numpy as np
import pydotplus
import gradio as gr
from PIL import Image


def Tree_Detection(noise, rotation, power_up, temp):
    
    noise = int(noise)
    rotation = int(rotation)
    power_up = int(power_up)
    temp = int(temp)
    with open('lenses.txt', 'r') as fr:										# 加载文件
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]		# 处理文件
    lenses_target = []														# 提取每组数据的类别,保存在列表里
	# print(lenses)
    for each in lenses:
        lenses_target.append(each[-1])
	# print(lenses_target)

    lensesLabels = ['noise', 'rotation', 'power-up', 'temp']			# 特征标签
    lenses_list = []														# 保存lenses数据的临时列表
    lenses_dict = {}														# 保存lenses数据的字典,用于生成pandas
    for each_label in lensesLabels:											# 提取信息,生成字典
        for each in lenses:
            lenses_list.append(each[lensesLabels.index(each_label)])
        lenses_dict[each_label] = lenses_list
        lenses_list = []
	# print(lenses_dict)														# 打印字典信息
    lenses_pd = pd.DataFrame(lenses_dict)									# 生成pandas.DataFrame
	# print(lenses_pd)														# 打印pandas.DataFrame
    le = LabelEncoder()														# 创建LabelEncoder()对象,用于序列化
    for col in lenses_pd.columns:											# 序列化
        lenses_pd[col] = le.fit_transform(lenses_pd[col])
	# print(lenses_pd)														# 打印编码信息

    clf = tree.DecisionTreeClassifier(max_depth=None)						# 创建DecisionTreeClassifier()类
    clf = clf.fit(lenses_pd.values.tolist(), lenses_target)					# 使用数据,构建决策树

    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data,							# 绘制决策树
						feature_names=lenses_pd.keys(),
						class_names=clf.classes_,
						filled=True, rounded=True,
						special_characters=True)

    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    #img = graph.write_jpg("tree.jpg")												# 保存绘制好的决策树,以JPG的形式存储。
    sample = []
    sample.append(noise)
    sample.append(rotation)
    sample.append(power_up)
    sample.append(temp)

    result = f'The fault type is : {clf.predict([sample])[0]}'										# 预测

    image = Image.open("tree.jpg")

    return result, image
# print(Tree_Detection([2, 1, 1, 0]))

def test(image):
    return image


demo = gr.Interface(
	fn=Tree_Detection,
	inputs={
        'input1': gr.inputs.Textbox(label="说明1:"),
        'input2': gr.inputs.Textbox(label="说明2:"),
        'input3': gr.inputs.Textbox(label="说明3:"),
        'input4': gr.inputs.Textbox(label="说明4:"),
    },
	outputs=['text', 'image']
)
demo.launch()