Upload 9 files
Browse files- eval_tables_only.py +351 -0
- five_connect.py +61 -0
- grad_cam_CNN.py +74 -0
- mask_connect_test.py +54 -0
- params_flops.py +53 -0
- requirements.txt +13 -0
- test.py +128 -0
- train.py +389 -0
- utilss.py +249 -0
eval_tables_only.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import prettytable
|
| 5 |
+
import time
|
| 6 |
+
import os
|
| 7 |
+
import multiprocessing.pool as mpp
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
+
|
| 10 |
+
from train import *
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
from utils.config import Config
|
| 14 |
+
from tools.mask_convert import mask_save
|
| 15 |
+
import numpy as np # [PR] for histogram-based PR accumulation
|
| 16 |
+
import csv
|
| 17 |
+
|
| 18 |
+
# =========================== [PR] Utilities BEGIN ===========================
|
| 19 |
+
class PRHistogram:
|
| 20 |
+
# Memory-friendly PR accumulator. Call update(probs, mask) repeatedly inside
|
| 21 |
+
# your test loop, then call export_csv(path) after the loop.
|
| 22 |
+
# - probs: torch.Tensor in [0,1], shape [B,H,W], "change" probability
|
| 23 |
+
# - mask: torch.Tensor of 0/1 (or 0/255), shape [B,H,W]
|
| 24 |
+
def __init__(self, nbins: int = 1000):
|
| 25 |
+
import numpy as _np
|
| 26 |
+
self.nbins = int(nbins)
|
| 27 |
+
self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64)
|
| 28 |
+
self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64)
|
| 29 |
+
self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1)
|
| 30 |
+
|
| 31 |
+
def update(self, probs, mask):
|
| 32 |
+
import numpy as _np
|
| 33 |
+
p = probs.detach().float().cpu().numpy().ravel()
|
| 34 |
+
g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8)
|
| 35 |
+
pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges)
|
| 36 |
+
neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges)
|
| 37 |
+
self.pos_hist += pos_counts
|
| 38 |
+
self.neg_hist += neg_counts
|
| 39 |
+
|
| 40 |
+
def compute_curve(self):
|
| 41 |
+
import numpy as _np
|
| 42 |
+
# 累加得到从高阈值到低阈值的 TP/FP
|
| 43 |
+
pos_cum = _np.cumsum(self.pos_hist[::-1])
|
| 44 |
+
neg_cum = _np.cumsum(self.neg_hist[::-1])
|
| 45 |
+
TP = pos_cum
|
| 46 |
+
FP = neg_cum
|
| 47 |
+
FN = self.pos_hist.sum() - TP
|
| 48 |
+
TN = None # 曲线里用不到 TN
|
| 49 |
+
|
| 50 |
+
denom_prec = _np.maximum(TP + FP, 1)
|
| 51 |
+
denom_rec = _np.maximum(TP + FN, 1)
|
| 52 |
+
precision = TP / denom_prec
|
| 53 |
+
recall = TP / denom_rec
|
| 54 |
+
|
| 55 |
+
# F1 = 2PR/(P+R)
|
| 56 |
+
denom_f1 = _np.maximum(precision + recall, 1e-12)
|
| 57 |
+
f1 = 2.0 * precision * recall / denom_f1
|
| 58 |
+
|
| 59 |
+
# IoU = TP / (TP + FP + FN)
|
| 60 |
+
denom_iou = _np.maximum(TP + FP + FN, 1)
|
| 61 |
+
iou = TP / denom_iou
|
| 62 |
+
|
| 63 |
+
thresholds = self.bin_edges[::-1][1:] # 与上述累积方向一致的阈值序列
|
| 64 |
+
return thresholds, precision, recall, f1, iou, TP, FP, FN
|
| 65 |
+
|
| 66 |
+
def export_csv(self, save_path: str):
|
| 67 |
+
thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve()
|
| 68 |
+
import numpy as _np, os as _os
|
| 69 |
+
_os.makedirs(_os.path.dirname(save_path), exist_ok=True)
|
| 70 |
+
_np.savetxt(
|
| 71 |
+
save_path,
|
| 72 |
+
_np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]),
|
| 73 |
+
delimiter=",",
|
| 74 |
+
header="threshold,precision,recall,f1,iou,TP,FP,FN",
|
| 75 |
+
comments=""
|
| 76 |
+
)
|
| 77 |
+
return save_path
|
| 78 |
+
|
| 79 |
+
# Global PR object (create when needed)
|
| 80 |
+
_PR = None
|
| 81 |
+
|
| 82 |
+
def pr_init(nbins: int = 1000):
|
| 83 |
+
global _PR
|
| 84 |
+
if _PR is None:
|
| 85 |
+
_PR = PRHistogram(nbins=nbins)
|
| 86 |
+
return _PR
|
| 87 |
+
|
| 88 |
+
def pr_update_from_outputs(raw_predictions, mask, cfg):
|
| 89 |
+
# Try to derive probs ∈ [0,1] from various model outputs in this repo.
|
| 90 |
+
# This covers:
|
| 91 |
+
# - cfg.argmax=True: 2-channel logits -> softmax class-1 prob
|
| 92 |
+
# - single-channel logits -> sigmoid
|
| 93 |
+
# - net == 'maskcd' (list/tuple outputs)
|
| 94 |
+
# Modify here if your network has a special head.
|
| 95 |
+
import torch
|
| 96 |
+
global _PR
|
| 97 |
+
if _PR is None:
|
| 98 |
+
_PR = PRHistogram(nbins=1000)
|
| 99 |
+
|
| 100 |
+
if getattr(cfg, 'argmax', False):
|
| 101 |
+
logits = raw_predictions
|
| 102 |
+
if logits.dim() == 4 and logits.size(1) >= 2:
|
| 103 |
+
probs = torch.softmax(logits, dim=1)[:, 1, :, :]
|
| 104 |
+
else:
|
| 105 |
+
probs = torch.sigmoid(logits.squeeze(1))
|
| 106 |
+
else:
|
| 107 |
+
if getattr(cfg, 'net', '') == 'maskcd':
|
| 108 |
+
if isinstance(raw_predictions, (list, tuple)):
|
| 109 |
+
logits = raw_predictions[0]
|
| 110 |
+
else:
|
| 111 |
+
logits = raw_predictions
|
| 112 |
+
probs = torch.sigmoid(logits).squeeze(1)
|
| 113 |
+
else:
|
| 114 |
+
logits = raw_predictions
|
| 115 |
+
if logits.dim() == 4 and logits.size(1) == 1:
|
| 116 |
+
logits = logits.squeeze(1)
|
| 117 |
+
probs = torch.sigmoid(logits)
|
| 118 |
+
|
| 119 |
+
if mask.dim() == 4 and mask.size(1) == 1:
|
| 120 |
+
mask_ = mask.squeeze(1)
|
| 121 |
+
else:
|
| 122 |
+
mask_ = mask
|
| 123 |
+
_PR.update(probs, (mask_ > 0).to(probs.dtype))
|
| 124 |
+
|
| 125 |
+
def pr_export(base_dir: str, cfg):
|
| 126 |
+
# Export PR CSV to base_dir/pr_<net>.csv
|
| 127 |
+
import os
|
| 128 |
+
global _PR
|
| 129 |
+
if _PR is None:
|
| 130 |
+
return None
|
| 131 |
+
save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv")
|
| 132 |
+
out = _PR.export_csv(save_path)
|
| 133 |
+
print(f"[PR] saved: {out}")
|
| 134 |
+
return out
|
| 135 |
+
# ============================ [PR] Utilities END ============================
|
| 136 |
+
|
| 137 |
+
# -------------------- [Per-Image] 逐图指标工具 --------------------
|
| 138 |
+
def _safe_div(a, b, eps=1e-12):
|
| 139 |
+
return a / max(b, eps)
|
| 140 |
+
|
| 141 |
+
def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray):
|
| 142 |
+
"""
|
| 143 |
+
pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W]
|
| 144 |
+
返回: dict 包含 TP/FP/TN/FN 与各类指标
|
| 145 |
+
"""
|
| 146 |
+
pred_bin = (pred_np > 0).astype(np.uint8)
|
| 147 |
+
gt_bin = (gt_np > 0).astype(np.uint8)
|
| 148 |
+
|
| 149 |
+
TP = int(((pred_bin == 1) & (gt_bin == 1)).sum())
|
| 150 |
+
FP = int(((pred_bin == 1) & (gt_bin == 0)).sum())
|
| 151 |
+
TN = int(((pred_bin == 0) & (gt_bin == 0)).sum())
|
| 152 |
+
FN = int(((pred_bin == 0) & (gt_bin == 1)).sum())
|
| 153 |
+
|
| 154 |
+
precision = _safe_div(TP, (TP + FP))
|
| 155 |
+
recall = _safe_div(TP, (TP + FN))
|
| 156 |
+
f1 = _safe_div(2 * precision * recall, (precision + recall))
|
| 157 |
+
iou = _safe_div(TP, (TP + FP + FN))
|
| 158 |
+
oa = _safe_div(TP + TN, (TP + TN + FP + FN))
|
| 159 |
+
|
| 160 |
+
return {
|
| 161 |
+
"TP": TP, "FP": FP, "TN": TN, "FN": FN,
|
| 162 |
+
"OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou
|
| 163 |
+
}
|
| 164 |
+
# --------------------------------------------------------------------
|
| 165 |
+
|
| 166 |
+
def get_args():
|
| 167 |
+
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
|
| 168 |
+
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
|
| 169 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 170 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
| 171 |
+
# 新增:仅生成表格模式(不导出可视化图片)
|
| 172 |
+
parser.add_argument("--tables-only", action="store_true",
|
| 173 |
+
help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片")
|
| 174 |
+
return parser.parse_args()
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
args = get_args()
|
| 178 |
+
cfg = Config.fromfile(args.config)
|
| 179 |
+
|
| 180 |
+
ckpt = args.ckpt
|
| 181 |
+
if ckpt is None:
|
| 182 |
+
ckpt = cfg.test_ckpt_path
|
| 183 |
+
assert ckpt is not None
|
| 184 |
+
|
| 185 |
+
if args.output_dir:
|
| 186 |
+
base_dir = args.output_dir
|
| 187 |
+
else:
|
| 188 |
+
base_dir = os.path.dirname(ckpt)
|
| 189 |
+
|
| 190 |
+
# 原图像输出目录(仅在需要写图时使用)
|
| 191 |
+
masks_output_dir = os.path.join(base_dir, "mask_rgb")
|
| 192 |
+
# 表格输出目录(逐图表格 .txt),如果 tables-only 则单独放在 tables_only 下
|
| 193 |
+
tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb")
|
| 194 |
+
os.makedirs(tables_output_dir, exist_ok=True)
|
| 195 |
+
|
| 196 |
+
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
|
| 197 |
+
model = model.to('cuda')
|
| 198 |
+
model.eval()
|
| 199 |
+
|
| 200 |
+
metric_cfg_1 = cfg.metric_cfg1
|
| 201 |
+
metric_cfg_2 = cfg.metric_cfg2
|
| 202 |
+
|
| 203 |
+
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
|
| 204 |
+
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
|
| 205 |
+
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
|
| 206 |
+
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
|
| 207 |
+
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
|
| 208 |
+
|
| 209 |
+
results = [] # 仅在生成图片时使用
|
| 210 |
+
per_image_rows = [] # [Per-Image] 收集逐图指标
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
test_loader = build_dataloader(cfg.dataset_config, mode='test')
|
| 214 |
+
# === 调用1: 初始化 ===
|
| 215 |
+
pr_init(nbins=1000)
|
| 216 |
+
|
| 217 |
+
for input in tqdm(test_loader):
|
| 218 |
+
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
|
| 219 |
+
# === 调用2: 更新 ===
|
| 220 |
+
pr_update_from_outputs(raw_predictions, mask, cfg)
|
| 221 |
+
|
| 222 |
+
if cfg.net == 'SARASNet':
|
| 223 |
+
mask = Variable(resize_label(mask.data.cpu().numpy(), \
|
| 224 |
+
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
|
| 225 |
+
param = 1 # This parameter is balance precision and recall to get higher F1-score
|
| 226 |
+
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
|
| 227 |
+
|
| 228 |
+
if cfg.argmax:
|
| 229 |
+
pred = raw_predictions.argmax(dim=1)
|
| 230 |
+
else:
|
| 231 |
+
if cfg.net == 'maskcd':
|
| 232 |
+
pred = raw_predictions[0]
|
| 233 |
+
pred = pred > 0.5
|
| 234 |
+
pred.squeeze_(1)
|
| 235 |
+
else:
|
| 236 |
+
pred = raw_predictions.squeeze(1)
|
| 237 |
+
pred = pred > 0.5
|
| 238 |
+
|
| 239 |
+
# ====== 累计整体验证指标 ======
|
| 240 |
+
test_oa(pred, mask)
|
| 241 |
+
test_iou(pred, mask)
|
| 242 |
+
test_prec(pred, mask)
|
| 243 |
+
test_f1(pred, mask)
|
| 244 |
+
test_recall(pred, mask)
|
| 245 |
+
|
| 246 |
+
# ====== [Per-Image] 逐图指标计算与收集 ======
|
| 247 |
+
for i in range(raw_predictions.shape[0]):
|
| 248 |
+
mask_real = mask[i].detach().cpu().numpy()
|
| 249 |
+
mask_pred = pred[i].detach().cpu().numpy()
|
| 250 |
+
mask_name = str(img_id[i])
|
| 251 |
+
|
| 252 |
+
# 逐图统计
|
| 253 |
+
stats = per_image_stats(mask_pred, mask_real)
|
| 254 |
+
per_image_rows.append({
|
| 255 |
+
"img_id": mask_name,
|
| 256 |
+
"TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"],
|
| 257 |
+
"OA": stats["OA"], "Precision": stats["Precision"],
|
| 258 |
+
"Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"]
|
| 259 |
+
})
|
| 260 |
+
|
| 261 |
+
# 仅在需要生成可视化图片时才收集写图任务
|
| 262 |
+
if not args.tables_only:
|
| 263 |
+
results.append((mask_real, mask_pred, masks_output_dir, mask_name))
|
| 264 |
+
|
| 265 |
+
# ====== 打印总体指标 ======
|
| 266 |
+
metrics = [test_prec.compute(),
|
| 267 |
+
test_recall.compute(),
|
| 268 |
+
test_f1.compute(),
|
| 269 |
+
test_iou.compute()]
|
| 270 |
+
|
| 271 |
+
total_metrics = [test_oa.compute().cpu().numpy(),
|
| 272 |
+
np.mean([item.cpu() for item in metrics[0]]),
|
| 273 |
+
np.mean([item.cpu() for item in metrics[1]]),
|
| 274 |
+
np.mean([item.cpu() for item in metrics[2]]),
|
| 275 |
+
np.mean([item.cpu() for item in metrics[3]])]
|
| 276 |
+
|
| 277 |
+
result_table = prettytable.PrettyTable()
|
| 278 |
+
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
|
| 279 |
+
|
| 280 |
+
for i in range(2):
|
| 281 |
+
item = [i, '--']
|
| 282 |
+
for j in range(len(metrics)):
|
| 283 |
+
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
|
| 284 |
+
result_table.add_row(item)
|
| 285 |
+
|
| 286 |
+
total = [np.round(v, 4) for v in total_metrics]
|
| 287 |
+
total.insert(0, 'total')
|
| 288 |
+
result_table.add_row(total)
|
| 289 |
+
print(result_table)
|
| 290 |
+
|
| 291 |
+
file_name = os.path.join(base_dir, "test_res.txt")
|
| 292 |
+
f = open(file_name,"a")
|
| 293 |
+
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
|
| 294 |
+
f.write(current_time+'\n')
|
| 295 |
+
f.write(str(result_table)+'\n')
|
| 296 |
+
|
| 297 |
+
# ====== 根据模式选择是否写图 ======
|
| 298 |
+
if not args.tables_only:
|
| 299 |
+
if not os.path.exists(masks_output_dir):
|
| 300 |
+
os.makedirs(masks_output_dir)
|
| 301 |
+
print(masks_output_dir)
|
| 302 |
+
|
| 303 |
+
# 多进程写图
|
| 304 |
+
t0 = time.time()
|
| 305 |
+
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
|
| 306 |
+
t1 = time.time()
|
| 307 |
+
img_write_time = t1 - t0
|
| 308 |
+
print('images writing spends: {} s'.format(img_write_time))
|
| 309 |
+
else:
|
| 310 |
+
print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。")
|
| 311 |
+
|
| 312 |
+
# ====== [Per-Image] 将逐图指标写成一个总 CSV ======
|
| 313 |
+
per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv")
|
| 314 |
+
with open(per_image_csv, "w", newline="") as wf:
|
| 315 |
+
writer = csv.DictWriter(
|
| 316 |
+
wf,
|
| 317 |
+
fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"]
|
| 318 |
+
)
|
| 319 |
+
writer.writeheader()
|
| 320 |
+
for row in per_image_rows:
|
| 321 |
+
row_out = dict(row)
|
| 322 |
+
for k in ["OA","Precision","Recall","F1","IoU"]:
|
| 323 |
+
row_out[k] = float(np.round(row_out[k], 6))
|
| 324 |
+
writer.writerow(row_out)
|
| 325 |
+
print(f"[Per-Image] saved CSV: {per_image_csv}")
|
| 326 |
+
|
| 327 |
+
# ====== [Per-Image] 为每张图各自写一个小表(.txt) ======
|
| 328 |
+
for row in per_image_rows:
|
| 329 |
+
txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt")
|
| 330 |
+
pt = prettytable.PrettyTable()
|
| 331 |
+
pt.field_names = ["Metric", "Value"]
|
| 332 |
+
# 先放混淆矩阵元素
|
| 333 |
+
pt.add_row(["TP", row["TP"]])
|
| 334 |
+
pt.add_row(["FP", row["FP"]])
|
| 335 |
+
pt.add_row(["TN", row["TN"]])
|
| 336 |
+
pt.add_row(["FN", row["FN"]])
|
| 337 |
+
# 再放比率类指标
|
| 338 |
+
pt.add_row(["OA", f"{row['OA']:.6f}"])
|
| 339 |
+
pt.add_row(["Precision",f"{row['Precision']:.6f}"])
|
| 340 |
+
pt.add_row(["Recall", f"{row['Recall']:.6f}"])
|
| 341 |
+
pt.add_row(["F1", f"{row['F1']:.6f}"])
|
| 342 |
+
pt.add_row(["IoU", f"{row['IoU']:.6f}"])
|
| 343 |
+
with open(txt_path, "w") as wf:
|
| 344 |
+
wf.write(str(pt))
|
| 345 |
+
print(f"[Per-Image] per-image tables saved to: {tables_output_dir}")
|
| 346 |
+
|
| 347 |
+
# ===== [PR] Export at program end =====
|
| 348 |
+
try:
|
| 349 |
+
pr_export(base_dir, cfg)
|
| 350 |
+
except Exception as e:
|
| 351 |
+
print(f"[PR] export skipped or failed: {e}")
|
five_connect.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def concat_image_heatmap(img1_path, img2_path, label_path, mask_path, heatmap_path, output_path):
|
| 7 |
+
img1 = cv2.imread(img1_path)
|
| 8 |
+
img2 = cv2.imread(img2_path)
|
| 9 |
+
mask = cv2.imread(mask_path)
|
| 10 |
+
heatmap = cv2.imread(heatmap_path)
|
| 11 |
+
label = cv2.imread(label_path) if label_path and os.path.exists(label_path) else None
|
| 12 |
+
|
| 13 |
+
if img1 is None or img2 is None or mask is None or heatmap is None:
|
| 14 |
+
print(f"❌ Missing image: {img1_path}, {img2_path}, {mask_path}, {heatmap_path}")
|
| 15 |
+
return
|
| 16 |
+
|
| 17 |
+
h, w = img1.shape[:2]
|
| 18 |
+
img2 = cv2.resize(img2, (w, h))
|
| 19 |
+
mask = cv2.resize(mask, (w, h))
|
| 20 |
+
heatmap = cv2.resize(heatmap, (w, h))
|
| 21 |
+
label = cv2.resize(label, (w, h)) if label is not None else np.zeros_like(img1)
|
| 22 |
+
|
| 23 |
+
top_row = np.concatenate([img1, img2, label], axis=1)
|
| 24 |
+
bottom_row = np.concatenate([mask, heatmap], axis=1)
|
| 25 |
+
|
| 26 |
+
# 补齐对齐
|
| 27 |
+
max_width = max(top_row.shape[1], bottom_row.shape[1])
|
| 28 |
+
if top_row.shape[1] < max_width:
|
| 29 |
+
pad = max_width - top_row.shape[1]
|
| 30 |
+
top_row = cv2.copyMakeBorder(top_row, 0, 0, 0, pad, cv2.BORDER_CONSTANT, value=0)
|
| 31 |
+
if bottom_row.shape[1] < max_width:
|
| 32 |
+
pad = max_width - bottom_row.shape[1]
|
| 33 |
+
bottom_row = cv2.copyMakeBorder(bottom_row, 0, 0, 0, pad, cv2.BORDER_CONSTANT, value=0)
|
| 34 |
+
|
| 35 |
+
full_image = np.concatenate([top_row, bottom_row], axis=0)
|
| 36 |
+
cv2.imwrite(output_path, full_image)
|
| 37 |
+
print(f"✅ Saved: {output_path}")
|
| 38 |
+
|
| 39 |
+
def batch_process(img1_dir, img2_dir, label_dir, mask_dir, heatmap_dir, output_dir):
|
| 40 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 41 |
+
img1_paths = glob.glob(os.path.join(img1_dir, "*.png"))
|
| 42 |
+
|
| 43 |
+
for img1_path in img1_paths:
|
| 44 |
+
filename = os.path.basename(img1_path)
|
| 45 |
+
img2_path = os.path.join(img2_dir, filename)
|
| 46 |
+
label_path = os.path.join(label_dir, filename) if label_dir else None
|
| 47 |
+
mask_path = os.path.join(mask_dir, filename)
|
| 48 |
+
heatmap_path = os.path.join(heatmap_dir, filename)
|
| 49 |
+
output_path = os.path.join(output_dir, filename.replace(".png", "_full.png"))
|
| 50 |
+
|
| 51 |
+
concat_image_heatmap(img1_path, img2_path, label_path, mask_path, heatmap_path, output_path)
|
| 52 |
+
|
| 53 |
+
# 设置路径
|
| 54 |
+
img1_dir = "data/WHU_CD/test/image1"
|
| 55 |
+
img2_dir = "data/WHU_CD/test/image2"
|
| 56 |
+
label_dir = "data/WHU_CD/test/label" # 可设为 None
|
| 57 |
+
mask_dir = "mask_connect_test_dir/mask_rgb"
|
| 58 |
+
heatmap_dir = "mask_connect_test_dir/grad_cam/model.net.decoderhead.LHBlock2"
|
| 59 |
+
output_dir = "mask_heatmap_concat_dir"
|
| 60 |
+
|
| 61 |
+
batch_process(img1_dir, img2_dir, label_dir, mask_dir, heatmap_dir, output_dir)
|
grad_cam_CNN.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.append('.')
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from utilss import GradCAM, show_cam_on_image, center_crop_img
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
from utils.config import Config
|
| 10 |
+
from train import *
|
| 11 |
+
|
| 12 |
+
def get_args():
|
| 13 |
+
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
|
| 14 |
+
parser.add_argument("-c", "--config", type=str, default="configs\cdxformer.py")
|
| 15 |
+
parser.add_argument("--output_dir", default=None)
|
| 16 |
+
parser.add_argument("--layer", default=None)
|
| 17 |
+
return parser.parse_args()
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
args = get_args()
|
| 21 |
+
|
| 22 |
+
if args.layer == None:
|
| 23 |
+
raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
|
| 24 |
+
|
| 25 |
+
cfg = Config.fromfile(args.config)
|
| 26 |
+
|
| 27 |
+
model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
|
| 28 |
+
model = model.to('cuda')
|
| 29 |
+
|
| 30 |
+
# print(dict(model.named_modules()).keys())
|
| 31 |
+
|
| 32 |
+
test_loader = build_dataloader(cfg.dataset_config, mode='test')
|
| 33 |
+
|
| 34 |
+
if args.output_dir:
|
| 35 |
+
base_dir = args.output_dir
|
| 36 |
+
else:
|
| 37 |
+
base_dir = os.path.dirname(cfg.test_ckpt_path)
|
| 38 |
+
gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer)
|
| 39 |
+
if os.path.exists(gradcam_output_dir):
|
| 40 |
+
raise NameError("Please ensure gradcam_output_dir does not exist!")
|
| 41 |
+
|
| 42 |
+
os.makedirs(gradcam_output_dir)
|
| 43 |
+
|
| 44 |
+
for input in tqdm(test_loader):
|
| 45 |
+
target_layers = [eval(args.layer)] # name of the network layer
|
| 46 |
+
mask, img_id = input[2].cuda(), input[3]
|
| 47 |
+
|
| 48 |
+
cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True)
|
| 49 |
+
target_category = 1 # tabby, tabby cat
|
| 50 |
+
|
| 51 |
+
grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
|
| 52 |
+
|
| 53 |
+
for i in range(grayscale_cam_all.shape[0]):
|
| 54 |
+
grayscale_cam = grayscale_cam_all[i, :]
|
| 55 |
+
visualization = show_cam_on_image(0,
|
| 56 |
+
grayscale_cam,
|
| 57 |
+
use_rgb=True)
|
| 58 |
+
fig = plt.figure()
|
| 59 |
+
ax = fig.add_subplot(111)
|
| 60 |
+
ax.imshow(visualization)
|
| 61 |
+
# ax = fig.add_subplot(122)
|
| 62 |
+
# ax.imshow(mask[i].cpu().numpy())
|
| 63 |
+
ax.set_xticks([])
|
| 64 |
+
ax.set_yticks([])
|
| 65 |
+
ax.spines['top'].set_visible(False)
|
| 66 |
+
ax.spines['right'].set_visible(False)
|
| 67 |
+
ax.spines['bottom'].set_visible(False)
|
| 68 |
+
ax.spines['left'].set_visible(False)
|
| 69 |
+
plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
|
| 70 |
+
plt.close()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == '__main__':
|
| 74 |
+
main()
|
mask_connect_test.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def concat_change_detection_images(img1_path, img2_path, label_path, pred_path, output_path):
|
| 7 |
+
img1 = cv2.imread(img1_path)
|
| 8 |
+
img2 = cv2.imread(img2_path)
|
| 9 |
+
label = cv2.imread(label_path) if os.path.exists(label_path) else None
|
| 10 |
+
pred = cv2.imread(pred_path)
|
| 11 |
+
|
| 12 |
+
if img1 is None or img2 is None or pred is None:
|
| 13 |
+
print(f"Missing or unreadable image: {img1_path}, {img2_path}, {pred_path}")
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
# resize 所有图片为相同大小(以 img1 为基准)
|
| 17 |
+
h, w = img1.shape[:2]
|
| 18 |
+
img2 = cv2.resize(img2, (w, h))
|
| 19 |
+
pred = cv2.resize(pred, (w, h))
|
| 20 |
+
if label is not None:
|
| 21 |
+
label = cv2.resize(label, (w, h))
|
| 22 |
+
|
| 23 |
+
# 组合图像(无 label 时跳过)
|
| 24 |
+
if label is not None:
|
| 25 |
+
concat = np.concatenate([img1, img2, label, pred], axis=1)
|
| 26 |
+
else:
|
| 27 |
+
concat = np.concatenate([img1, img2, pred], axis=1)
|
| 28 |
+
|
| 29 |
+
cv2.imwrite(output_path, concat)
|
| 30 |
+
|
| 31 |
+
def batch_process(img1_dir, img2_dir, label_dir, pred_dir, output_dir):
|
| 32 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 33 |
+
img1_paths = glob.glob(os.path.join(img1_dir, "*.png"))
|
| 34 |
+
for img1_path in img1_paths:
|
| 35 |
+
filename = os.path.basename(img1_path)
|
| 36 |
+
img2_path = os.path.join(img2_dir, filename)
|
| 37 |
+
label_path = os.path.join(label_dir, filename) if label_dir else None
|
| 38 |
+
pred_path = os.path.join(pred_dir, filename)
|
| 39 |
+
output_path = os.path.join(output_dir, filename.replace(".png", "_concat.png"))
|
| 40 |
+
|
| 41 |
+
print(f"[INFO] img1: {img1_path}, img2: {img2_path}")
|
| 42 |
+
print(f"[INFO] label: {label_path}, pred: {pred_path}")
|
| 43 |
+
|
| 44 |
+
concat_change_detection_images(img1_path, img2_path, label_path, pred_path, output_path)
|
| 45 |
+
print(f"Saved: {output_path}")
|
| 46 |
+
|
| 47 |
+
# 设置路径
|
| 48 |
+
img1_dir = "data/WHU_CD/test/image1"
|
| 49 |
+
img2_dir = "data/WHU_CD/test/image2"
|
| 50 |
+
label_dir = "data/WHU_CD/test/label" # 如果没有标签图可以设为 None
|
| 51 |
+
pred_dir = "work_dirs/CLCD_BS4_epoch200/CDXFormer/version_0/ckpts/test/mask_rgb"
|
| 52 |
+
output_dir = "mask_connect_test_dir"
|
| 53 |
+
|
| 54 |
+
batch_process(img1_dir, img2_dir, label_dir, pred_dir, output_dir)
|
params_flops.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.append('.')
|
| 3 |
+
from train import *
|
| 4 |
+
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count, parameter_count
|
| 5 |
+
from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit
|
| 6 |
+
|
| 7 |
+
def parse_args():
|
| 8 |
+
parser = argparse.ArgumentParser(description='count params and flops')
|
| 9 |
+
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
|
| 10 |
+
parser.add_argument("--size", type=int, default=256)
|
| 11 |
+
args = parser.parse_args()
|
| 12 |
+
return args
|
| 13 |
+
|
| 14 |
+
def flops_mamba(model, shape=(3, 224, 224)):
|
| 15 |
+
# shape = self.__input_shape__[1:]
|
| 16 |
+
supported_ops = {
|
| 17 |
+
"aten::silu": None, # as relu is in _IGNORED_OPS
|
| 18 |
+
"aten::neg": None, # as relu is in _IGNORED_OPS
|
| 19 |
+
"aten::exp": None, # as relu is in _IGNORED_OPS
|
| 20 |
+
"aten::flip": None, # as permute is in _IGNORED_OPS
|
| 21 |
+
# "prim::PythonOp.CrossScan": None,
|
| 22 |
+
# "prim::PythonOp.CrossMerge": None,
|
| 23 |
+
"prim::PythonOp.SelectiveScanCuda": selective_scan_flop_jit,
|
| 24 |
+
"prim::PythonOp.SelectiveScanMamba": selective_scan_flop_jit,
|
| 25 |
+
"prim::PythonOp.SelectiveScanOflex": selective_scan_flop_jit,
|
| 26 |
+
"prim::PythonOp.SelectiveScanCore": selective_scan_flop_jit,
|
| 27 |
+
"prim::PythonOp.SelectiveScanNRow": selective_scan_flop_jit,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
model.cuda().eval()
|
| 31 |
+
|
| 32 |
+
input1 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| 33 |
+
input2 = torch.randn((1, *shape), device=next(model.parameters()).device)
|
| 34 |
+
params = parameter_count(model)[""]
|
| 35 |
+
Gflops, unsupported = flop_count(model=model, inputs=(input1,input2), supported_ops=supported_ops)
|
| 36 |
+
|
| 37 |
+
del model, input1, input2
|
| 38 |
+
# return sum(Gflops.values()) * 1e9
|
| 39 |
+
return f"params {params / 1e6} GFLOPs {sum(Gflops.values())}"
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
args = parse_args()
|
| 43 |
+
cfg = Config.fromfile(args.config)
|
| 44 |
+
net = myTrain(cfg).net.cuda()
|
| 45 |
+
|
| 46 |
+
size = args.size
|
| 47 |
+
input = torch.rand((1, 3, size, size)).cuda()
|
| 48 |
+
|
| 49 |
+
net.eval()
|
| 50 |
+
flops = FlopCountAnalysis(net, (input, input))
|
| 51 |
+
print(flop_count_table(flops, max_depth = 2))
|
| 52 |
+
|
| 53 |
+
print(flops_mamba(net, (3, size, size)))
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchmetrics==0.11.4
|
| 2 |
+
pytorch-lightning==2.0.6
|
| 3 |
+
scikit-image==0.21.0
|
| 4 |
+
|
| 5 |
+
catalyst==20.9
|
| 6 |
+
albumentations==1.3.1
|
| 7 |
+
ttach==0.0.3
|
| 8 |
+
einops==0.6.1
|
| 9 |
+
timm==0.6.7
|
| 10 |
+
addict==2.4.0
|
| 11 |
+
soundfile==0.12.1
|
| 12 |
+
prettytable==3.8.0
|
| 13 |
+
fvcore
|
test.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import prettytable
|
| 5 |
+
import time
|
| 6 |
+
import os
|
| 7 |
+
import multiprocessing.pool as mpp
|
| 8 |
+
import multiprocessing as mp
|
| 9 |
+
|
| 10 |
+
from train import *
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
from utils.config import Config
|
| 14 |
+
from tools.mask_convert import mask_save
|
| 15 |
+
|
| 16 |
+
def get_args():
|
| 17 |
+
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
|
| 18 |
+
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py")
|
| 19 |
+
parser.add_argument("--ckpt", type=str, default=None)
|
| 20 |
+
parser.add_argument("--output_dir", type=str, default=None)
|
| 21 |
+
return parser.parse_args()
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
args = get_args()
|
| 25 |
+
cfg = Config.fromfile(args.config)
|
| 26 |
+
|
| 27 |
+
ckpt = args.ckpt
|
| 28 |
+
if ckpt is None:
|
| 29 |
+
ckpt = cfg.test_ckpt_path
|
| 30 |
+
assert ckpt is not None
|
| 31 |
+
|
| 32 |
+
if args.output_dir:
|
| 33 |
+
base_dir = args.output_dir
|
| 34 |
+
else:
|
| 35 |
+
base_dir = os.path.dirname(ckpt)
|
| 36 |
+
masks_output_dir = os.path.join(base_dir, "mask_rgb")
|
| 37 |
+
|
| 38 |
+
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg)
|
| 39 |
+
model = model.to('cuda')
|
| 40 |
+
|
| 41 |
+
model.eval()
|
| 42 |
+
|
| 43 |
+
metric_cfg_1 = cfg.metric_cfg1
|
| 44 |
+
metric_cfg_2 = cfg.metric_cfg2
|
| 45 |
+
|
| 46 |
+
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda')
|
| 47 |
+
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda')
|
| 48 |
+
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda')
|
| 49 |
+
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda')
|
| 50 |
+
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda')
|
| 51 |
+
|
| 52 |
+
results = []
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
test_loader = build_dataloader(cfg.dataset_config, mode='test')
|
| 55 |
+
for input in tqdm(test_loader):
|
| 56 |
+
|
| 57 |
+
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3]
|
| 58 |
+
|
| 59 |
+
if cfg.net == 'SARASNet':
|
| 60 |
+
mask = Variable(resize_label(mask.data.cpu().numpy(), \
|
| 61 |
+
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
|
| 62 |
+
param = 1 # This parameter is balance precision and recall to get higher F1-score
|
| 63 |
+
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
|
| 64 |
+
|
| 65 |
+
if cfg.argmax:
|
| 66 |
+
pred = raw_predictions.argmax(dim=1)
|
| 67 |
+
else:
|
| 68 |
+
if cfg.net == 'maskcd':
|
| 69 |
+
pred = raw_predictions[0]
|
| 70 |
+
pred = pred > 0.5
|
| 71 |
+
pred.squeeze_(1)
|
| 72 |
+
else:
|
| 73 |
+
pred = raw_predictions.squeeze(1)
|
| 74 |
+
pred = pred > 0.5
|
| 75 |
+
|
| 76 |
+
test_oa(pred, mask)
|
| 77 |
+
test_iou(pred, mask)
|
| 78 |
+
test_prec(pred, mask)
|
| 79 |
+
test_f1(pred, mask)
|
| 80 |
+
test_recall(pred, mask)
|
| 81 |
+
|
| 82 |
+
for i in range(raw_predictions.shape[0]):
|
| 83 |
+
mask_real = mask[i].cpu().numpy()
|
| 84 |
+
mask_pred = pred[i].cpu().numpy()
|
| 85 |
+
mask_name = str(img_id[i])
|
| 86 |
+
results.append((mask_real, mask_pred, masks_output_dir, mask_name))
|
| 87 |
+
|
| 88 |
+
metrics = [test_prec.compute(),
|
| 89 |
+
test_recall.compute(),
|
| 90 |
+
test_f1.compute(),
|
| 91 |
+
test_iou.compute()]
|
| 92 |
+
|
| 93 |
+
total_metrics = [test_oa.compute().cpu().numpy(),
|
| 94 |
+
np.mean([item.cpu() for item in metrics[0]]),
|
| 95 |
+
np.mean([item.cpu() for item in metrics[1]]),
|
| 96 |
+
np.mean([item.cpu() for item in metrics[2]]),
|
| 97 |
+
np.mean([item.cpu() for item in metrics[3]])]
|
| 98 |
+
|
| 99 |
+
result_table = prettytable.PrettyTable()
|
| 100 |
+
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
|
| 101 |
+
|
| 102 |
+
for i in range(2):
|
| 103 |
+
item = [i, '--']
|
| 104 |
+
for j in range(len(metrics)):
|
| 105 |
+
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
|
| 106 |
+
result_table.add_row(item)
|
| 107 |
+
|
| 108 |
+
total = [np.round(v, 4) for v in total_metrics]
|
| 109 |
+
total.insert(0, 'total')
|
| 110 |
+
result_table.add_row(total)
|
| 111 |
+
|
| 112 |
+
print(result_table)
|
| 113 |
+
|
| 114 |
+
file_name = os.path.join(base_dir, "test_res.txt")
|
| 115 |
+
f = open(file_name,"a")
|
| 116 |
+
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time()))
|
| 117 |
+
f.write(current_time+'\n')
|
| 118 |
+
f.write(str(result_table)+'\n')
|
| 119 |
+
|
| 120 |
+
if not os.path.exists(masks_output_dir):
|
| 121 |
+
os.makedirs(masks_output_dir)
|
| 122 |
+
print(masks_output_dir)
|
| 123 |
+
|
| 124 |
+
t0 = time.time()
|
| 125 |
+
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results)
|
| 126 |
+
t1 = time.time()
|
| 127 |
+
img_write_time = t1 - t0
|
| 128 |
+
print('images writing spends: {} s'.format(img_write_time))
|
train.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from pytorch_lightning import LightningModule, Trainer, seed_everything
|
| 4 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, TQDMProgressBar
|
| 5 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 6 |
+
import torchmetrics
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
import prettytable
|
| 10 |
+
import numpy as np
|
| 11 |
+
import argparse
|
| 12 |
+
from rscd.models.build_model import build_model
|
| 13 |
+
from rscd.datasets import build_dataloader
|
| 14 |
+
from rscd.optimizers import build_optimizer
|
| 15 |
+
from rscd.losses import build_loss
|
| 16 |
+
from utils.config import Config
|
| 17 |
+
|
| 18 |
+
from torch.autograd import Variable
|
| 19 |
+
|
| 20 |
+
import sys
|
| 21 |
+
sys.path.append('rscd')
|
| 22 |
+
|
| 23 |
+
seed_everything(1234, workers=True)
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
import time # 用于计时
|
| 29 |
+
|
| 30 |
+
def resize_label(label, size):
|
| 31 |
+
|
| 32 |
+
label = np.expand_dims(label,axis=0)
|
| 33 |
+
label_resized = np.zeros((1,label.shape[1],size[0],size[1]))
|
| 34 |
+
interp = nn.Upsample(size=(size[0], size[1]),mode='bilinear')
|
| 35 |
+
|
| 36 |
+
labelVar = Variable(torch.from_numpy(label).float())
|
| 37 |
+
label_resized[:, :,:,:] = interp(labelVar).data.numpy()
|
| 38 |
+
label_resized = np.array(label_resized, dtype=np.int32)
|
| 39 |
+
return torch.from_numpy(np.squeeze(label_resized,axis=0)).float()
|
| 40 |
+
|
| 41 |
+
def get_args():
|
| 42 |
+
parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
|
| 43 |
+
parser.add_argument("-c", "--config", type=str, default="configs/cdlamba.py")
|
| 44 |
+
return parser.parse_args()
|
| 45 |
+
|
| 46 |
+
class myTrain(LightningModule):
|
| 47 |
+
def __init__(self, cfg, log_dir = None):
|
| 48 |
+
super(myTrain, self).__init__()
|
| 49 |
+
|
| 50 |
+
self.cfg = cfg
|
| 51 |
+
self.log_dir = log_dir
|
| 52 |
+
self.net = build_model(cfg.model_config)
|
| 53 |
+
self.loss = build_loss(cfg.loss_config)
|
| 54 |
+
|
| 55 |
+
self.loss.to('cuda:{}'.format(cfg.gpus[0]))
|
| 56 |
+
|
| 57 |
+
metric_cfg1 = cfg.metric_cfg1
|
| 58 |
+
metric_cfg2 = cfg.metric_cfg2
|
| 59 |
+
|
| 60 |
+
self.tr_oa=torchmetrics.Accuracy(**metric_cfg1)
|
| 61 |
+
self.tr_prec = torchmetrics.Precision(**metric_cfg2)
|
| 62 |
+
self.tr_recall = torchmetrics.Recall(**metric_cfg2)
|
| 63 |
+
self.tr_f1 = torchmetrics.F1Score(**metric_cfg2)
|
| 64 |
+
self.tr_iou=torchmetrics.JaccardIndex(**metric_cfg2)
|
| 65 |
+
|
| 66 |
+
self.val_oa=torchmetrics.Accuracy(**metric_cfg1)
|
| 67 |
+
self.val_prec = torchmetrics.Precision(**metric_cfg2)
|
| 68 |
+
self.val_recall = torchmetrics.Recall(**metric_cfg2)
|
| 69 |
+
self.val_f1 = torchmetrics.F1Score(**metric_cfg2)
|
| 70 |
+
self.val_iou=torchmetrics.JaccardIndex(**metric_cfg2)
|
| 71 |
+
|
| 72 |
+
self.test_oa=torchmetrics.Accuracy(**metric_cfg1)
|
| 73 |
+
self.test_prec = torchmetrics.Precision(**metric_cfg2)
|
| 74 |
+
self.test_recall = torchmetrics.Recall(**metric_cfg2)
|
| 75 |
+
self.test_f1 = torchmetrics.F1Score(**metric_cfg2)
|
| 76 |
+
self.test_iou=torchmetrics.JaccardIndex(**metric_cfg2)
|
| 77 |
+
|
| 78 |
+
self.test_max_f1 = [0 for _ in range(10)]
|
| 79 |
+
|
| 80 |
+
self.test_loader = build_dataloader(cfg.dataset_config, mode='test')
|
| 81 |
+
|
| 82 |
+
def forward(self, x1, x2) :
|
| 83 |
+
pred = self.net(x1, x2)
|
| 84 |
+
return pred
|
| 85 |
+
|
| 86 |
+
def configure_optimizers(self):
|
| 87 |
+
optimizer, scheduler = build_optimizer(self.cfg.optimizer_config, self.net)
|
| 88 |
+
return {'optimizer':optimizer,'lr_scheduler':scheduler, 'monitor': self.cfg.monitor_val}
|
| 89 |
+
|
| 90 |
+
def train_dataloader(self):
|
| 91 |
+
loader = build_dataloader(self.cfg.dataset_config, mode='train')
|
| 92 |
+
return loader
|
| 93 |
+
|
| 94 |
+
def val_dataloader(self):
|
| 95 |
+
loader = build_dataloader(self.cfg.dataset_config, mode='val')
|
| 96 |
+
return loader
|
| 97 |
+
|
| 98 |
+
def output(self, metrics, total_metrics, mode, test_idx=0, test_value=None):
|
| 99 |
+
result_table = prettytable.PrettyTable()
|
| 100 |
+
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU']
|
| 101 |
+
|
| 102 |
+
for i in range(len(metrics[0])):
|
| 103 |
+
item = [i, '--']
|
| 104 |
+
for j in range(len(metrics)):
|
| 105 |
+
item.append(np.round(metrics[j][i].cpu().numpy(), 4))
|
| 106 |
+
result_table.add_row(item)
|
| 107 |
+
|
| 108 |
+
total = list(total_metrics.values())
|
| 109 |
+
total = [np.round(v, 4) for v in total]
|
| 110 |
+
total.insert(0, 'total')
|
| 111 |
+
result_table.add_row(total)
|
| 112 |
+
|
| 113 |
+
if mode == 'val' or mode == 'test':
|
| 114 |
+
print(mode)
|
| 115 |
+
print(result_table)
|
| 116 |
+
|
| 117 |
+
if self.log_dir:
|
| 118 |
+
base_dir = self.log_dir
|
| 119 |
+
else:
|
| 120 |
+
base_dir = os.path.join('work_dirs', cfg.exp_name)
|
| 121 |
+
|
| 122 |
+
if mode == 'test':
|
| 123 |
+
if self.cfg.argmax:
|
| 124 |
+
file_name = os.path.join(base_dir, "test_metrics_{}.txt".format(test_idx))
|
| 125 |
+
if metrics[2][1] > self.test_max_f1[test_idx]:
|
| 126 |
+
self.test_max_f1[test_idx] = metrics[2][1]
|
| 127 |
+
file_name = os.path.join(base_dir, "test_max_metrics_{}.txt".format(test_idx))
|
| 128 |
+
else:
|
| 129 |
+
file_name = os.path.join(base_dir, "test_metrics_{}_{}.txt".format(test_idx, str(test_value)))
|
| 130 |
+
if metrics[2][1] > self.test_max_f1[test_idx]:
|
| 131 |
+
self.test_max_f1[test_idx] = metrics[2][1]
|
| 132 |
+
file_name = os.path.join(base_dir, "test_max_metrics_{}_{}.txt".format(test_idx, '%.1f' % test_value))
|
| 133 |
+
else:
|
| 134 |
+
file_name = os.path.join(base_dir, "train_metrics.txt")
|
| 135 |
+
f = open(file_name,"a")
|
| 136 |
+
f.write('epoch:{}/{} {}\n'.format(self.current_epoch, self.cfg.epoch, mode))
|
| 137 |
+
f.write(str(result_table)+'\n')
|
| 138 |
+
f.close()
|
| 139 |
+
|
| 140 |
+
def training_step(self, batch, batch_idx):
|
| 141 |
+
imgA, imgB, mask = batch[0], batch[1], batch[2]
|
| 142 |
+
preds = self(imgA, imgB)
|
| 143 |
+
|
| 144 |
+
if self.cfg.net == 'SARASNet':
|
| 145 |
+
mask = Variable(resize_label(mask.data.cpu().numpy(), \
|
| 146 |
+
size=preds.data.cpu().numpy().shape[2:]).to('cuda')).long()
|
| 147 |
+
param = 1 # This parameter is balance precision and recall to get higher F1-score
|
| 148 |
+
preds[:,1,:,:] = preds[:,1,:,:] + param
|
| 149 |
+
|
| 150 |
+
if self.cfg.argmax:
|
| 151 |
+
loss = self.loss(preds, mask)
|
| 152 |
+
pred = preds.argmax(dim=1)
|
| 153 |
+
else:
|
| 154 |
+
if self.cfg.net == 'maskcd':
|
| 155 |
+
loss = self.loss(preds[1], mask)
|
| 156 |
+
pred = preds[0]
|
| 157 |
+
pred = pred > 0.5
|
| 158 |
+
pred.squeeze_(1)
|
| 159 |
+
else:
|
| 160 |
+
pred = preds.squeeze(1)
|
| 161 |
+
loss = self.loss(pred, mask)
|
| 162 |
+
pred = pred > 0.5
|
| 163 |
+
|
| 164 |
+
self.tr_oa(pred, mask)
|
| 165 |
+
self.tr_prec(pred, mask)
|
| 166 |
+
self.tr_recall(pred, mask)
|
| 167 |
+
self.tr_f1(pred, mask)
|
| 168 |
+
self.tr_iou(pred, mask)
|
| 169 |
+
|
| 170 |
+
self.log('tr_loss', loss, on_step=True,on_epoch=True,prog_bar=True)
|
| 171 |
+
return loss
|
| 172 |
+
|
| 173 |
+
def on_train_epoch_end(self):
|
| 174 |
+
metrics = [self.tr_prec.compute(),
|
| 175 |
+
self.tr_recall.compute(),
|
| 176 |
+
self.tr_f1.compute(),
|
| 177 |
+
self.tr_iou.compute()]
|
| 178 |
+
|
| 179 |
+
log = {'tr_oa': float(self.tr_oa.compute().cpu()),
|
| 180 |
+
'tr_prec': np.mean([item.cpu() for item in metrics[0]]),
|
| 181 |
+
'tr_recall': np.mean([item.cpu() for item in metrics[1]]),
|
| 182 |
+
'tr_f1': np.mean([item.cpu() for item in metrics[2]]),
|
| 183 |
+
'tr_miou': np.mean([item.cpu() for item in metrics[3]])}
|
| 184 |
+
|
| 185 |
+
self.output(metrics, log, 'train')
|
| 186 |
+
|
| 187 |
+
for key, value in zip(log.keys(), log.values()):
|
| 188 |
+
self.log(key, value, on_step=False,on_epoch=True,prog_bar=True)
|
| 189 |
+
self.log('tr_change_f1', metrics[2][1], on_step=False,on_epoch=True,prog_bar=True)
|
| 190 |
+
|
| 191 |
+
self.tr_oa.reset()
|
| 192 |
+
self.tr_prec.reset()
|
| 193 |
+
self.tr_recall.reset()
|
| 194 |
+
self.tr_f1.reset()
|
| 195 |
+
self.tr_iou.reset()
|
| 196 |
+
|
| 197 |
+
def validation_step(self, batch, batch_idx):
|
| 198 |
+
imgA, imgB, mask = batch[0], batch[1], batch[2]
|
| 199 |
+
preds = self(imgA, imgB)
|
| 200 |
+
|
| 201 |
+
if self.cfg.net == 'SARASNet':
|
| 202 |
+
mask = Variable(resize_label(mask.data.cpu().numpy(), \
|
| 203 |
+
size=preds.data.cpu().numpy().shape[2:]).to('cuda')).long()
|
| 204 |
+
param = 1 # This parameter is balance precision and recall to get higher F1-score
|
| 205 |
+
preds[:,1,:,:] = preds[:,1,:,:] + param
|
| 206 |
+
|
| 207 |
+
if self.cfg.argmax:
|
| 208 |
+
loss = self.loss(preds, mask)
|
| 209 |
+
pred = preds.argmax(dim=1)
|
| 210 |
+
else:
|
| 211 |
+
if self.cfg.net == 'maskcd':
|
| 212 |
+
loss = self.loss(preds[1], mask)
|
| 213 |
+
pred = preds[0]
|
| 214 |
+
pred = pred > 0.5
|
| 215 |
+
pred.squeeze_(1)
|
| 216 |
+
else:
|
| 217 |
+
pred = preds.squeeze(1)
|
| 218 |
+
loss = self.loss(pred, mask)
|
| 219 |
+
pred = pred > 0.5
|
| 220 |
+
|
| 221 |
+
self.val_oa(pred, mask)
|
| 222 |
+
self.val_prec(pred, mask)
|
| 223 |
+
self.val_recall(pred, mask)
|
| 224 |
+
self.val_f1(pred, mask)
|
| 225 |
+
self.val_iou(pred, mask)
|
| 226 |
+
|
| 227 |
+
self.log('val_loss', loss, on_step=True,on_epoch=True,prog_bar=True)
|
| 228 |
+
return loss
|
| 229 |
+
|
| 230 |
+
def on_validation_epoch_end(self):
|
| 231 |
+
metrics = [self.val_prec.compute(),
|
| 232 |
+
self.val_recall.compute(),
|
| 233 |
+
self.val_f1.compute(),
|
| 234 |
+
self.val_iou.compute()]
|
| 235 |
+
|
| 236 |
+
log = {'val_oa': float(self.val_oa.compute().cpu()),
|
| 237 |
+
'val_prec': np.mean([item.cpu() for item in metrics[0]]),
|
| 238 |
+
'val_recall': np.mean([item.cpu() for item in metrics[1]]),
|
| 239 |
+
'val_f1': np.mean([item.cpu() for item in metrics[2]]),
|
| 240 |
+
'val_miou': np.mean([item.cpu() for item in metrics[3]])}
|
| 241 |
+
|
| 242 |
+
self.output(metrics, log, 'val')
|
| 243 |
+
|
| 244 |
+
for key, value in zip(log.keys(), log.values()):
|
| 245 |
+
self.log(key, value, on_step=False,on_epoch=True,prog_bar=True)
|
| 246 |
+
self.log('val_change_f1', metrics[2][1], on_step=False,on_epoch=True,prog_bar=True)
|
| 247 |
+
|
| 248 |
+
self.val_oa.reset()
|
| 249 |
+
self.val_prec.reset()
|
| 250 |
+
self.val_recall.reset()
|
| 251 |
+
self.val_f1.reset()
|
| 252 |
+
self.val_iou.reset()
|
| 253 |
+
|
| 254 |
+
for idx in range(0, len(self.cfg.monitor_test), 1):
|
| 255 |
+
if self.cfg.argmax:
|
| 256 |
+
self.log(self.cfg.monitor_test[idx], self.test(idx), on_step=False,on_epoch=True,prog_bar=True)
|
| 257 |
+
else:
|
| 258 |
+
t = 0.2 + 0.1 * idx
|
| 259 |
+
self.log(self.cfg.monitor_test[idx], self.test(idx, t), on_step=False,on_epoch=True,prog_bar=True)
|
| 260 |
+
|
| 261 |
+
def test(self, idx, value = None):
|
| 262 |
+
for input in tqdm(self.test_loader):
|
| 263 |
+
raw_predictions, mask_test = self(input[0].cuda(cfg.gpus[0]), input[1].cuda(cfg.gpus[0])), input[2].cuda(cfg.gpus[0])
|
| 264 |
+
|
| 265 |
+
if self.cfg.net == 'SARASNet':
|
| 266 |
+
mask_test = Variable(resize_label(mask_test.data.cpu().numpy(), \
|
| 267 |
+
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long()
|
| 268 |
+
param = 1 # This parameter is balance precision and recall to get higher F1-score
|
| 269 |
+
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param
|
| 270 |
+
|
| 271 |
+
if self.cfg.argmax:
|
| 272 |
+
pred_test = raw_predictions.argmax(dim=1)
|
| 273 |
+
else:
|
| 274 |
+
if self.cfg.net == 'maskcd':
|
| 275 |
+
raw_prediction = raw_predictions[0]
|
| 276 |
+
pred_test = raw_prediction > value
|
| 277 |
+
pred_test.squeeze_(1)
|
| 278 |
+
else:
|
| 279 |
+
pred_test = raw_predictions.squeeze(1)
|
| 280 |
+
pred_test = pred_test > 0.5
|
| 281 |
+
|
| 282 |
+
self.test_oa(pred_test, mask_test)
|
| 283 |
+
self.test_iou(pred_test, mask_test)
|
| 284 |
+
self.test_prec(pred_test, mask_test)
|
| 285 |
+
self.test_f1(pred_test, mask_test)
|
| 286 |
+
self.test_recall(pred_test, mask_test)
|
| 287 |
+
|
| 288 |
+
metrics_test = [self.test_prec.compute(),
|
| 289 |
+
self.test_recall.compute(),
|
| 290 |
+
self.test_f1.compute(),
|
| 291 |
+
self.test_iou.compute()]
|
| 292 |
+
|
| 293 |
+
log = {'test_oa': float(self.test_oa.compute().cpu()),
|
| 294 |
+
'test_prec': np.mean([item.cpu() for item in metrics_test[0]]),
|
| 295 |
+
'test_recall': np.mean([item.cpu() for item in metrics_test[1]]),
|
| 296 |
+
'test_f1': np.mean([item.cpu() for item in metrics_test[2]]),
|
| 297 |
+
'test_miou': np.mean([item.cpu() for item in metrics_test[3]])}
|
| 298 |
+
|
| 299 |
+
self.output(metrics_test, log, 'test', idx, value)
|
| 300 |
+
|
| 301 |
+
self.test_oa.reset()
|
| 302 |
+
self.test_prec.reset()
|
| 303 |
+
self.test_recall.reset()
|
| 304 |
+
self.test_f1.reset()
|
| 305 |
+
self.test_iou.reset()
|
| 306 |
+
|
| 307 |
+
return metrics_test[2][1]
|
| 308 |
+
|
| 309 |
+
if __name__ == "__main__":
|
| 310 |
+
args = get_args()
|
| 311 |
+
cfg = Config.fromfile(args.config)
|
| 312 |
+
logger = TensorBoardLogger(save_dir = "work_dirs",
|
| 313 |
+
sub_dir = 'log',
|
| 314 |
+
name = cfg.exp_name,
|
| 315 |
+
default_hp_metric = False)
|
| 316 |
+
|
| 317 |
+
log_dir = os.path.dirname(logger.log_dir)
|
| 318 |
+
|
| 319 |
+
model = myTrain(cfg, log_dir)
|
| 320 |
+
# —— 在这里插入“推理 FPS 测试”功能 —— #
|
| 321 |
+
device = torch.device(f'cuda:{cfg.gpus[0]}' if torch.cuda.is_available() else 'cpu')
|
| 322 |
+
model = model.to(device)
|
| 323 |
+
model.eval()
|
| 324 |
+
|
| 325 |
+
# 从验证集 dataloader 里取一个 batch
|
| 326 |
+
val_loader = model.val_dataloader()
|
| 327 |
+
batch_iter = iter(val_loader)
|
| 328 |
+
try:
|
| 329 |
+
batch = next(batch_iter)
|
| 330 |
+
imgA_batch = batch[0]
|
| 331 |
+
imgB_batch = batch[1]
|
| 332 |
+
except StopIteration:
|
| 333 |
+
raise RuntimeError("验证集 dataloader 为空,请检查数据集配置。")
|
| 334 |
+
|
| 335 |
+
# 将输入搬到同一个设备
|
| 336 |
+
imgA_batch = imgA_batch.to(device)
|
| 337 |
+
imgB_batch = imgB_batch.to(device)
|
| 338 |
+
|
| 339 |
+
# 热身推理 10 次
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
for _ in range(10):
|
| 342 |
+
_ = model(imgA_batch, imgB_batch)
|
| 343 |
+
|
| 344 |
+
# 正式计时 N 次推理
|
| 345 |
+
N = 100
|
| 346 |
+
torch.cuda.synchronize(device)
|
| 347 |
+
start_time = time.time()
|
| 348 |
+
with torch.no_grad():
|
| 349 |
+
for _ in range(N):
|
| 350 |
+
_ = model(imgA_batch, imgB_batch)
|
| 351 |
+
torch.cuda.synchronize(device)
|
| 352 |
+
elapsed = time.time() - start_time
|
| 353 |
+
fps = N / elapsed
|
| 354 |
+
print(f"[推理 FPS 测试] 输入分辨率 = {imgA_batch.shape[2]}×{imgA_batch.shape[3]},"
|
| 355 |
+
f"Batch Size = {imgA_batch.shape[0]},推理 {N} 次总耗时:{elapsed:.4f} 秒,FPS = {fps:.2f}")
|
| 356 |
+
# —— 插入结束 —— #
|
| 357 |
+
|
| 358 |
+
pbar = TQDMProgressBar(refresh_rate=1)
|
| 359 |
+
lr_monitor=LearningRateMonitor(logging_interval = cfg.logging_interval)
|
| 360 |
+
callbacks = [pbar, lr_monitor]
|
| 361 |
+
|
| 362 |
+
ckpt_cb = ModelCheckpoint(dirpath = f'{log_dir}/ckpts/val',
|
| 363 |
+
filename = '{' + cfg.monitor_val + ':.4f}' + '-{epoch:d}',
|
| 364 |
+
monitor = cfg.monitor_val,
|
| 365 |
+
mode = 'max',
|
| 366 |
+
save_top_k = cfg.save_top_k,
|
| 367 |
+
save_last=True)
|
| 368 |
+
callbacks.append(ckpt_cb)
|
| 369 |
+
|
| 370 |
+
for m_test in cfg.monitor_test:
|
| 371 |
+
ckpt_cb = ModelCheckpoint(dirpath = f'{log_dir}/ckpts/test/{m_test}',
|
| 372 |
+
filename = '{' + m_test + ':.4f}' + '-{epoch:d}',
|
| 373 |
+
monitor = m_test,
|
| 374 |
+
mode = 'max',
|
| 375 |
+
save_top_k = cfg.save_top_k,
|
| 376 |
+
save_last=True)
|
| 377 |
+
callbacks.append(ckpt_cb)
|
| 378 |
+
|
| 379 |
+
trainer = Trainer(max_epochs = cfg.epoch,
|
| 380 |
+
# precision='16-mixed',
|
| 381 |
+
callbacks = callbacks,
|
| 382 |
+
logger = logger,
|
| 383 |
+
enable_model_summary = True,
|
| 384 |
+
accelerator = 'auto',
|
| 385 |
+
devices = cfg.gpus,
|
| 386 |
+
num_sanity_val_steps = 2,
|
| 387 |
+
benchmark = True)
|
| 388 |
+
|
| 389 |
+
trainer.fit(model, ckpt_path=cfg.resume_ckpt_path)
|
utilss.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class ActivationsAndGradients:
|
| 7 |
+
""" Class for extracting activations and
|
| 8 |
+
registering gradients from targeted intermediate layers """
|
| 9 |
+
|
| 10 |
+
def __init__(self, model, target_layers, reshape_transform):
|
| 11 |
+
self.model = model
|
| 12 |
+
self.gradients = []
|
| 13 |
+
self.activations = []
|
| 14 |
+
self.reshape_transform = reshape_transform
|
| 15 |
+
self.handles = []
|
| 16 |
+
for target_layer in target_layers:
|
| 17 |
+
self.handles.append(
|
| 18 |
+
target_layer.register_forward_hook(
|
| 19 |
+
self.save_activation))
|
| 20 |
+
# Backward compatibility with older pytorch versions:
|
| 21 |
+
if hasattr(target_layer, 'register_full_backward_hook'):
|
| 22 |
+
self.handles.append(
|
| 23 |
+
target_layer.register_full_backward_hook(
|
| 24 |
+
self.save_gradient))
|
| 25 |
+
else:
|
| 26 |
+
self.handles.append(
|
| 27 |
+
target_layer.register_backward_hook(
|
| 28 |
+
self.save_gradient))
|
| 29 |
+
|
| 30 |
+
def save_activation(self, module, input, output):
|
| 31 |
+
activation = output
|
| 32 |
+
if self.reshape_transform is not None:
|
| 33 |
+
activation = self.reshape_transform(activation)
|
| 34 |
+
self.activations.append(activation.cpu().detach())
|
| 35 |
+
|
| 36 |
+
def save_gradient(self, module, grad_input, grad_output):
|
| 37 |
+
# Gradients are computed in reverse order
|
| 38 |
+
grad = grad_output[0]
|
| 39 |
+
if self.reshape_transform is not None:
|
| 40 |
+
grad = self.reshape_transform(grad)
|
| 41 |
+
self.gradients = [grad.cpu().detach()] + self.gradients
|
| 42 |
+
|
| 43 |
+
def __call__(self, x, y):
|
| 44 |
+
self.gradients = []
|
| 45 |
+
self.activations = []
|
| 46 |
+
return self.model(x, y)
|
| 47 |
+
|
| 48 |
+
def release(self):
|
| 49 |
+
for handle in self.handles:
|
| 50 |
+
handle.remove()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class GradCAM:
|
| 54 |
+
def __init__(self,
|
| 55 |
+
cfg,
|
| 56 |
+
model,
|
| 57 |
+
target_layers,
|
| 58 |
+
reshape_transform=None,
|
| 59 |
+
use_cuda=False):
|
| 60 |
+
self.cfg = cfg
|
| 61 |
+
self.model = model.eval()
|
| 62 |
+
self.target_layers = target_layers
|
| 63 |
+
self.reshape_transform = reshape_transform
|
| 64 |
+
self.cuda = use_cuda
|
| 65 |
+
if self.cuda:
|
| 66 |
+
self.model = model.cuda()
|
| 67 |
+
self.activations_and_grads = ActivationsAndGradients(
|
| 68 |
+
self.model, target_layers, reshape_transform)
|
| 69 |
+
|
| 70 |
+
""" Get a vector of weights for every channel in the target layer.
|
| 71 |
+
Methods that return weights channels,
|
| 72 |
+
will typically need to only implement this function. """
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def get_cam_weights(grads):
|
| 76 |
+
return np.mean(grads, axis=(2, 3), keepdims=True)
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def get_loss(output, target_category):
|
| 80 |
+
loss = 0
|
| 81 |
+
for i in range(len(target_category)):
|
| 82 |
+
loss = loss + output[i]
|
| 83 |
+
return loss
|
| 84 |
+
|
| 85 |
+
def get_cam_image(self, activations, grads):
|
| 86 |
+
weights = self.get_cam_weights(grads)
|
| 87 |
+
weighted_activations = weights * activations
|
| 88 |
+
cam = weighted_activations.sum(axis=1)
|
| 89 |
+
|
| 90 |
+
return cam
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def get_target_width_height(input_tensor):
|
| 94 |
+
width, height = input_tensor.size(-1), input_tensor.size(-2)
|
| 95 |
+
return width, height
|
| 96 |
+
|
| 97 |
+
def compute_cam_per_layer(self, input_tensor):
|
| 98 |
+
activations_list = [a.cpu().data.numpy()
|
| 99 |
+
for a in self.activations_and_grads.activations]
|
| 100 |
+
grads_list = [g.cpu().data.numpy()
|
| 101 |
+
for g in self.activations_and_grads.gradients]
|
| 102 |
+
target_size = self.get_target_width_height(input_tensor)
|
| 103 |
+
|
| 104 |
+
cam_per_target_layer = []
|
| 105 |
+
# Loop over the saliency image from every layer
|
| 106 |
+
|
| 107 |
+
for layer_activations, layer_grads in zip(activations_list, grads_list):
|
| 108 |
+
cam = self.get_cam_image(layer_activations, layer_grads)
|
| 109 |
+
cam[cam < 0] = 0 # works like mute the min-max scale in the function of scale_cam_image
|
| 110 |
+
scaled = self.scale_cam_image(cam, target_size)
|
| 111 |
+
cam_per_target_layer.append(scaled[:, None, :])
|
| 112 |
+
|
| 113 |
+
return cam_per_target_layer
|
| 114 |
+
|
| 115 |
+
def aggregate_multi_layers(self, cam_per_target_layer):
|
| 116 |
+
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
|
| 117 |
+
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
|
| 118 |
+
result = np.mean(cam_per_target_layer, axis=1)
|
| 119 |
+
return self.scale_cam_image(result)
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def scale_cam_image(cam, target_size=None):
|
| 123 |
+
result = []
|
| 124 |
+
for img in cam:
|
| 125 |
+
img = img - np.min(img)
|
| 126 |
+
img = img / (1e-7 + np.max(img))
|
| 127 |
+
if target_size is not None:
|
| 128 |
+
img = cv2.resize(img, target_size)
|
| 129 |
+
result.append(img)
|
| 130 |
+
result = np.float32(result)
|
| 131 |
+
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
def __call__(self, input_tensor, target_category=None):
|
| 135 |
+
x, y = input_tensor
|
| 136 |
+
if self.cuda:
|
| 137 |
+
x = x.cuda()
|
| 138 |
+
y = y.cuda()
|
| 139 |
+
|
| 140 |
+
# 正向传播得到网络输出logits(未经过softmax)
|
| 141 |
+
if self.cfg.net == 'cdmask':
|
| 142 |
+
o, outputs = self.activations_and_grads(x, y)
|
| 143 |
+
mask_cls_results = outputs["pred_logits"]
|
| 144 |
+
mask_pred_results = outputs["pred_masks"]
|
| 145 |
+
mask_pred_results = F.interpolate(
|
| 146 |
+
mask_pred_results,
|
| 147 |
+
scale_factor=(4,4),
|
| 148 |
+
mode="bilinear",
|
| 149 |
+
align_corners=False,
|
| 150 |
+
)
|
| 151 |
+
mask_cls = F.softmax(mask_cls_results, dim=-1)[...,1:]
|
| 152 |
+
mask_pred = mask_pred_results.sigmoid()
|
| 153 |
+
output = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred)
|
| 154 |
+
else:
|
| 155 |
+
output = self.activations_and_grads(x, y)
|
| 156 |
+
|
| 157 |
+
if isinstance(target_category, int):
|
| 158 |
+
target_category = [target_category] * x.size(0)
|
| 159 |
+
|
| 160 |
+
if target_category is None:
|
| 161 |
+
target_category = np.argmax(output.cpu().data.numpy(), axis=-1)
|
| 162 |
+
print(f"category id: {target_category}")
|
| 163 |
+
else:
|
| 164 |
+
assert (len(target_category) == x.size(0))
|
| 165 |
+
|
| 166 |
+
self.model.zero_grad()
|
| 167 |
+
loss = self.get_loss(output, target_category).sum()
|
| 168 |
+
loss.backward(retain_graph=True)
|
| 169 |
+
|
| 170 |
+
# In most of the saliency attribution papers, the saliency is
|
| 171 |
+
# computed with a single target layer.
|
| 172 |
+
# Commonly it is the last convolutional layer.
|
| 173 |
+
# Here we support passing a list with multiple target layers.
|
| 174 |
+
# It will compute the saliency image for every image,
|
| 175 |
+
# and then aggregate them (with a default mean aggregation).
|
| 176 |
+
# This gives you more flexibility in case you just want to
|
| 177 |
+
# use all conv layers for example, all Batchnorm layers,
|
| 178 |
+
# or something else.
|
| 179 |
+
cam_per_layer = self.compute_cam_per_layer(x)
|
| 180 |
+
return self.aggregate_multi_layers(cam_per_layer)
|
| 181 |
+
|
| 182 |
+
def __del__(self):
|
| 183 |
+
self.activations_and_grads.release()
|
| 184 |
+
|
| 185 |
+
def __enter__(self):
|
| 186 |
+
return self
|
| 187 |
+
|
| 188 |
+
def __exit__(self, exc_type, exc_value, exc_tb):
|
| 189 |
+
self.activations_and_grads.release()
|
| 190 |
+
if isinstance(exc_value, IndexError):
|
| 191 |
+
# Handle IndexError here...
|
| 192 |
+
print(
|
| 193 |
+
f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
|
| 194 |
+
return True
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def show_cam_on_image(img: np.ndarray,
|
| 198 |
+
mask: np.ndarray,
|
| 199 |
+
use_rgb: bool = False,
|
| 200 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
| 201 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
| 202 |
+
By default the heatmap is in BGR format.
|
| 203 |
+
|
| 204 |
+
:param img: The base image in RGB or BGR format.
|
| 205 |
+
:param mask: The cam mask.
|
| 206 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
| 207 |
+
:param colormap: The OpenCV colormap to be used.
|
| 208 |
+
:returns: The default image with the cam overlay.
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
| 212 |
+
if use_rgb:
|
| 213 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
| 214 |
+
heatmap = np.float32(heatmap) / 255
|
| 215 |
+
|
| 216 |
+
if np.max(img) > 1:
|
| 217 |
+
raise Exception(
|
| 218 |
+
"The input image should np.float32 in the range [0, 1]")
|
| 219 |
+
|
| 220 |
+
cam = heatmap + img
|
| 221 |
+
cam = cam / np.max(cam)
|
| 222 |
+
return np.uint8(255 * cam)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def center_crop_img(img: np.ndarray, size: int):
|
| 226 |
+
h, w, c = img.shape
|
| 227 |
+
|
| 228 |
+
if w == h == size:
|
| 229 |
+
return img
|
| 230 |
+
|
| 231 |
+
if w < h:
|
| 232 |
+
ratio = size / w
|
| 233 |
+
new_w = size
|
| 234 |
+
new_h = int(h * ratio)
|
| 235 |
+
else:
|
| 236 |
+
ratio = size / h
|
| 237 |
+
new_h = size
|
| 238 |
+
new_w = int(w * ratio)
|
| 239 |
+
|
| 240 |
+
img = cv2.resize(img, dsize=(new_w, new_h))
|
| 241 |
+
|
| 242 |
+
if new_w == size:
|
| 243 |
+
h = (new_h - size) // 2
|
| 244 |
+
img = img[h: h+size]
|
| 245 |
+
else:
|
| 246 |
+
w = (new_w - size) // 2
|
| 247 |
+
img = img[:, w: w+size]
|
| 248 |
+
|
| 249 |
+
return img
|