Spaces:
Runtime error
Runtime error
from sklearn.preprocessing import LabelEncoder, OneHotEncoder | |
from six import StringIO | |
from sklearn import tree | |
import pandas as pd | |
import numpy as np | |
import pydotplus | |
import gradio as gr | |
from PIL import Image | |
def Tree_Detection(noise, rotation, power_up, temp): | |
noise = int(noise) | |
rotation = int(rotation) | |
power_up = int(power_up) | |
temp = int(temp) | |
with open('lenses.txt', 'r') as fr: # 加载文件 | |
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件 | |
lenses_target = [] # 提取每组数据的类别,保存在列表里 | |
# print(lenses) | |
for each in lenses: | |
lenses_target.append(each[-1]) | |
# print(lenses_target) | |
lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签 | |
lenses_list = [] # 保存lenses数据的临时列表 | |
lenses_dict = {} # 保存lenses数据的字典,用于生成pandas | |
for each_label in lensesLabels: # 提取信息,生成字典 | |
for each in lenses: | |
lenses_list.append(each[lensesLabels.index(each_label)]) | |
lenses_dict[each_label] = lenses_list | |
lenses_list = [] | |
# print(lenses_dict) # 打印字典信息 | |
lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame | |
# print(lenses_pd) # 打印pandas.DataFrame | |
le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化 | |
for col in lenses_pd.columns: # 序列化 | |
lenses_pd[col] = le.fit_transform(lenses_pd[col]) | |
# print(lenses_pd) # 打印编码信息 | |
clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类 | |
clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树 | |
dot_data = StringIO() | |
tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树 | |
feature_names=lenses_pd.keys(), | |
class_names=clf.classes_, | |
filled=True, rounded=True, | |
special_characters=True) | |
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) | |
#img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。 | |
sample = [] | |
sample.append(noise) | |
sample.append(rotation) | |
sample.append(power_up) | |
sample.append(temp) | |
result = f'The fault type is : {clf.predict([sample])[0]}' # 预测 | |
image = Image.open("tree.jpg") | |
return result, image | |
# print(Tree_Detection([2, 1, 1, 0])) | |
def test(image): | |
return image | |
demo = gr.Interface( | |
fn=Tree_Detection, | |
inputs=[ | |
gr.components.Textbox(label="noise: dron=0, explosion=1, soundless=2"), | |
gr.components.Textbox(label="rotation: common=0, delay=1"), | |
gr.components.Textbox(label="power-up: no=0, yes=1"), | |
gr.components.Textbox(label="tempreture: high=0, normal=1"), | |
], | |
outputs=['text', 'image'] | |
) | |
demo.launch() | |