|
|
import torch |
|
|
from torch import nn |
|
|
from tqdm import tqdm |
|
|
import prettytable |
|
|
import time |
|
|
import os |
|
|
import multiprocessing.pool as mpp |
|
|
import multiprocessing as mp |
|
|
|
|
|
from train import * |
|
|
|
|
|
import argparse |
|
|
from utils.config import Config |
|
|
from tools.mask_convert import mask_save |
|
|
import numpy as np |
|
|
import csv |
|
|
|
|
|
|
|
|
class PRHistogram: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, nbins: int = 1000): |
|
|
import numpy as _np |
|
|
self.nbins = int(nbins) |
|
|
self.pos_hist = _np.zeros(self.nbins, dtype=_np.int64) |
|
|
self.neg_hist = _np.zeros(self.nbins, dtype=_np.int64) |
|
|
self.bin_edges = _np.linspace(0.0, 1.0, self.nbins + 1) |
|
|
|
|
|
def update(self, probs, mask): |
|
|
import numpy as _np |
|
|
p = probs.detach().float().cpu().numpy().ravel() |
|
|
g = (mask.detach().cpu().numpy().ravel() > 0).astype(_np.uint8) |
|
|
pos_counts, _ = _np.histogram(p[g == 1], bins=self.bin_edges) |
|
|
neg_counts, _ = _np.histogram(p[g == 0], bins=self.bin_edges) |
|
|
self.pos_hist += pos_counts |
|
|
self.neg_hist += neg_counts |
|
|
|
|
|
def compute_curve(self): |
|
|
import numpy as _np |
|
|
|
|
|
pos_cum = _np.cumsum(self.pos_hist[::-1]) |
|
|
neg_cum = _np.cumsum(self.neg_hist[::-1]) |
|
|
TP = pos_cum |
|
|
FP = neg_cum |
|
|
FN = self.pos_hist.sum() - TP |
|
|
TN = None |
|
|
|
|
|
denom_prec = _np.maximum(TP + FP, 1) |
|
|
denom_rec = _np.maximum(TP + FN, 1) |
|
|
precision = TP / denom_prec |
|
|
recall = TP / denom_rec |
|
|
|
|
|
|
|
|
denom_f1 = _np.maximum(precision + recall, 1e-12) |
|
|
f1 = 2.0 * precision * recall / denom_f1 |
|
|
|
|
|
|
|
|
denom_iou = _np.maximum(TP + FP + FN, 1) |
|
|
iou = TP / denom_iou |
|
|
|
|
|
thresholds = self.bin_edges[::-1][1:] |
|
|
return thresholds, precision, recall, f1, iou, TP, FP, FN |
|
|
|
|
|
def export_csv(self, save_path: str): |
|
|
thresholds, precision, recall, f1, iou, TP, FP, FN = self.compute_curve() |
|
|
import numpy as _np, os as _os |
|
|
_os.makedirs(_os.path.dirname(save_path), exist_ok=True) |
|
|
_np.savetxt( |
|
|
save_path, |
|
|
_np.column_stack([thresholds, precision, recall, f1, iou, TP, FP, FN]), |
|
|
delimiter=",", |
|
|
header="threshold,precision,recall,f1,iou,TP,FP,FN", |
|
|
comments="" |
|
|
) |
|
|
return save_path |
|
|
|
|
|
|
|
|
_PR = None |
|
|
|
|
|
def pr_init(nbins: int = 1000): |
|
|
global _PR |
|
|
if _PR is None: |
|
|
_PR = PRHistogram(nbins=nbins) |
|
|
return _PR |
|
|
|
|
|
def pr_update_from_outputs(raw_predictions, mask, cfg): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
global _PR |
|
|
if _PR is None: |
|
|
_PR = PRHistogram(nbins=1000) |
|
|
|
|
|
if getattr(cfg, 'argmax', False): |
|
|
logits = raw_predictions |
|
|
if logits.dim() == 4 and logits.size(1) >= 2: |
|
|
probs = torch.softmax(logits, dim=1)[:, 1, :, :] |
|
|
else: |
|
|
probs = torch.sigmoid(logits.squeeze(1)) |
|
|
else: |
|
|
if getattr(cfg, 'net', '') == 'maskcd': |
|
|
if isinstance(raw_predictions, (list, tuple)): |
|
|
logits = raw_predictions[0] |
|
|
else: |
|
|
logits = raw_predictions |
|
|
probs = torch.sigmoid(logits).squeeze(1) |
|
|
else: |
|
|
logits = raw_predictions |
|
|
if logits.dim() == 4 and logits.size(1) == 1: |
|
|
logits = logits.squeeze(1) |
|
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
if mask.dim() == 4 and mask.size(1) == 1: |
|
|
mask_ = mask.squeeze(1) |
|
|
else: |
|
|
mask_ = mask |
|
|
_PR.update(probs, (mask_ > 0).to(probs.dtype)) |
|
|
|
|
|
def pr_export(base_dir: str, cfg): |
|
|
|
|
|
import os |
|
|
global _PR |
|
|
if _PR is None: |
|
|
return None |
|
|
save_path = os.path.join(base_dir, f"pr_{getattr(cfg,'net','model')}.csv") |
|
|
out = _PR.export_csv(save_path) |
|
|
print(f"[PR] saved: {out}") |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def _safe_div(a, b, eps=1e-12): |
|
|
return a / max(b, eps) |
|
|
|
|
|
def per_image_stats(pred_np: np.ndarray, gt_np: np.ndarray): |
|
|
""" |
|
|
pred_np, gt_np: 0/1 二值 numpy 数组, shape [H,W] |
|
|
返回: dict 包含 TP/FP/TN/FN 与各类指标 |
|
|
""" |
|
|
pred_bin = (pred_np > 0).astype(np.uint8) |
|
|
gt_bin = (gt_np > 0).astype(np.uint8) |
|
|
|
|
|
TP = int(((pred_bin == 1) & (gt_bin == 1)).sum()) |
|
|
FP = int(((pred_bin == 1) & (gt_bin == 0)).sum()) |
|
|
TN = int(((pred_bin == 0) & (gt_bin == 0)).sum()) |
|
|
FN = int(((pred_bin == 0) & (gt_bin == 1)).sum()) |
|
|
|
|
|
precision = _safe_div(TP, (TP + FP)) |
|
|
recall = _safe_div(TP, (TP + FN)) |
|
|
f1 = _safe_div(2 * precision * recall, (precision + recall)) |
|
|
iou = _safe_div(TP, (TP + FP + FN)) |
|
|
oa = _safe_div(TP + TN, (TP + TN + FP + FN)) |
|
|
|
|
|
return { |
|
|
"TP": TP, "FP": FP, "TN": TN, "FN": FN, |
|
|
"OA": oa, "Precision": precision, "Recall": recall, "F1": f1, "IoU": iou |
|
|
} |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser('description=Change detection of remote sensing images') |
|
|
parser.add_argument("-c", "--config", type=str, default="configs/cdlama.py") |
|
|
parser.add_argument("--ckpt", type=str, default=None) |
|
|
parser.add_argument("--output_dir", type=str, default=None) |
|
|
|
|
|
parser.add_argument("--tables-only", action="store_true", |
|
|
help="仅生成表格与CSV(总体表、逐图CSV、逐图TXT、小计PR曲线CSV),不生成mask可视化图片") |
|
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = get_args() |
|
|
cfg = Config.fromfile(args.config) |
|
|
|
|
|
ckpt = args.ckpt |
|
|
if ckpt is None: |
|
|
ckpt = cfg.test_ckpt_path |
|
|
assert ckpt is not None |
|
|
|
|
|
if args.output_dir: |
|
|
base_dir = args.output_dir |
|
|
else: |
|
|
base_dir = os.path.dirname(ckpt) |
|
|
|
|
|
|
|
|
masks_output_dir = os.path.join(base_dir, "mask_rgb") |
|
|
|
|
|
tables_output_dir = os.path.join(base_dir, "tables_only" if args.tables_only else "mask_rgb") |
|
|
os.makedirs(tables_output_dir, exist_ok=True) |
|
|
|
|
|
model = myTrain.load_from_checkpoint(ckpt, map_location={'cuda:1':'cuda:0'}, cfg = cfg) |
|
|
model = model.to('cuda') |
|
|
model.eval() |
|
|
|
|
|
metric_cfg_1 = cfg.metric_cfg1 |
|
|
metric_cfg_2 = cfg.metric_cfg2 |
|
|
|
|
|
test_oa=torchmetrics.Accuracy(**metric_cfg_1).to('cuda') |
|
|
test_prec = torchmetrics.Precision(**metric_cfg_2).to('cuda') |
|
|
test_recall = torchmetrics.Recall(**metric_cfg_2).to('cuda') |
|
|
test_f1 = torchmetrics.F1Score(**metric_cfg_2).to('cuda') |
|
|
test_iou=torchmetrics.JaccardIndex(**metric_cfg_2).to('cuda') |
|
|
|
|
|
results = [] |
|
|
per_image_rows = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
test_loader = build_dataloader(cfg.dataset_config, mode='test') |
|
|
|
|
|
pr_init(nbins=1000) |
|
|
|
|
|
for input in tqdm(test_loader): |
|
|
raw_predictions, mask, img_id = model(input[0].cuda(), input[1].cuda()), input[2].cuda(), input[3] |
|
|
|
|
|
pr_update_from_outputs(raw_predictions, mask, cfg) |
|
|
|
|
|
if cfg.net == 'SARASNet': |
|
|
mask = Variable(resize_label(mask.data.cpu().numpy(), \ |
|
|
size=raw_predictions.data.cpu().numpy().shape[2:]).to('cuda')).long() |
|
|
param = 1 |
|
|
raw_predictions[:,1,:,:] = raw_predictions[:,1,:,:] + param |
|
|
|
|
|
if cfg.argmax: |
|
|
pred = raw_predictions.argmax(dim=1) |
|
|
else: |
|
|
if cfg.net == 'maskcd': |
|
|
pred = raw_predictions[0] |
|
|
pred = pred > 0.5 |
|
|
pred.squeeze_(1) |
|
|
else: |
|
|
pred = raw_predictions.squeeze(1) |
|
|
pred = pred > 0.5 |
|
|
|
|
|
|
|
|
test_oa(pred, mask) |
|
|
test_iou(pred, mask) |
|
|
test_prec(pred, mask) |
|
|
test_f1(pred, mask) |
|
|
test_recall(pred, mask) |
|
|
|
|
|
|
|
|
for i in range(raw_predictions.shape[0]): |
|
|
mask_real = mask[i].detach().cpu().numpy() |
|
|
mask_pred = pred[i].detach().cpu().numpy() |
|
|
mask_name = str(img_id[i]) |
|
|
|
|
|
|
|
|
stats = per_image_stats(mask_pred, mask_real) |
|
|
per_image_rows.append({ |
|
|
"img_id": mask_name, |
|
|
"TP": stats["TP"], "FP": stats["FP"], "TN": stats["TN"], "FN": stats["FN"], |
|
|
"OA": stats["OA"], "Precision": stats["Precision"], |
|
|
"Recall": stats["Recall"], "F1": stats["F1"], "IoU": stats["IoU"] |
|
|
}) |
|
|
|
|
|
|
|
|
if not args.tables_only: |
|
|
results.append((mask_real, mask_pred, masks_output_dir, mask_name)) |
|
|
|
|
|
|
|
|
metrics = [test_prec.compute(), |
|
|
test_recall.compute(), |
|
|
test_f1.compute(), |
|
|
test_iou.compute()] |
|
|
|
|
|
total_metrics = [test_oa.compute().cpu().numpy(), |
|
|
np.mean([item.cpu() for item in metrics[0]]), |
|
|
np.mean([item.cpu() for item in metrics[1]]), |
|
|
np.mean([item.cpu() for item in metrics[2]]), |
|
|
np.mean([item.cpu() for item in metrics[3]])] |
|
|
|
|
|
result_table = prettytable.PrettyTable() |
|
|
result_table.field_names = ['Class', 'OA', 'Precision', 'Recall', 'F1_Score', 'IOU'] |
|
|
|
|
|
for i in range(2): |
|
|
item = [i, '--'] |
|
|
for j in range(len(metrics)): |
|
|
item.append(np.round(metrics[j][i].cpu().numpy(), 4)) |
|
|
result_table.add_row(item) |
|
|
|
|
|
total = [np.round(v, 4) for v in total_metrics] |
|
|
total.insert(0, 'total') |
|
|
result_table.add_row(total) |
|
|
print(result_table) |
|
|
|
|
|
file_name = os.path.join(base_dir, "test_res.txt") |
|
|
f = open(file_name,"a") |
|
|
current_time = time.strftime('%Y_%m_%d %H:%M:%S {}'.format(cfg.net),time.localtime(time.time())) |
|
|
f.write(current_time+'\n') |
|
|
f.write(str(result_table)+'\n') |
|
|
|
|
|
|
|
|
if not args.tables_only: |
|
|
if not os.path.exists(masks_output_dir): |
|
|
os.makedirs(masks_output_dir) |
|
|
print(masks_output_dir) |
|
|
|
|
|
|
|
|
t0 = time.time() |
|
|
mpp.Pool(processes=mp.cpu_count()).map(mask_save, results) |
|
|
t1 = time.time() |
|
|
img_write_time = t1 - t0 |
|
|
print('images writing spends: {} s'.format(img_write_time)) |
|
|
else: |
|
|
print("[Mode] --tables-only: 跳过可视化图片的生成,仅导出表格/CSV。") |
|
|
|
|
|
|
|
|
per_image_csv = os.path.join(base_dir, f"per_image_metrics_{getattr(cfg,'net','model')}.csv") |
|
|
with open(per_image_csv, "w", newline="") as wf: |
|
|
writer = csv.DictWriter( |
|
|
wf, |
|
|
fieldnames=["img_id","TP","FP","TN","FN","OA","Precision","Recall","F1","IoU"] |
|
|
) |
|
|
writer.writeheader() |
|
|
for row in per_image_rows: |
|
|
row_out = dict(row) |
|
|
for k in ["OA","Precision","Recall","F1","IoU"]: |
|
|
row_out[k] = float(np.round(row_out[k], 6)) |
|
|
writer.writerow(row_out) |
|
|
print(f"[Per-Image] saved CSV: {per_image_csv}") |
|
|
|
|
|
|
|
|
for row in per_image_rows: |
|
|
txt_path = os.path.join(tables_output_dir, f"{row['img_id']}_metrics.txt") |
|
|
pt = prettytable.PrettyTable() |
|
|
pt.field_names = ["Metric", "Value"] |
|
|
|
|
|
pt.add_row(["TP", row["TP"]]) |
|
|
pt.add_row(["FP", row["FP"]]) |
|
|
pt.add_row(["TN", row["TN"]]) |
|
|
pt.add_row(["FN", row["FN"]]) |
|
|
|
|
|
pt.add_row(["OA", f"{row['OA']:.6f}"]) |
|
|
pt.add_row(["Precision",f"{row['Precision']:.6f}"]) |
|
|
pt.add_row(["Recall", f"{row['Recall']:.6f}"]) |
|
|
pt.add_row(["F1", f"{row['F1']:.6f}"]) |
|
|
pt.add_row(["IoU", f"{row['IoU']:.6f}"]) |
|
|
with open(txt_path, "w") as wf: |
|
|
wf.write(str(pt)) |
|
|
print(f"[Per-Image] per-image tables saved to: {tables_output_dir}") |
|
|
|
|
|
|
|
|
try: |
|
|
pr_export(base_dir, cfg) |
|
|
except Exception as e: |
|
|
print(f"[PR] export skipped or failed: {e}") |
|
|
|