Tudohuang commited on
Commit
6fe3e70
1 Parent(s): 2ba9206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -151,36 +151,48 @@ demo.launch()
151
 
152
  def generate_heatmap(last_conv_output, pred_class):
153
  """
154
- 生成熱力圖。
155
  """
156
- # 獲取目標類別的輸出
157
- target_output = last_conv_output[0, pred_class]
158
-
159
- # 執行全局平均池化以獲得類別特定的特徵圖
160
- heatmap = torch.mean(target_output, dim=0)
161
-
162
- # 標準化熱力圖
163
  heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
164
  heatmap /= np.max(heatmap)
165
 
 
 
 
166
  return heatmap
167
 
 
168
  def overlay_heatmap(image, heatmap, intensity=0.5, colormap=cv2.COLORMAP_JET):
169
  """
170
- 將熱力圖疊加到原始圖片上。
171
  """
172
- # 將熱力圖調整為與原始圖片相同的大小
173
- heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0])) # 使用 NumPy 數組的尺寸
174
  heatmap = np.uint8(255 * heatmap)
175
 
176
- # 將熱力圖應用到原始圖片上
177
  heatmap = cv2.applyColorMap(heatmap, colormap)
178
- superimposed_img = heatmap * intensity + image # 直接使用 NumPy 數組
 
 
 
 
 
 
 
179
  superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
180
  superimposed_img = Image.fromarray(superimposed_img)
181
 
182
  return superimposed_img
183
 
 
184
  def predict(image):
185
  """
186
  預測並生成熱力圖。
 
151
 
152
  def generate_heatmap(last_conv_output, pred_class):
153
  """
154
+ 生成局部化的热力图。
155
  """
156
+
157
+ class_conv_output = last_conv_output[0, pred_class]
158
+
159
+
160
+ heatmap = torch.mean(class_conv_output, dim=0)
161
+
162
+
163
  heatmap = np.maximum(heatmap.detach().cpu().numpy(), 0)
164
  heatmap /= np.max(heatmap)
165
 
166
+
167
+ heatmap[heatmap < 0.5] = 0 # 只保留高信心区域
168
+
169
  return heatmap
170
 
171
+
172
  def overlay_heatmap(image, heatmap, intensity=0.5, colormap=cv2.COLORMAP_JET):
173
  """
174
+ 调整热力图叠加到原始图像上的逻辑。
175
  """
176
+ # 调整热力图大小
177
+ heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))
178
  heatmap = np.uint8(255 * heatmap)
179
 
180
+ # 应用颜色映射
181
  heatmap = cv2.applyColorMap(heatmap, colormap)
182
+
183
+ # 创建一个只有热力图区域的蒙版
184
+ mask = heatmap > 0
185
+
186
+ # 创建叠加图像
187
+ superimposed_img = image.copy()
188
+ superimposed_img[mask] = superimposed_img[mask] * (1 - intensity) + heatmap[mask] * intensity
189
+
190
  superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
191
  superimposed_img = Image.fromarray(superimposed_img)
192
 
193
  return superimposed_img
194
 
195
+
196
  def predict(image):
197
  """
198
  預測並生成熱力圖。