Spaces:
Runtime error
Runtime error
File size: 2,248 Bytes
ac71df5 76d02ff 84fbda4 ac71df5 960a75a ac71df5 b0ecf6a ac71df5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from six import StringIO
from sklearn import tree
import pandas as pd
import numpy as np
import pydotplus
import gradio as gr
def Tree_Detection(sample):
sample=list(sample)
with open('lenses.txt', 'r') as fr: # 加载文件
lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
lenses_target = [] # 提取每组数据的类别,保存在列表里
# print(lenses)
for each in lenses:
lenses_target.append(each[-1])
# print(lenses_target)
lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签
lenses_list = [] # 保存lenses数据的临时列表
lenses_dict = {} # 保存lenses数据的字典,用于生成pandas
for each_label in lensesLabels: # 提取信息,生成字典
for each in lenses:
lenses_list.append(each[lensesLabels.index(each_label)])
lenses_dict[each_label] = lenses_list
lenses_list = []
# print(lenses_dict) # 打印字典信息
lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame
# print(lenses_pd) # 打印pandas.DataFrame
le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化
for col in lenses_pd.columns: # 序列化
lenses_pd[col] = le.fit_transform(lenses_pd[col])
# print(lenses_pd) # 打印编码信息
clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类
clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树
feature_names=lenses_pd.keys(),
class_names=clf.classes_,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
# img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
result = clf.predict([sample]) # 预测
return result
# print(Tree_Detection([2, 1, 1, 0]))
def test(image):
return image
demo = gr.Interface(
fn=Tree_Detection,
inputs='text',
outputs='text'
)
demo.launch()
|