InPeerReview commited on
Commit
032c113
·
verified ·
1 Parent(s): b0481db

Upload 9 files

Browse files
Files changed (9) hide show
  1. eval_tables_only.py +351 -0
  2. five_connect.py +61 -0
  3. grad_cam_CNN.py +74 -0
  4. mask_connect_test.py +54 -0
  5. params_flops.py +53 -0
  6. requirements.txt +13 -0
  7. test.py +128 -0
  8. train.py +389 -0
  9. utilss.py +249 -0
eval_tables_only.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from tqdm import tqdm
4
+ import prettytable
5
+ import time
6
+ import os
7
+ import multiprocessing.pool as mpp
8
+ import multiprocessing as mp
9
+
10
+ from train import *
11
+
12
+ import argparse
13
+ from utils.config import Config
14
+ from tools.mask_convert import mask_save
15
+ import numpy as np # [PR] for histogram-based PR accumulation
16
+ import csv
17
+
18
+ # =========================== [PR] Utilities BEGIN ===========================
19
+ class PRHistogram:
20
+ # Memory-friendly PR accumulator. Call update(probs, mask) repeatedly inside
21
+ # your test loop, then call export_csv(path) after the loop.
22
+ # - probs: torch.Tensor in [0,1], shape [B,H,W], "change" probability
23
+ # - mask: torch.Tensor of 0/1 (or 0/255), shape [B,H,W]
24
+ def __init__(self, nbins: int = 1000):
25
+ import numpy as _np
26
+ self.nbins = int(nbins)
27
+ self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64)
28
+ self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64)
29
+ self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1)
30
+
31
+ def update(self, probs, mask):
32
+ import numpy as _np
33
+ p = probs.detach().float().cpu().numpy().ravel()
34
+ g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8)
35
+ pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges)
36
+ neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges)
37
+ self.pos_hist += pos_counts
38
+ self.neg_hist += neg_counts
39
+
40
+ def compute_curve(self):
41
+ import numpy as _np
42
+ # 累加得到从高阈值到低阈值的 TP/FP
43
+ pos_cum = _np.cumsum(self.pos_hist[::-1])
44
+ neg_cum = _np.cumsum(self.neg_hist[::-1])
45
+ TP = pos_cum
46
+ FP = neg_cum
47
+ FN = self.pos_hist.sum() - TP
48
+ TN = None # 曲线里用不到 TN
49
+
50
+ denom_prec = _np.maximum(TP + FP, 1)
51
+ denom_rec = _np.maximum(TP + FN, 1)
52
+ precision = TP / denom_prec
53
+ recall = TP / denom_rec
54
+
55
+ # F1 = 2PR/(P+R)
56
+ denom_f1 = _np.maximum(precision + recall, 1e-12)
57
+ f1 = 2.0 * precision * recall / denom_f1
58
+
59
+ # IoU = TP / (TP + FP + FN)
60
+ denom_iou = _np.maximum(TP + FP + FN, 1)
61
+ iou = TP / denom_iou
62
+
63
+ thresholds = self.bin_edges[::-1][1:] # 与上述累积方向一致的阈值序列
64
+ return thresholds, precision, recall, f1, iou, TP, FP, FN
65
+
66
+ def export_csv(self, save_path: str):
67
+ thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve()
68
+ import numpy as _np, os as _os
69
+ _os.makedirs(_os.path.dirname(save_path), exist_ok=True)
70
+ _np.savetxt(
71
+ save_path,
72
+ _np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]),
73
+ delimiter=",",
74
+ header="threshold,precision,recall,f1,iou,TP,FP,FN",
75
+ comments=""
76
+ )
77
+ return save_path
78
+
79
+ # Global PR object (create when needed)
80
+ _PR = None
81
+
82
+ def pr_init(nbins: int = 1000):
83
+ global _PR
84
+ if _PR is None:
85
+ _PR = PRHistogram(nbins=nbins)
86
+ return _PR
87
+
88
+ def pr_update_from_outputs(raw_predictions, mask, cfg):
89
+ # Try to derive probs ∈ [0,1] from various model outputs in this repo.
90
+ # This covers:
91
+ # - cfg.argmax=True: 2-channel logits -> softmax class-1 prob
92
+ # - single-channel logits -> sigmoid
93
+ # - net == 'maskcd' (list/tuple outputs)
94
+ # Modify here if your network has a special head.
95
+ import torch
96
+ global _PR
97
+ if _PR is None:
98
+ _PR = PRHistogram(nbins=1000)
99
+
100
+ if getattr(cfg, 'argmax', False):
101
+ logits = raw_predictions
102
+ if logits.dim() == 4 and logits.size(1) >= 2:
103
+ probs = torch.softmax(logits, dim=1)[:, 1, :, :]
104
+ else:
105
+ probs = torch.sigmoid(logits.squeeze(1))
106
+ else:
107
+ if getattr(cfg, 'net', '') == 'maskcd':
108
+ if isinstance(raw_predictions, (list, tuple)):
109
+ logits = raw_predictions[0]
110
+ else:
111
+ logits = raw_predictions
112
+ probs = torch.sigmoid(logits).squeeze(1)
113
+ else:
114
+ logits = raw_predictions
115
+ if logits.dim() == 4 and logits.size(1) == 1:
116
+ logits = logits.squeeze(1)
117
+ probs = torch.sigmoid(logits)
118
+
119
+ if mask.dim() == 4 and mask.size(1) == 1:
120
+ mask_ = mask.squeeze(1)
121
+ else:
122
+ mask_ = mask
123
+ _PR.update(probs, (mask_ > 0).to(probs.dtype))
124
+
125
+ def pr_export(base_dir: str, cfg):
126
+ # Export PR CSV to base_dir/pr_<net>.csv
127
+ import os
128
+ global _PR
129
+ if _PR is None:
130
+ return None
131
+ save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv")
132
+ out = _PR.export_csv(save_path)
133
+ print(f"[PR] saved: {out}")
134
+ return out
135
+ # ============================ [PR] Utilities END ============================
136
+
137
+ # -------------------- [Per-Image] 逐图指标工具 --------------------
138
+ def _safe_div(a, b, eps=1e-12):
139
+ return a / max(b, eps)
140
+
141
+ def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray):
142
+ """
143
+ pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W]
144
+ 返回: dict 包含 TP/FP/TN/FN 与各类指标
145
+ """
146
+ pred_bin = (pred_np > 0).astype(np.uint8)
147
+ gt_bin = (gt_np > 0).astype(np.uint8)
148
+
149
+ TP = int(((pred_bin == 1) & (gt_bin == 1)).sum())
150
+ FP = int(((pred_bin == 1) & (gt_bin == 0)).sum())
151
+ TN = int(((pred_bin == 0) & (gt_bin == 0)).sum())
152
+ FN = int(((pred_bin == 0) & (gt_bin == 1)).sum())
153
+
154
+ precision = _safe_div(TP, (TP + FP))
155
+ recall = _safe_div(TP, (TP + FN))
156
+ f1 = _safe_div(2 * precision * recall, (precision + recall))
157
+ iou = _safe_div(TP, (TP + FP + FN))
158
+ oa = _safe_div(TP + TN, (TP + TN + FP + FN))
159
+
160
+ return {
161
+ "TP": TP, "FP": FP, "TN": TN, "FN": FN,
162
+ "OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou
163
+ }
164
+ # --------------------------------------------------------------------
165
+
166
+ def get_args():
167
+ parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
168
+ parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
169
+ parser.add_argument("--ckpt", type=str, default=None)
170
+ parser.add_argument("--output_dir", type=str, default=None)
171
+ # 新增:仅生成表格模式(不导出可视化图片)
172
+ parser.add_argument("--tables-only", action="store_true",
173
+ help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片")
174
+ return parser.parse_args()
175
+
176
+ if __name__ == "__main__":
177
+ args = get_args()
178
+ cfg = Config.fromfile(args.config)
179
+
180
+ ckpt = args.ckpt
181
+ if ckpt is None:
182
+ ckpt = cfg.test_ckpt_path
183
+ assert ckpt is not None
184
+
185
+ if args.output_dir:
186
+ base_dir = args.output_dir
187
+ else:
188
+ base_dir = os.path.dirname(ckpt)
189
+
190
+ # 原图像输出目录(仅在需要写图时使用)
191
+ masks_output_dir = os.path.join(base_dir, "mask_rgb")
192
+ # 表格输出目录(逐图表格 .txt),如果 tables-only 则单独放在 tables_only 下
193
+ tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb")
194
+ os.makedirs(tables_output_dir, exist_ok=True)
195
+
196
+ model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
197
+ model = model.to('cuda')
198
+ model.eval()
199
+
200
+ metric_cfg_1 = cfg.metric_cfg1
201
+ metric_cfg_2 = cfg.metric_cfg2
202
+
203
+ test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
204
+ test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
205
+ test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
206
+ test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
207
+ test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
208
+
209
+ results = [] # 仅在生成图片时使用
210
+ per_image_rows = [] # [Per-Image] 收集逐图指标
211
+
212
+ with torch.no_grad():
213
+ test_loader = build_dataloader(cfg.dataset_config, mode='test')
214
+ # === 调用1: 初始化 ===
215
+ pr_init(nbins=1000)
216
+
217
+ for input in tqdm(test_loader):
218
+ raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
219
+ # === 调用2: 更新 ===
220
+ pr_update_from_outputs(raw_predictions, mask, cfg)
221
+
222
+ if cfg.net == 'SARASNet':
223
+ mask = Variable(resize_label(mask.data.cpu().numpy(), \
224
+ size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
225
+ param = 1 # This parameter is balance precision and recall to get higher F1-score
226
+ raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
227
+
228
+ if cfg.argmax:
229
+ pred = raw_predictions.argmax(dim=1)
230
+ else:
231
+ if cfg.net == 'maskcd':
232
+ pred = raw_predictions[0]
233
+ pred = pred > 0.5
234
+ pred.squeeze_(1)
235
+ else:
236
+ pred = raw_predictions.squeeze(1)
237
+ pred = pred > 0.5
238
+
239
+ # ====== 累计整体验证指标 ======
240
+ test_oa(pred, mask)
241
+ test_iou(pred, mask)
242
+ test_prec(pred, mask)
243
+ test_f1(pred, mask)
244
+ test_recall(pred, mask)
245
+
246
+ # ====== [Per-Image] 逐图指标计算与收集 ======
247
+ for i in range(raw_predictions.shape[0]):
248
+ mask_real = mask[i].detach().cpu().numpy()
249
+ mask_pred = pred[i].detach().cpu().numpy()
250
+ mask_name = str(img_id[i])
251
+
252
+ # 逐图统计
253
+ stats = per_image_stats(mask_pred, mask_real)
254
+ per_image_rows.append({
255
+ "img_id": mask_name,
256
+ "TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"],
257
+ "OA": stats["OA"], "Precision": stats["Precision"],
258
+ "Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"]
259
+ })
260
+
261
+ # 仅在需要生成可视化图片时才收集写图任务
262
+ if not args.tables_only:
263
+ results.append((mask_real, mask_pred, masks_output_dir, mask_name))
264
+
265
+ # ====== 打印总体指标 ======
266
+ metrics = [test_prec.compute(),
267
+ test_recall.compute(),
268
+ test_f1.compute(),
269
+ test_iou.compute()]
270
+
271
+ total_metrics = [test_oa.compute().cpu().numpy(),
272
+ np.mean([item.cpu() for item in metrics[0]]),
273
+ np.mean([item.cpu() for item in metrics[1]]),
274
+ np.mean([item.cpu() for item in metrics[2]]),
275
+ np.mean([item.cpu() for item in metrics[3]])]
276
+
277
+ result_table = prettytable.PrettyTable()
278
+ result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
279
+
280
+ for i in range(2):
281
+ item = [i, '--']
282
+ for j in range(len(metrics)):
283
+ item.append(np.round(metrics[j][i].cpu().numpy(), 4))
284
+ result_table.add_row(item)
285
+
286
+ total = [np.round(v, 4) for v in total_metrics]
287
+ total.insert(0, 'total')
288
+ result_table.add_row(total)
289
+ print(result_table)
290
+
291
+ file_name = os.path.join(base_dir, "test_res.txt")
292
+ f = open(file_name,"a")
293
+ current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
294
+ f.write(current_time+'\n')
295
+ f.write(str(result_table)+'\n')
296
+
297
+ # ====== 根据模式选择是否写图 ======
298
+ if not args.tables_only:
299
+ if not os.path.exists(masks_output_dir):
300
+ os.makedirs(masks_output_dir)
301
+ print(masks_output_dir)
302
+
303
+ # 多进程写图
304
+ t0 = time.time()
305
+ mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
306
+ t1 = time.time()
307
+ img_write_time = t1 - t0
308
+ print('images writing spends: {} s'.format(img_write_time))
309
+ else:
310
+ print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。")
311
+
312
+ # ====== [Per-Image] 将逐图指标写成一个总 CSV ======
313
+ per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv")
314
+ with open(per_image_csv, "w", newline="") as wf:
315
+ writer = csv.DictWriter(
316
+ wf,
317
+ fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"]
318
+ )
319
+ writer.writeheader()
320
+ for row in per_image_rows:
321
+ row_out = dict(row)
322
+ for k in ["OA","Precision","Recall","F1","IoU"]:
323
+ row_out[k] = float(np.round(row_out[k], 6))
324
+ writer.writerow(row_out)
325
+ print(f"[Per-Image] saved CSV: {per_image_csv}")
326
+
327
+ # ====== [Per-Image] 为每张图各自写一个小表(.txt) ======
328
+ for row in per_image_rows:
329
+ txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt")
330
+ pt = prettytable.PrettyTable()
331
+ pt.field_names = ["Metric", "Value"]
332
+ # 先放混淆矩阵元素
333
+ pt.add_row(["TP", row["TP"]])
334
+ pt.add_row(["FP", row["FP"]])
335
+ pt.add_row(["TN", row["TN"]])
336
+ pt.add_row(["FN", row["FN"]])
337
+ # 再放比率类指标
338
+ pt.add_row(["OA", f"{row['OA']:.6f}"])
339
+ pt.add_row(["Precision",f"{row['Precision']:.6f}"])
340
+ pt.add_row(["Recall", f"{row['Recall']:.6f}"])
341
+ pt.add_row(["F1", f"{row['F1']:.6f}"])
342
+ pt.add_row(["IoU", f"{row['IoU']:.6f}"])
343
+ with open(txt_path, "w") as wf:
344
+ wf.write(str(pt))
345
+ print(f"[Per-Image] per-image tables saved to: {tables_output_dir}")
346
+
347
+ # ===== [PR] Export at program end =====
348
+ try:
349
+ pr_export(base_dir, cfg)
350
+ except Exception as e:
351
+ print(f"[PR] export skipped or failed: {e}")
five_connect.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+
6
+ def concat_image_heatmap(img1_path, img2_path, label_path, mask_path, heatmap_path, output_path):
7
+ img1 = cv2.imread(img1_path)
8
+ img2 = cv2.imread(img2_path)
9
+ mask = cv2.imread(mask_path)
10
+ heatmap = cv2.imread(heatmap_path)
11
+ label = cv2.imread(label_path) if label_path and os.path.exists(label_path) else None
12
+
13
+ if img1 is None or img2 is None or mask is None or heatmap is None:
14
+ print(f"❌ Missing image: {img1_path}, {img2_path}, {mask_path}, {heatmap_path}")
15
+ return
16
+
17
+ h, w = img1.shape[:2]
18
+ img2 = cv2.resize(img2, (w, h))
19
+ mask = cv2.resize(mask, (w, h))
20
+ heatmap = cv2.resize(heatmap, (w, h))
21
+ label = cv2.resize(label, (w, h)) if label is not None else np.zeros_like(img1)
22
+
23
+ top_row = np.concatenate([img1, img2, label], axis=1)
24
+ bottom_row = np.concatenate([mask, heatmap], axis=1)
25
+
26
+ # 补齐对齐
27
+ max_width = max(top_row.shape[1], bottom_row.shape[1])
28
+ if top_row.shape[1] < max_width:
29
+ pad = max_width - top_row.shape[1]
30
+ top_row = cv2.copyMakeBorder(top_row, 0, 0, 0, pad, cv2.BORDER_CONSTANT, value=0)
31
+ if bottom_row.shape[1] < max_width:
32
+ pad = max_width - bottom_row.shape[1]
33
+ bottom_row = cv2.copyMakeBorder(bottom_row, 0, 0, 0, pad, cv2.BORDER_CONSTANT, value=0)
34
+
35
+ full_image = np.concatenate([top_row, bottom_row], axis=0)
36
+ cv2.imwrite(output_path, full_image)
37
+ print(f"✅ Saved: {output_path}")
38
+
39
+ def batch_process(img1_dir, img2_dir, label_dir, mask_dir, heatmap_dir, output_dir):
40
+ os.makedirs(output_dir, exist_ok=True)
41
+ img1_paths = glob.glob(os.path.join(img1_dir, "*.png"))
42
+
43
+ for img1_path in img1_paths:
44
+ filename = os.path.basename(img1_path)
45
+ img2_path = os.path.join(img2_dir, filename)
46
+ label_path = os.path.join(label_dir, filename) if label_dir else None
47
+ mask_path = os.path.join(mask_dir, filename)
48
+ heatmap_path = os.path.join(heatmap_dir, filename)
49
+ output_path = os.path.join(output_dir, filename.replace(".png", "_full.png"))
50
+
51
+ concat_image_heatmap(img1_path, img2_path, label_path, mask_path, heatmap_path, output_path)
52
+
53
+ # 设置路径
54
+ img1_dir = "data/WHU_CD/test/image1"
55
+ img2_dir = "data/WHU_CD/test/image2"
56
+ label_dir = "data/WHU_CD/test/label" # 可设为 None
57
+ mask_dir = "mask_connect_test_dir/mask_rgb"
58
+ heatmap_dir = "mask_connect_test_dir/grad_cam/model.net.decoderhead.LHBlock2"
59
+ output_dir = "mask_heatmap_concat_dir"
60
+
61
+ batch_process(img1_dir, img2_dir, label_dir, mask_dir, heatmap_dir, output_dir)
grad_cam_CNN.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append('.')
4
+
5
+ import matplotlib.pyplot as plt
6
+ from utilss import GradCAM, show_cam_on_image, center_crop_img
7
+
8
+ import argparse
9
+ from utils.config import Config
10
+ from train import *
11
+
12
+ def get_args():
13
+ parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
14
+ parser.add_argument("-c", "--config", type=str, default="configs\cdxformer.py")
15
+ parser.add_argument("--output_dir", default=None)
16
+ parser.add_argument("--layer", default=None)
17
+ return parser.parse_args()
18
+
19
+ def main():
20
+ args = get_args()
21
+
22
+ if args.layer == None:
23
+ raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
24
+
25
+ cfg = Config.fromfile(args.config)
26
+
27
+ model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
28
+ model = model.to('cuda')
29
+
30
+ # print(dict(model.named_modules()).keys())
31
+
32
+ test_loader = build_dataloader(cfg.dataset_config, mode='test')
33
+
34
+ if args.output_dir:
35
+ base_dir = args.output_dir
36
+ else:
37
+ base_dir = os.path.dirname(cfg.test_ckpt_path)
38
+ gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer)
39
+ if os.path.exists(gradcam_output_dir):
40
+ raise NameError("Please ensure gradcam_output_dir does not exist!")
41
+
42
+ os.makedirs(gradcam_output_dir)
43
+
44
+ for input in tqdm(test_loader):
45
+ target_layers = [eval(args.layer)] # name of the network layer
46
+ mask, img_id = input[2].cuda(), input[3]
47
+
48
+ cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True)
49
+ target_category = 1 # tabby, tabby cat
50
+
51
+ grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
52
+
53
+ for i in range(grayscale_cam_all.shape[0]):
54
+ grayscale_cam = grayscale_cam_all[i, :]
55
+ visualization = show_cam_on_image(0,
56
+ grayscale_cam,
57
+ use_rgb=True)
58
+ fig = plt.figure()
59
+ ax = fig.add_subplot(111)
60
+ ax.imshow(visualization)
61
+ # ax = fig.add_subplot(122)
62
+ # ax.imshow(mask[i].cpu().numpy())
63
+ ax.set_xticks([])
64
+ ax.set_yticks([])
65
+ ax.spines['top'].set_visible(False)
66
+ ax.spines['right'].set_visible(False)
67
+ ax.spines['bottom'].set_visible(False)
68
+ ax.spines['left'].set_visible(False)
69
+ plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
70
+ plt.close()
71
+
72
+
73
+ if __name__ == '__main__':
74
+ main()
mask_connect_test.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import glob
4
+ import numpy as np
5
+
6
+ def concat_change_detection_images(img1_path, img2_path, label_path, pred_path, output_path):
7
+ img1 = cv2.imread(img1_path)
8
+ img2 = cv2.imread(img2_path)
9
+ label = cv2.imread(label_path) if os.path.exists(label_path) else None
10
+ pred = cv2.imread(pred_path)
11
+
12
+ if img1 is None or img2 is None or pred is None:
13
+ print(f"Missing or unreadable image: {img1_path}, {img2_path}, {pred_path}")
14
+ return
15
+
16
+ # resize 所有图片为相同大小(以 img1 为基准)
17
+ h, w = img1.shape[:2]
18
+ img2 = cv2.resize(img2, (w, h))
19
+ pred = cv2.resize(pred, (w, h))
20
+ if label is not None:
21
+ label = cv2.resize(label, (w, h))
22
+
23
+ # 组合图像(无 label 时跳过)
24
+ if label is not None:
25
+ concat = np.concatenate([img1, img2, label, pred], axis=1)
26
+ else:
27
+ concat = np.concatenate([img1, img2, pred], axis=1)
28
+
29
+ cv2.imwrite(output_path, concat)
30
+
31
+ def batch_process(img1_dir, img2_dir, label_dir, pred_dir, output_dir):
32
+ os.makedirs(output_dir, exist_ok=True)
33
+ img1_paths = glob.glob(os.path.join(img1_dir, "*.png"))
34
+ for img1_path in img1_paths:
35
+ filename = os.path.basename(img1_path)
36
+ img2_path = os.path.join(img2_dir, filename)
37
+ label_path = os.path.join(label_dir, filename) if label_dir else None
38
+ pred_path = os.path.join(pred_dir, filename)
39
+ output_path = os.path.join(output_dir, filename.replace(".png", "_concat.png"))
40
+
41
+ print(f"[INFO] img1: {img1_path}, img2: {img2_path}")
42
+ print(f"[INFO] label: {label_path}, pred: {pred_path}")
43
+
44
+ concat_change_detection_images(img1_path, img2_path, label_path, pred_path, output_path)
45
+ print(f"Saved: {output_path}")
46
+
47
+ # 设置路径
48
+ img1_dir = "data/WHU_CD/test/image1"
49
+ img2_dir = "data/WHU_CD/test/image2"
50
+ label_dir = "data/WHU_CD/test/label" # 如果没有标签图可以设为 None
51
+ pred_dir = "work_dirs/CLCD_BS4_epoch200/CDXFormer/version_0/ckpts/test/mask_rgb"
52
+ output_dir = "mask_connect_test_dir"
53
+
54
+ batch_process(img1_dir, img2_dir, label_dir, pred_dir, output_dir)
params_flops.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('.')
3
+ from train import *
4
+ from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count, parameter_count
5
+ from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit
6
+
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description='count params and flops')
9
+ parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
10
+ parser.add_argument("--size", type=int, default=256)
11
+ args = parser.parse_args()
12
+ return args
13
+
14
+ def flops_mamba(model, shape=(3, 224, 224)):
15
+ # shape = self.__input_shape__[1:]
16
+ supported_ops = {
17
+ "aten::silu": None, # as relu is in _IGNORED_OPS
18
+ "aten::neg": None, # as relu is in _IGNORED_OPS
19
+ "aten::exp": None, # as relu is in _IGNORED_OPS
20
+ "aten::flip": None, # as permute is in _IGNORED_OPS
21
+ # "prim::PythonOp.CrossScan": None,
22
+ # "prim::PythonOp.CrossMerge": None,
23
+ "prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
24
+ "prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
25
+ "prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
26
+ "prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
27
+ "prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit,
28
+ }
29
+
30
+ model.cuda().eval()
31
+
32
+ input1 = torch.randn((1, *shape), device=next(model.parameters()).device)
33
+ input2 = torch.randn((1, *shape), device=next(model.parameters()).device)
34
+ params = parameter_count(model)[""]
35
+ Gflops, unsupported = flop_count(model=model, inputs=(input1,input2), supported_ops=supported_ops)
36
+
37
+ del model, input1, input2
38
+ # return sum(Gflops.values()) * 1e9
39
+ return f"params {params / 1e6} GFLOPs {sum(Gflops.values())}"
40
+
41
+ if __name__ == "__main__":
42
+ args = parse_args()
43
+ cfg = Config.fromfile(args.config)
44
+ net = myTrain(cfg).net.cuda()
45
+
46
+ size = args.size
47
+ input = torch.rand((1, 3, size, size)).cuda()
48
+
49
+ net.eval()
50
+ flops = FlopCountAnalysis(net, (input, input))
51
+ print(flop_count_table(flops, max_depth = 2))
52
+
53
+ print(flops_mamba(net, (3, size, size)))
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchmetrics==0.11.4
2
+ pytorch-lightning==2.0.6
3
+ scikit-image==0.21.0
4
+
5
+ catalyst==20.9
6
+ albumentations==1.3.1
7
+ ttach==0.0.3
8
+ einops==0.6.1
9
+ timm==0.6.7
10
+ addict==2.4.0
11
+ soundfile==0.12.1
12
+ prettytable==3.8.0
13
+ fvcore
test.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from tqdm import tqdm
4
+ import prettytable
5
+ import time
6
+ import os
7
+ import multiprocessing.pool as mpp
8
+ import multiprocessing as mp
9
+
10
+ from train import *
11
+
12
+ import argparse
13
+ from utils.config import Config
14
+ from tools.mask_convert import mask_save
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
18
+ parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
19
+ parser.add_argument("--ckpt", type=str, default=None)
20
+ parser.add_argument("--output_dir", type=str, default=None)
21
+ return parser.parse_args()
22
+
23
+ if __name__ == "__main__":
24
+ args = get_args()
25
+ cfg = Config.fromfile(args.config)
26
+
27
+ ckpt = args.ckpt
28
+ if ckpt is None:
29
+ ckpt = cfg.test_ckpt_path
30
+ assert ckpt is not None
31
+
32
+ if args.output_dir:
33
+ base_dir = args.output_dir
34
+ else:
35
+ base_dir = os.path.dirname(ckpt)
36
+ masks_output_dir = os.path.join(base_dir, "mask_rgb")
37
+
38
+ model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
39
+ model = model.to('cuda')
40
+
41
+ model.eval()
42
+
43
+ metric_cfg_1 = cfg.metric_cfg1
44
+ metric_cfg_2 = cfg.metric_cfg2
45
+
46
+ test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
47
+ test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
48
+ test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
49
+ test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
50
+ test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
51
+
52
+ results = []
53
+ with torch.no_grad():
54
+ test_loader = build_dataloader(cfg.dataset_config, mode='test')
55
+ for input in tqdm(test_loader):
56
+
57
+ raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
58
+
59
+ if cfg.net == 'SARASNet':
60
+ mask = Variable(resize_label(mask.data.cpu().numpy(), \
61
+ size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
62
+ param = 1 # This parameter is balance precision and recall to get higher F1-score
63
+ raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
64
+
65
+ if cfg.argmax:
66
+ pred = raw_predictions.argmax(dim=1)
67
+ else:
68
+ if cfg.net == 'maskcd':
69
+ pred = raw_predictions[0]
70
+ pred = pred > 0.5
71
+ pred.squeeze_(1)
72
+ else:
73
+ pred = raw_predictions.squeeze(1)
74
+ pred = pred > 0.5
75
+
76
+ test_oa(pred, mask)
77
+ test_iou(pred, mask)
78
+ test_prec(pred, mask)
79
+ test_f1(pred, mask)
80
+ test_recall(pred, mask)
81
+
82
+ for i in range(raw_predictions.shape[0]):
83
+ mask_real = mask[i].cpu().numpy()
84
+ mask_pred = pred[i].cpu().numpy()
85
+ mask_name = str(img_id[i])
86
+ results.append((mask_real, mask_pred, masks_output_dir, mask_name))
87
+
88
+ metrics = [test_prec.compute(),
89
+ test_recall.compute(),
90
+ test_f1.compute(),
91
+ test_iou.compute()]
92
+
93
+ total_metrics = [test_oa.compute().cpu().numpy(),
94
+ np.mean([item.cpu() for item in metrics[0]]),
95
+ np.mean([item.cpu() for item in metrics[1]]),
96
+ np.mean([item.cpu() for item in metrics[2]]),
97
+ np.mean([item.cpu() for item in metrics[3]])]
98
+
99
+ result_table = prettytable.PrettyTable()
100
+ result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
101
+
102
+ for i in range(2):
103
+ item = [i, '--']
104
+ for j in range(len(metrics)):
105
+ item.append(np.round(metrics[j][i].cpu().numpy(), 4))
106
+ result_table.add_row(item)
107
+
108
+ total = [np.round(v, 4) for v in total_metrics]
109
+ total.insert(0, 'total')
110
+ result_table.add_row(total)
111
+
112
+ print(result_table)
113
+
114
+ file_name = os.path.join(base_dir, "test_res.txt")
115
+ f = open(file_name,"a")
116
+ current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
117
+ f.write(current_time+'\n')
118
+ f.write(str(result_table)+'\n')
119
+
120
+ if not os.path.exists(masks_output_dir):
121
+ os.makedirs(masks_output_dir)
122
+ print(masks_output_dir)
123
+
124
+ t0 = time.time()
125
+ mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
126
+ t1 = time.time()
127
+ img_write_time = t1 - t0
128
+ print('images writing spends: {} s'.format(img_write_time))
train.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from pytorch_lightning import LightningModule, Trainer, seed_everything
4
+ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, TQDMProgressBar
5
+ from pytorch_lightning.loggers import TensorBoardLogger
6
+ import torchmetrics
7
+ from tqdm import tqdm
8
+
9
+ import prettytable
10
+ import numpy as np
11
+ import argparse
12
+ from rscd.models.build_model import build_model
13
+ from rscd.datasets import build_dataloader
14
+ from rscd.optimizers import build_optimizer
15
+ from rscd.losses import build_loss
16
+ from utils.config import Config
17
+
18
+ from torch.autograd import Variable
19
+
20
+ import sys
21
+ sys.path.append('rscd')
22
+
23
+ seed_everything(1234, workers=True)
24
+
25
+ import numpy as np
26
+ import os
27
+
28
+ import time # 用于计时
29
+
30
+ def resize_label(label, size):
31
+
32
+ label = np.expand_dims(label,axis=0)
33
+ label_resized = np.zeros((1,label.shape[1],size[0],size[1]))
34
+ interp = nn.Upsample(size=(size[0], size[1]),mode='bilinear')
35
+
36
+ labelVar = Variable(torch.from_numpy(label).float())
37
+ label_resized[:, :,:,:] = interp(labelVar).data.numpy()
38
+ label_resized = np.array(label_resized, dtype=np.int32)
39
+ return torch.from_numpy(np.squeeze(label_resized,axis=0)).float()
40
+
41
+ def get_args():
42
+ parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
43
+ parser.add_argument("-c", "--config", type=str, default="configs/cdlamba.py")
44
+ return parser.parse_args()
45
+
46
+ class myTrain(LightningModule):
47
+ def __init__(self, cfg, log_dir = None):
48
+ super(myTrain, self).__init__()
49
+
50
+ self.cfg = cfg
51
+ self.log_dir = log_dir
52
+ self.net = build_model(cfg.model_config)
53
+ self.loss = build_loss(cfg.loss_config)
54
+
55
+ self.loss.to('cuda:{}'.format(cfg.gpus[0]))
56
+
57
+ metric_cfg1 = cfg.metric_cfg1
58
+ metric_cfg2 = cfg.metric_cfg2
59
+
60
+ self.tr_oa=torchmetrics.Accuracy(**metric_cfg1)
61
+ self.tr_prec = torchmetrics.Precision(**metric_cfg2)
62
+ self.tr_recall = torchmetrics.Recall(**metric_cfg2)
63
+ self.tr_f1 = torchmetrics.F1Score(**metric_cfg2)
64
+ self.tr_iou=torchmetrics.JaccardIndex(**metric_cfg2)
65
+
66
+ self.val_oa=torchmetrics.Accuracy(**metric_cfg1)
67
+ self.val_prec = torchmetrics.Precision(**metric_cfg2)
68
+ self.val_recall = torchmetrics.Recall(**metric_cfg2)
69
+ self.val_f1 = torchmetrics.F1Score(**metric_cfg2)
70
+ self.val_iou=torchmetrics.JaccardIndex(**metric_cfg2)
71
+
72
+ self.test_oa=torchmetrics.Accuracy(**metric_cfg1)
73
+ self.test_prec = torchmetrics.Precision(**metric_cfg2)
74
+ self.test_recall = torchmetrics.Recall(**metric_cfg2)
75
+ self.test_f1 = torchmetrics.F1Score(**metric_cfg2)
76
+ self.test_iou=torchmetrics.JaccardIndex(**metric_cfg2)
77
+
78
+ self.test_max_f1 = [0 for _ in range(10)]
79
+
80
+ self.test_loader = build_dataloader(cfg.dataset_config, mode='test')
81
+
82
+ def forward(self, x1, x2) :
83
+ pred = self.net(x1, x2)
84
+ return pred
85
+
86
+ def configure_optimizers(self):
87
+ optimizer, scheduler = build_optimizer(self.cfg.optimizer_config, self.net)
88
+ return {'optimizer':optimizer,'lr_scheduler':scheduler, 'monitor': self.cfg.monitor_val}
89
+
90
+ def train_dataloader(self):
91
+ loader = build_dataloader(self.cfg.dataset_config, mode='train')
92
+ return loader
93
+
94
+ def val_dataloader(self):
95
+ loader = build_dataloader(self.cfg.dataset_config, mode='val')
96
+ return loader
97
+
98
+ def output(self, metrics, total_metrics, mode, test_idx=0, test_value=None):
99
+ result_table = prettytable.PrettyTable()
100
+ result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
101
+
102
+ for i in range(len(metrics[0])):
103
+ item = [i, '--']
104
+ for j in range(len(metrics)):
105
+ item.append(np.round(metrics[j][i].cpu().numpy(), 4))
106
+ result_table.add_row(item)
107
+
108
+ total = list(total_metrics.values())
109
+ total = [np.round(v, 4) for v in total]
110
+ total.insert(0, 'total')
111
+ result_table.add_row(total)
112
+
113
+ if mode == 'val' or mode == 'test':
114
+ print(mode)
115
+ print(result_table)
116
+
117
+ if self.log_dir:
118
+ base_dir = self.log_dir
119
+ else:
120
+ base_dir = os.path.join('work_dirs', cfg.exp_name)
121
+
122
+ if mode == 'test':
123
+ if self.cfg.argmax:
124
+ file_name = os.path.join(base_dir, "test_metrics_{}.txt".format(test_idx))
125
+ if metrics[2][1] > self.test_max_f1[test_idx]:
126
+ self.test_max_f1[test_idx] = metrics[2][1]
127
+ file_name = os.path.join(base_dir, "test_max_metrics_{}.txt".format(test_idx))
128
+ else:
129
+ file_name = os.path.join(base_dir, "test_metrics_{}_{}.txt".format(test_idx, str(test_value)))
130
+ if metrics[2][1] > self.test_max_f1[test_idx]:
131
+ self.test_max_f1[test_idx] = metrics[2][1]
132
+ file_name = os.path.join(base_dir, "test_max_metrics_{}_{}.txt".format(test_idx, '%.1f' % test_value))
133
+ else:
134
+ file_name = os.path.join(base_dir, "train_metrics.txt")
135
+ f = open(file_name,"a")
136
+ f.write('epoch:{}/{} {}\n'.format(self.current_epoch, self.cfg.epoch, mode))
137
+ f.write(str(result_table)+'\n')
138
+ f.close()
139
+
140
+ def training_step(self, batch, batch_idx):
141
+ imgA, imgB, mask = batch[0], batch[1], batch[2]
142
+ preds = self(imgA, imgB)
143
+
144
+ if self.cfg.net == 'SARASNet':
145
+ mask = Variable(resize_label(mask.data.cpu().numpy(), \
146
+ size=preds.data.cpu().numpy().shape[2:]).to('cuda')).long()
147
+ param = 1 # This parameter is balance precision and recall to get higher F1-score
148
+ preds[:,1,:,:] = preds[:,1,:,:] + param
149
+
150
+ if self.cfg.argmax:
151
+ loss = self.loss(preds, mask)
152
+ pred = preds.argmax(dim=1)
153
+ else:
154
+ if self.cfg.net == 'maskcd':
155
+ loss = self.loss(preds[1], mask)
156
+ pred = preds[0]
157
+ pred = pred > 0.5
158
+ pred.squeeze_(1)
159
+ else:
160
+ pred = preds.squeeze(1)
161
+ loss = self.loss(pred, mask)
162
+ pred = pred > 0.5
163
+
164
+ self.tr_oa(pred, mask)
165
+ self.tr_prec(pred, mask)
166
+ self.tr_recall(pred, mask)
167
+ self.tr_f1(pred, mask)
168
+ self.tr_iou(pred, mask)
169
+
170
+ self.log('tr_loss', loss, on_step=True,on_epoch=True,prog_bar=True)
171
+ return loss
172
+
173
+ def on_train_epoch_end(self):
174
+ metrics = [self.tr_prec.compute(),
175
+ self.tr_recall.compute(),
176
+ self.tr_f1.compute(),
177
+ self.tr_iou.compute()]
178
+
179
+ log = {'tr_oa': float(self.tr_oa.compute().cpu()),
180
+ 'tr_prec': np.mean([item.cpu() for item in metrics[0]]),
181
+ 'tr_recall': np.mean([item.cpu() for item in metrics[1]]),
182
+ 'tr_f1': np.mean([item.cpu() for item in metrics[2]]),
183
+ 'tr_miou': np.mean([item.cpu() for item in metrics[3]])}
184
+
185
+ self.output(metrics, log, 'train')
186
+
187
+ for key, value in zip(log.keys(), log.values()):
188
+ self.log(key, value, on_step=False,on_epoch=True,prog_bar=True)
189
+ self.log('tr_change_f1', metrics[2][1], on_step=False,on_epoch=True,prog_bar=True)
190
+
191
+ self.tr_oa.reset()
192
+ self.tr_prec.reset()
193
+ self.tr_recall.reset()
194
+ self.tr_f1.reset()
195
+ self.tr_iou.reset()
196
+
197
+ def validation_step(self, batch, batch_idx):
198
+ imgA, imgB, mask = batch[0], batch[1], batch[2]
199
+ preds = self(imgA, imgB)
200
+
201
+ if self.cfg.net == 'SARASNet':
202
+ mask = Variable(resize_label(mask.data.cpu().numpy(), \
203
+ size=preds.data.cpu().numpy().shape[2:]).to('cuda')).long()
204
+ param = 1 # This parameter is balance precision and recall to get higher F1-score
205
+ preds[:,1,:,:] = preds[:,1,:,:] + param
206
+
207
+ if self.cfg.argmax:
208
+ loss = self.loss(preds, mask)
209
+ pred = preds.argmax(dim=1)
210
+ else:
211
+ if self.cfg.net == 'maskcd':
212
+ loss = self.loss(preds[1], mask)
213
+ pred = preds[0]
214
+ pred = pred > 0.5
215
+ pred.squeeze_(1)
216
+ else:
217
+ pred = preds.squeeze(1)
218
+ loss = self.loss(pred, mask)
219
+ pred = pred > 0.5
220
+
221
+ self.val_oa(pred, mask)
222
+ self.val_prec(pred, mask)
223
+ self.val_recall(pred, mask)
224
+ self.val_f1(pred, mask)
225
+ self.val_iou(pred, mask)
226
+
227
+ self.log('val_loss', loss, on_step=True,on_epoch=True,prog_bar=True)
228
+ return loss
229
+
230
+ def on_validation_epoch_end(self):
231
+ metrics = [self.val_prec.compute(),
232
+ self.val_recall.compute(),
233
+ self.val_f1.compute(),
234
+ self.val_iou.compute()]
235
+
236
+ log = {'val_oa': float(self.val_oa.compute().cpu()),
237
+ 'val_prec': np.mean([item.cpu() for item in metrics[0]]),
238
+ 'val_recall': np.mean([item.cpu() for item in metrics[1]]),
239
+ 'val_f1': np.mean([item.cpu() for item in metrics[2]]),
240
+ 'val_miou': np.mean([item.cpu() for item in metrics[3]])}
241
+
242
+ self.output(metrics, log, 'val')
243
+
244
+ for key, value in zip(log.keys(), log.values()):
245
+ self.log(key, value, on_step=False,on_epoch=True,prog_bar=True)
246
+ self.log('val_change_f1', metrics[2][1], on_step=False,on_epoch=True,prog_bar=True)
247
+
248
+ self.val_oa.reset()
249
+ self.val_prec.reset()
250
+ self.val_recall.reset()
251
+ self.val_f1.reset()
252
+ self.val_iou.reset()
253
+
254
+ for idx in range(0, len(self.cfg.monitor_test), 1):
255
+ if self.cfg.argmax:
256
+ self.log(self.cfg.monitor_test[idx], self.test(idx), on_step=False,on_epoch=True,prog_bar=True)
257
+ else:
258
+ t = 0.2 + 0.1 * idx
259
+ self.log(self.cfg.monitor_test[idx], self.test(idx, t), on_step=False,on_epoch=True,prog_bar=True)
260
+
261
+ def test(self, idx, value = None):
262
+ for input in tqdm(self.test_loader):
263
+ raw_predictions, mask_test = self(input[0].cuda(cfg.gpus[0]), input[1].cuda(cfg.gpus[0])), input[2].cuda(cfg.gpus[0])
264
+
265
+ if self.cfg.net == 'SARASNet':
266
+ mask_test = Variable(resize_label(mask_test.data.cpu().numpy(), \
267
+ size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
268
+ param = 1 # This parameter is balance precision and recall to get higher F1-score
269
+ raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
270
+
271
+ if self.cfg.argmax:
272
+ pred_test = raw_predictions.argmax(dim=1)
273
+ else:
274
+ if self.cfg.net == 'maskcd':
275
+ raw_prediction = raw_predictions[0]
276
+ pred_test = raw_prediction > value
277
+ pred_test.squeeze_(1)
278
+ else:
279
+ pred_test = raw_predictions.squeeze(1)
280
+ pred_test = pred_test > 0.5
281
+
282
+ self.test_oa(pred_test, mask_test)
283
+ self.test_iou(pred_test, mask_test)
284
+ self.test_prec(pred_test, mask_test)
285
+ self.test_f1(pred_test, mask_test)
286
+ self.test_recall(pred_test, mask_test)
287
+
288
+ metrics_test = [self.test_prec.compute(),
289
+ self.test_recall.compute(),
290
+ self.test_f1.compute(),
291
+ self.test_iou.compute()]
292
+
293
+ log = {'test_oa': float(self.test_oa.compute().cpu()),
294
+ 'test_prec': np.mean([item.cpu() for item in metrics_test[0]]),
295
+ 'test_recall': np.mean([item.cpu() for item in metrics_test[1]]),
296
+ 'test_f1': np.mean([item.cpu() for item in metrics_test[2]]),
297
+ 'test_miou': np.mean([item.cpu() for item in metrics_test[3]])}
298
+
299
+ self.output(metrics_test, log, 'test', idx, value)
300
+
301
+ self.test_oa.reset()
302
+ self.test_prec.reset()
303
+ self.test_recall.reset()
304
+ self.test_f1.reset()
305
+ self.test_iou.reset()
306
+
307
+ return metrics_test[2][1]
308
+
309
+ if __name__ == "__main__":
310
+ args = get_args()
311
+ cfg = Config.fromfile(args.config)
312
+ logger = TensorBoardLogger(save_dir = "work_dirs",
313
+ sub_dir = 'log',
314
+ name = cfg.exp_name,
315
+ default_hp_metric = False)
316
+
317
+ log_dir = os.path.dirname(logger.log_dir)
318
+
319
+ model = myTrain(cfg, log_dir)
320
+ # —— 在这里插入“推理 FPS 测试”功能 —— #
321
+ device = torch.device(f'cuda:{cfg.gpus[0]}' if torch.cuda.is_available() else 'cpu')
322
+ model = model.to(device)
323
+ model.eval()
324
+
325
+ # 从验证集 dataloader 里取一个 batch
326
+ val_loader = model.val_dataloader()
327
+ batch_iter = iter(val_loader)
328
+ try:
329
+ batch = next(batch_iter)
330
+ imgA_batch = batch[0]
331
+ imgB_batch = batch[1]
332
+ except StopIteration:
333
+ raise RuntimeError("验证集 dataloader 为空,请检查数据集配置。")
334
+
335
+ # 将输入搬到同一个设备
336
+ imgA_batch = imgA_batch.to(device)
337
+ imgB_batch = imgB_batch.to(device)
338
+
339
+ # 热身推理 10 次
340
+ with torch.no_grad():
341
+ for _ in range(10):
342
+ _ = model(imgA_batch, imgB_batch)
343
+
344
+ # 正式计时 N 次推理
345
+ N = 100
346
+ torch.cuda.synchronize(device)
347
+ start_time = time.time()
348
+ with torch.no_grad():
349
+ for _ in range(N):
350
+ _ = model(imgA_batch, imgB_batch)
351
+ torch.cuda.synchronize(device)
352
+ elapsed = time.time() - start_time
353
+ fps = N / elapsed
354
+ print(f"[推理 FPS 测试] 输入分辨率 = {imgA_batch.shape[2]}×{imgA_batch.shape[3]},"
355
+ f"Batch Size = {imgA_batch.shape[0]},推理 {N} 次总耗时:{elapsed:.4f} 秒,FPS = {fps:.2f}")
356
+ # —— 插入结束 —— #
357
+
358
+ pbar = TQDMProgressBar(refresh_rate=1)
359
+ lr_monitor=LearningRateMonitor(logging_interval = cfg.logging_interval)
360
+ callbacks = [pbar, lr_monitor]
361
+
362
+ ckpt_cb = ModelCheckpoint(dirpath = f'{log_dir}/ckpts/val',
363
+ filename = '{' + cfg.monitor_val + ':.4f}' + '-{epoch:d}',
364
+ monitor = cfg.monitor_val,
365
+ mode = 'max',
366
+ save_top_k = cfg.save_top_k,
367
+ save_last=True)
368
+ callbacks.append(ckpt_cb)
369
+
370
+ for m_test in cfg.monitor_test:
371
+ ckpt_cb = ModelCheckpoint(dirpath = f'{log_dir}/ckpts/test/{m_test}',
372
+ filename = '{' + m_test + ':.4f}' + '-{epoch:d}',
373
+ monitor = m_test,
374
+ mode = 'max',
375
+ save_top_k = cfg.save_top_k,
376
+ save_last=True)
377
+ callbacks.append(ckpt_cb)
378
+
379
+ trainer = Trainer(max_epochs = cfg.epoch,
380
+ # precision='16-mixed',
381
+ callbacks = callbacks,
382
+ logger = logger,
383
+ enable_model_summary = True,
384
+ accelerator = 'auto',
385
+ devices = cfg.gpus,
386
+ num_sanity_val_steps = 2,
387
+ benchmark = True)
388
+
389
+ trainer.fit(model, ckpt_path=cfg.resume_ckpt_path)
utilss.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from torch.nn import functional as F
4
+ import torch
5
+
6
+ class ActivationsAndGradients:
7
+ """ Class for extracting activations and
8
+ registering gradients from targeted intermediate layers """
9
+
10
+ def __init__(self, model, target_layers, reshape_transform):
11
+ self.model = model
12
+ self.gradients = []
13
+ self.activations = []
14
+ self.reshape_transform = reshape_transform
15
+ self.handles = []
16
+ for target_layer in target_layers:
17
+ self.handles.append(
18
+ target_layer.register_forward_hook(
19
+ self.save_activation))
20
+ # Backward compatibility with older pytorch versions:
21
+ if hasattr(target_layer, 'register_full_backward_hook'):
22
+ self.handles.append(
23
+ target_layer.register_full_backward_hook(
24
+ self.save_gradient))
25
+ else:
26
+ self.handles.append(
27
+ target_layer.register_backward_hook(
28
+ self.save_gradient))
29
+
30
+ def save_activation(self, module, input, output):
31
+ activation = output
32
+ if self.reshape_transform is not None:
33
+ activation = self.reshape_transform(activation)
34
+ self.activations.append(activation.cpu().detach())
35
+
36
+ def save_gradient(self, module, grad_input, grad_output):
37
+ # Gradients are computed in reverse order
38
+ grad = grad_output[0]
39
+ if self.reshape_transform is not None:
40
+ grad = self.reshape_transform(grad)
41
+ self.gradients = [grad.cpu().detach()] + self.gradients
42
+
43
+ def __call__(self, x, y):
44
+ self.gradients = []
45
+ self.activations = []
46
+ return self.model(x, y)
47
+
48
+ def release(self):
49
+ for handle in self.handles:
50
+ handle.remove()
51
+
52
+
53
+ class GradCAM:
54
+ def __init__(self,
55
+ cfg,
56
+ model,
57
+ target_layers,
58
+ reshape_transform=None,
59
+ use_cuda=False):
60
+ self.cfg = cfg
61
+ self.model = model.eval()
62
+ self.target_layers = target_layers
63
+ self.reshape_transform = reshape_transform
64
+ self.cuda = use_cuda
65
+ if self.cuda:
66
+ self.model = model.cuda()
67
+ self.activations_and_grads = ActivationsAndGradients(
68
+ self.model, target_layers, reshape_transform)
69
+
70
+ """ Get a vector of weights for every channel in the target layer.
71
+ Methods that return weights channels,
72
+ will typically need to only implement this function. """
73
+
74
+ @staticmethod
75
+ def get_cam_weights(grads):
76
+ return np.mean(grads, axis=(2, 3), keepdims=True)
77
+
78
+ @staticmethod
79
+ def get_loss(output, target_category):
80
+ loss = 0
81
+ for i in range(len(target_category)):
82
+ loss = loss + output[i]
83
+ return loss
84
+
85
+ def get_cam_image(self, activations, grads):
86
+ weights = self.get_cam_weights(grads)
87
+ weighted_activations = weights * activations
88
+ cam = weighted_activations.sum(axis=1)
89
+
90
+ return cam
91
+
92
+ @staticmethod
93
+ def get_target_width_height(input_tensor):
94
+ width, height = input_tensor.size(-1), input_tensor.size(-2)
95
+ return width, height
96
+
97
+ def compute_cam_per_layer(self, input_tensor):
98
+ activations_list = [a.cpu().data.numpy()
99
+ for a in self.activations_and_grads.activations]
100
+ grads_list = [g.cpu().data.numpy()
101
+ for g in self.activations_and_grads.gradients]
102
+ target_size = self.get_target_width_height(input_tensor)
103
+
104
+ cam_per_target_layer = []
105
+ # Loop over the saliency image from every layer
106
+
107
+ for layer_activations, layer_grads in zip(activations_list, grads_list):
108
+ cam = self.get_cam_image(layer_activations, layer_grads)
109
+ cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image
110
+ scaled = self.scale_cam_image(cam, target_size)
111
+ cam_per_target_layer.append(scaled[:, None, :])
112
+
113
+ return cam_per_target_layer
114
+
115
+ def aggregate_multi_layers(self, cam_per_target_layer):
116
+ cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
117
+ cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
118
+ result = np.mean(cam_per_target_layer, axis=1)
119
+ return self.scale_cam_image(result)
120
+
121
+ @staticmethod
122
+ def scale_cam_image(cam, target_size=None):
123
+ result = []
124
+ for img in cam:
125
+ img = img - np.min(img)
126
+ img = img / (1e-7 + np.max(img))
127
+ if target_size is not None:
128
+ img = cv2.resize(img, target_size)
129
+ result.append(img)
130
+ result = np.float32(result)
131
+
132
+ return result
133
+
134
+ def __call__(self, input_tensor, target_category=None):
135
+ x, y = input_tensor
136
+ if self.cuda:
137
+ x = x.cuda()
138
+ y = y.cuda()
139
+
140
+ # 正向传播得到网络输出logits(未经过softmax)
141
+ if self.cfg.net == 'cdmask':
142
+ o, outputs = self.activations_and_grads(x, y)
143
+ mask_cls_results = outputs["pred_logits"]
144
+ mask_pred_results = outputs["pred_masks"]
145
+ mask_pred_results = F.interpolate(
146
+ mask_pred_results,
147
+ scale_factor=(4,4),
148
+ mode="bilinear",
149
+ align_corners=False,
150
+ )
151
+ mask_cls = F.softmax(mask_cls_results, dim=-1)[...,1:]
152
+ mask_pred = mask_pred_results.sigmoid()
153
+ output = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
154
+ else:
155
+ output = self.activations_and_grads(x, y)
156
+
157
+ if isinstance(target_category, int):
158
+ target_category = [target_category] * x.size(0)
159
+
160
+ if target_category is None:
161
+ target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
162
+ print(f"category id: {target_category}")
163
+ else:
164
+ assert (len(target_category) == x.size(0))
165
+
166
+ self.model.zero_grad()
167
+ loss = self.get_loss(output, target_category).sum()
168
+ loss.backward(retain_graph=True)
169
+
170
+ # In most of the saliency attribution papers, the saliency is
171
+ # computed with a single target layer.
172
+ # Commonly it is the last convolutional layer.
173
+ # Here we support passing a list with multiple target layers.
174
+ # It will compute the saliency image for every image,
175
+ # and then aggregate them (with a default mean aggregation).
176
+ # This gives you more flexibility in case you just want to
177
+ # use all conv layers for example, all Batchnorm layers,
178
+ # or something else.
179
+ cam_per_layer = self.compute_cam_per_layer(x)
180
+ return self.aggregate_multi_layers(cam_per_layer)
181
+
182
+ def __del__(self):
183
+ self.activations_and_grads.release()
184
+
185
+ def __enter__(self):
186
+ return self
187
+
188
+ def __exit__(self, exc_type, exc_value, exc_tb):
189
+ self.activations_and_grads.release()
190
+ if isinstance(exc_value, IndexError):
191
+ # Handle IndexError here...
192
+ print(
193
+ f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
194
+ return True
195
+
196
+
197
+ def show_cam_on_image(img: np.ndarray,
198
+ mask: np.ndarray,
199
+ use_rgb: bool = False,
200
+ colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
201
+ """ This function overlays the cam mask on the image as an heatmap.
202
+ By default the heatmap is in BGR format.
203
+
204
+ :param img: The base image in RGB or BGR format.
205
+ :param mask: The cam mask.
206
+ :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
207
+ :param colormap: The OpenCV colormap to be used.
208
+ :returns: The default image with the cam overlay.
209
+ """
210
+
211
+ heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
212
+ if use_rgb:
213
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
214
+ heatmap = np.float32(heatmap) / 255
215
+
216
+ if np.max(img) > 1:
217
+ raise Exception(
218
+ "The input image should np.float32 in the range [0, 1]")
219
+
220
+ cam = heatmap + img
221
+ cam = cam / np.max(cam)
222
+ return np.uint8(255 * cam)
223
+
224
+
225
+ def center_crop_img(img: np.ndarray, size: int):
226
+ h, w, c = img.shape
227
+
228
+ if w == h == size:
229
+ return img
230
+
231
+ if w < h:
232
+ ratio = size / w
233
+ new_w = size
234
+ new_h = int(h * ratio)
235
+ else:
236
+ ratio = size / h
237
+ new_h = size
238
+ new_w = int(w * ratio)
239
+
240
+ img = cv2.resize(img, dsize=(new_w, new_h))
241
+
242
+ if new_w == size:
243
+ h = (new_h - size) // 2
244
+ img = img[h: h+size]
245
+ else:
246
+ w = (new_w - size) // 2
247
+ img = img[:, w: w+size]
248
+
249
+ return img