menghui753 commited on
Commit
cd5967a
·
verified ·
1 Parent(s): 1d25b94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -37
app.py CHANGED
@@ -72,10 +72,19 @@ for model_name, model_path in MODEL_PATHS.items():
72
  if os.path.isfile(model_path):
73
  try:
74
  print(f"正在加载ONNX模型: {model_path}")
75
- # Create ONNX runtime session
76
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
 
 
 
 
 
 
 
 
 
77
  sessions[model_name] = ort.InferenceSession(model_path, providers=providers)
78
- print(f"{model_name} ONNX模型加载成功")
79
  except Exception as e:
80
  print(f"加载 {model_name} ONNX模型出错: {str(e)}")
81
  # Continue loading other models if one fails
@@ -87,10 +96,26 @@ if not sessions:
87
  raise RuntimeError("无法加载任何ONNX模型,请确保至少有一个模型文件是正确的")
88
 
89
  # Load feature extractor for image preprocessing
90
- feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
 
 
 
 
 
 
 
 
91
 
92
  # Define segmentation classes and corresponding colors
93
  CLASSES = {
 
 
 
 
 
 
 
 
94
  0: "背景",
95
  1: "叶片",
96
  2: "黑痣病",
@@ -211,10 +236,7 @@ def segment_image(image, model_choice="Segformer", visualization_type="overlay")
211
 
212
  # 左侧显示原图
213
  ax1.imshow(original_array)
214
- if font_prop:
215
- ax1.set_title("原始图像", fontproperties=font_prop, fontsize=14, pad=10, fontweight='bold')
216
- else:
217
- ax1.set_title("原始图像", fontsize=14, pad=10, fontweight='bold')
218
  ax1.axis('off')
219
  ax1.set_frame_on(True) # 显示边框
220
  ax1.patch.set_edgecolor('lightgray') # 设置边框颜色
@@ -222,10 +244,7 @@ def segment_image(image, model_choice="Segformer", visualization_type="overlay")
222
 
223
  # 右侧显示分割结果
224
  ax2.imshow(overlay_image)
225
- if font_prop:
226
- ax2.set_title("分割结果", fontproperties=font_prop, fontsize=14, pad=10, fontweight='bold')
227
- else:
228
- ax2.set_title("分割结果", fontsize=14, pad=10, fontweight='bold')
229
  ax2.axis('off')
230
  ax2.set_frame_on(True) # 显示边框
231
  ax2.patch.set_edgecolor('lightgray') # 设置边框颜色
@@ -235,19 +254,13 @@ def segment_image(image, model_choice="Segformer", visualization_type="overlay")
235
  if visualization_type == "detailed":
236
  # 热力图显示
237
  heatmap = ax3.imshow(segmentation_map, cmap=CUSTOM_CMAP, interpolation='nearest')
238
- if font_prop:
239
- ax3.set_title("病害热力图", fontproperties=font_prop, fontsize=14, pad=10, fontweight='bold')
240
- else:
241
- ax3.set_title("病害热力图", fontsize=14, pad=10, fontweight='bold')
242
  ax3.axis('off')
243
 
244
  # 添加颜色条
245
  cbar = plt.colorbar(heatmap, ax=ax3, orientation='horizontal', pad=0.05, shrink=0.8)
246
  cbar.set_ticks([0, 1, 2, 3])
247
- if font_prop:
248
- cbar.set_ticklabels([CLASSES[i] for i in range(4)], fontproperties=font_prop)
249
- else:
250
- cbar.set_ticklabels([CLASSES[i] for i in range(4)])
251
 
252
  # 饼图显示各类别占比
253
  non_zero_classes = {k: v for k, v in class_percentages.items() if v > 0.5} # 只显示占比大于0.5%的类别
@@ -268,16 +281,10 @@ def segment_image(image, model_choice="Segformer", visualization_type="overlay")
268
  )
269
 
270
  # 设置饼图标题
271
- if font_prop:
272
- ax4.set_title("类别分布", fontproperties=font_prop, fontsize=14, pad=10, fontweight='bold')
273
- else:
274
- ax4.set_title("类别分布", fontsize=14, pad=10, fontweight='bold')
275
 
276
  # 自定义饼图图例
277
- if font_prop:
278
- ax4.legend(wedges, labels, title="类别", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1), prop=font_prop)
279
- else:
280
- ax4.legend(wedges, labels, title="类别", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
281
 
282
  # 添加美观的图例(对于非详细视图)
283
  if visualization_type != "detailed":
@@ -291,16 +298,11 @@ def segment_image(image, model_choice="Segformer", visualization_type="overlay")
291
  legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color, label=legend_label))
292
 
293
  # 在图像下方添加图例
294
- if font_prop:
295
- fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.02),
296
- fancybox=True, shadow=True, ncol=len(legend_elements), fontsize=10, prop=font_prop)
297
- else:
298
- fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.02),
299
- fancybox=True, shadow=True, ncol=len(legend_elements), fontsize=10)
300
 
301
  # 添加水印
302
- fig.text(0.99, 0.01, '植物叶片病害分割系统', fontsize=8, color='gray', ha='right', va='bottom', alpha=0.7,
303
- fontproperties=font_prop if font_prop else None)
304
 
