shane666 commited on
Commit
b2f8e9c
·
1 Parent(s): 76d02ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -12,46 +12,46 @@ import gradio as gr
12
  def Tree_Detection(sample):
13
 
14
  sample=list(sample)
15
- with open('lenses.txt', 'r') as fr: # 加载文件
16
  lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
17
- lenses_target = [] # 提取每组数据的类别,保存在列表里
18
  # print(lenses)
19
- for each in lenses:
20
  lenses_target.append(each[-1])
21
  # print(lenses_target)
22
 
23
- lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签
24
- lenses_list = [] # 保存lenses数据的临时列表
25
- lenses_dict = {} # 保存lenses数据的字典,用于生成pandas
26
- for each_label in lensesLabels: # 提取信息,生成字典
27
  for each in lenses:
28
  lenses_list.append(each[lensesLabels.index(each_label)])
29
  lenses_dict[each_label] = lenses_list
30
  lenses_list = []
31
  # print(lenses_dict) # 打印字典信息
32
- lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame
33
  # print(lenses_pd) # 打印pandas.DataFrame
34
- le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化
35
- for col in lenses_pd.columns: # 序列化
36
  lenses_pd[col] = le.fit_transform(lenses_pd[col])
37
  # print(lenses_pd) # 打印编码信息
38
 
39
- clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类
40
- clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树
41
 
42
- dot_data = StringIO()
43
- tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树
44
  feature_names=lenses_pd.keys(),
45
  class_names=clf.classes_,
46
  filled=True, rounded=True,
47
  special_characters=True)
48
 
49
- graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
50
  # img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
51
 
52
- result = clf.predict([sample]) # 预测
53
 
54
- return result
55
  # print(Tree_Detection([2, 1, 1, 0]))
56
 
57
  def test(image):
 
12
  def Tree_Detection(sample):
13
 
14
  sample=list(sample)
15
+ with open('lenses.txt', 'r') as fr: # 加载文件
16
  lenses = [inst.strip().split('\t') for inst in fr.readlines()] # 处理文件
17
+ lenses_target = [] # 提取每组数据的类别,保存在列表里
18
  # print(lenses)
19
+ for each in lenses:
20
  lenses_target.append(each[-1])
21
  # print(lenses_target)
22
 
23
+ lensesLabels = ['noise', 'rotation', 'power-up', 'temp'] # 特征标签
24
+ lenses_list = [] # 保存lenses数据的临时列表
25
+ lenses_dict = {} # 保存lenses数据的字典,用于生成pandas
26
+ for each_label in lensesLabels: # 提取信息,生成字典
27
  for each in lenses:
28
  lenses_list.append(each[lensesLabels.index(each_label)])
29
  lenses_dict[each_label] = lenses_list
30
  lenses_list = []
31
  # print(lenses_dict) # 打印字典信息
32
+ lenses_pd = pd.DataFrame(lenses_dict) # 生成pandas.DataFrame
33
  # print(lenses_pd) # 打印pandas.DataFrame
34
+ le = LabelEncoder() # 创建LabelEncoder()对象,用于序列化
35
+ for col in lenses_pd.columns: # 序列化
36
  lenses_pd[col] = le.fit_transform(lenses_pd[col])
37
  # print(lenses_pd) # 打印编码信息
38
 
39
+ clf = tree.DecisionTreeClassifier(max_depth=None) # 创建DecisionTreeClassifier()类
40
+ clf = clf.fit(lenses_pd.values.tolist(), lenses_target) # 使用数据,构建决策树
41
 
42
+ dot_data = StringIO()
43
+ tree.export_graphviz(clf, out_file=dot_data, # 绘制决策树
44
  feature_names=lenses_pd.keys(),
45
  class_names=clf.classes_,
46
  filled=True, rounded=True,
47
  special_characters=True)
48
 
49
+ graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
50
  # img = graph.write_jpg("tree.jpg") # 保存绘制好的决策树,以JPG的形式存储。
51
 
52
+ result = clf.predict([sample]) # 预测
53
 
54
+ return result
55
  # print(Tree_Detection([2, 1, 1, 0]))
56
 
57
  def test(image):