305
  # 调整布局,确保图例不会被裁剪
306
  plt.tight_layout()
@@ -356,4 +358,4 @@ iface = gr.Interface(
356
  ".gr-image {border-radius: 8px; border: 1px solid #ddd;}"
357
  )
358
 
359
- iface.launch(share=True)
 
72
  if os.path.isfile(model_path):
73
  try:
74
  print(f"正在加载ONNX模型: {model_path}")
75
+ # 检查可用的执行提供程序
76
+ available_providers = ort.get_available_providers()
77
+ providers = []
78
+
79
+ # 优先使用CUDA,如果可用
80
+ if 'CUDAExecutionProvider' in available_providers:
81
+ providers.append('CUDAExecutionProvider')
82
+
83
+ # 总是添加CPU作为后备
84
+ providers.append('CPUExecutionProvider')
85
+
86
  sessions[model_name] = ort.InferenceSession(model_path, providers=providers)
87
+ print(f"{model_name} ONNX模型加载成功,使用提供程序: {providers}")
88
  except Exception as e:
89
  print(f"加载 {model_name} ONNX模型出错: {str(e)}")
90
  # Continue loading other models if one fails
 
96
  raise RuntimeError("无法加载任何ONNX模型,请确保至少有一个模型文件是正确的")
97
 
98
  # Load feature extractor for image preprocessing
99
+ try:
100
+ # 尝试使用新的API (如果transformers版本支持)
101
+ from transformers import AutoImageProcessor
102
+ feature_extractor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
103
+ print("使用AutoImageProcessor加载特征提取器")
104
+ except (ImportError, AttributeError):
105
+ # 回退到旧API
106
+ feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
107
+ print("使用SegformerFeatureExtractor加载特征提取器")
108
 
109
  # Define segmentation classes and corresponding colors
110
  CLASSES = {
111
+ 0: "Background",
112
+ 1: "Leaf",
113
+ 2: "Black Spot",
114
+ 3: "Black Rot"
115
+ }
116
+
117
+ # 中文类别名称 (用于UI界面)
118
+ CLASSES_CN = {
119
  0: "背景",
120
  1: "叶片",
121
  2: "黑痣病",
 
236
 
237
  # 左侧显示原图
238
  ax1.imshow(original_array)
239
+ ax1.set_title("Original Image", fontsize=14, pad=10, fontweight='bold')
 
 
 
240
  ax1.axis('off')
241
  ax1.set_frame_on(True) # 显示边框
242
  ax1.patch.set_edgecolor('lightgray') # 设置边框颜色
 
244
 
245
  # 右侧显示分割结果
246
  ax2.imshow(overlay_image)
247
+ ax2.set_title("Segmentation Result", fontsize=14, pad=10, fontweight='bold')
 
 
 
248
  ax2.axis('off')
249
  ax2.set_frame_on(True) # 显示边框
250
  ax2.patch.set_edgecolor('lightgray') # 设置边框颜色
 
254
  if visualization_type == "detailed":
255
  # 热力图显示
256
  heatmap = ax3.imshow(segmentation_map, cmap=CUSTOM_CMAP, interpolation='nearest')
257
+ ax3.set_title("Disease Heatmap", fontsize=14, pad=10, fontweight='bold')
 
 
 
258
  ax3.axis('off')
259
 
260
  # 添加颜色条
261
  cbar = plt.colorbar(heatmap, ax=ax3, orientation='horizontal', pad=0.05, shrink=0.8)
262
  cbar.set_ticks([0, 1, 2, 3])
263
+ cbar.set_ticklabels([CLASSES[i] for i in range(4)])
 
 
 
264
 
265
  # 饼图显示各类别占比
266
  non_zero_classes = {k: v for k, v in class_percentages.items() if v > 0.5} # 只显示占比大于0.5%的类别
 
281
  )
282
 
283
  # 设置饼图标题
284
+ ax4.set_title("Class Distribution", fontsize=14, pad=10, fontweight='bold')
 
 
 
285
 
286
  # 自定义饼图图例
287
+ ax4.legend(wedges, labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
 
 
 
288
 
289
  # 添加美观的图例(对于非详细视图)
290
  if visualization_type != "detailed":
 
298
  legend_elements.append(plt.Rectangle((0, 0), 1, 1, color=color, label=legend_label))
299
 
300
  # 在图像下方添加图例
301
+ fig.legend(handles=legend_elements, loc='lower center', bbox_to_anchor=(0.5, 0.02),
302
+ fancybox=True, shadow=True, ncol=len(legend_elements), fontsize=10)
 
 
 
 
303
 
304
  # 添加水印
305
+ fig.text(0.99, 0.01, 'Plant Leaf Disease Segmentation System', fontsize=8, color='gray', ha='right', va='bottom', alpha=0.7)
 
306
 
307
  # 调整布局,确保图例不会被裁剪
308
  plt.tight_layout()
 
358
  ".gr-image {border-radius: 8px; border: 1px solid #ddd;}"
359
  )
360
 
361
+ iface.launch()