|
import itertools |
|
import os |
|
import random |
|
import shutil |
|
from math import ceil |
|
from typing import Dict, List |
|
|
|
import numpy as np |
|
import prettytable as pt |
|
import torch |
|
import torch.nn as nn |
|
from fast_pytorch_kmeans import KMeans |
|
from pathlib import Path |
|
from scipy.stats import hmean |
|
from sklearn import metrics |
|
from termcolor import cprint |
|
from torchvision.utils import draw_segmentation_masks, make_grid, save_image |
|
|
|
import utils.misc as misc |
|
from losses import get_spixel_tgt_map, get_volume_seg_map |
|
from utils.convcrf import convcrf |
|
from utils.crf import DenseCRF |
|
|
|
|
|
def train( |
|
model: nn.Module, |
|
dataloader, |
|
dataset_title: str, |
|
optimizer_dict: Dict, |
|
criterion, |
|
epoch: int, |
|
writer, |
|
suffix: str, |
|
opt, |
|
): |
|
|
|
metric_logger = misc.MetricLogger(writer=writer, suffix=suffix) |
|
cprint("{}-th epoch training on {}".format(epoch, dataset_title), "blue") |
|
model.train() |
|
roc_auc_elements = { |
|
modality: {"map_scores": [], "vol_scores": []} |
|
for modality in itertools.chain(opt.modality, ["ensemble"]) |
|
} |
|
roc_auc_elements["labels"] = [] |
|
|
|
for i, data in metric_logger.log_every( |
|
dataloader, print_freq=opt.print_freq, header=f"[{suffix} {epoch}]" |
|
): |
|
if (opt.debug or opt.wholetest) and i > 50: |
|
break |
|
|
|
for modality, optimizer in optimizer_dict.items(): |
|
optimizer.zero_grad() |
|
|
|
image = data["image"].to(opt.device) |
|
unnormalized_image = data["unnormalized_image"].to(opt.device) |
|
label = data["label"].to(opt.device) |
|
mask = data["mask"].to(opt.device) |
|
spixel = data["spixel"].to(opt.device) if opt.mvc_spixel else None |
|
|
|
outputs = model( |
|
image, |
|
seg_size=None |
|
if opt.loss_on_mid_map |
|
else [image.shape[-2], image.shape[-1]], |
|
) |
|
|
|
losses = criterion( |
|
outputs, |
|
label, |
|
mask, |
|
epoch=epoch, |
|
max_epoch=opt.epochs, |
|
spixel=spixel, |
|
raw_image=unnormalized_image, |
|
) |
|
total_loss = losses["total_loss"] |
|
total_loss.backward() |
|
|
|
for modality in opt.modality: |
|
if opt.grad_clip > 0.0: |
|
grad_norm = nn.utils.clip_grad_norm_( |
|
model.sub_models[modality].parameters(), opt.grad_clip |
|
) |
|
metric_logger.update(**{f"grad_norm/{modality}": grad_norm}) |
|
|
|
optimizer_dict[modality].step() |
|
|
|
|
|
roc_auc_elements["labels"].extend(label.tolist()) |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
roc_auc_elements[modality]["map_scores"].extend( |
|
outputs[modality]["map_pred"].tolist() |
|
) |
|
roc_auc_elements[modality]["vol_scores"].extend( |
|
(outputs[modality]["vol_pred"]).tolist() |
|
) |
|
|
|
metric_logger.update(**losses) |
|
|
|
image_metrics = update_image_roc_auc_metric( |
|
opt.modality + ["ensemble"], roc_auc_elements, None |
|
) |
|
metric_logger.update(**image_metrics) |
|
|
|
metric_logger.write_tensorboard(epoch) |
|
print("Average status:") |
|
print(metric_logger.stat_table()) |
|
|
|
|
|
def bundled_evaluate( |
|
model: nn.Module, dataloaders: Dict, criterion, epoch, writer, suffix, opt |
|
): |
|
|
|
metric_logger = misc.MetricLogger(writer=writer, suffix=suffix + "_avg") |
|
for dataset, dataloader in dataloaders.items(): |
|
outputs = evaluate( |
|
model, |
|
dataloader, |
|
criterion, |
|
dataset, |
|
epoch, |
|
writer, |
|
suffix + f"_{dataset}", |
|
opt, |
|
) |
|
old_keys = list(outputs.keys()) |
|
for k in old_keys: |
|
outputs[k.replace(dataset.upper(), "AVG")] = outputs[k] |
|
for k in old_keys: |
|
del outputs[k] |
|
|
|
metric_logger.update(**outputs) |
|
|
|
metric_logger.write_tensorboard(epoch) |
|
print("Average status:") |
|
print(metric_logger.stat_table()) |
|
return metric_logger.get_meters() |
|
|
|
|
|
def evaluate( |
|
model: nn.Module, |
|
dataloader, |
|
criterion, |
|
dataset_title: str, |
|
epoch: int, |
|
writer, |
|
suffix: str, |
|
opt, |
|
): |
|
|
|
metric_logger = misc.MetricLogger(writer=writer, suffix=suffix) |
|
cprint("{}-th epoch evaluation on {}".format(epoch, dataset_title.upper()), "blue") |
|
|
|
model.eval() |
|
|
|
if opt.crf_postproc: |
|
postprocess = DenseCRF( |
|
iter_max=opt.crf_iter_max, |
|
pos_w=opt.crf_pos_w, |
|
pos_xy_std=opt.crf_pos_xy_std, |
|
bi_w=opt.crf_bi_w, |
|
bi_xy_std=opt.crf_bi_xy_std, |
|
bi_rgb_std=opt.crf_bi_rgb_std, |
|
) |
|
elif opt.convcrf_postproc: |
|
convcrf_config = convcrf.default_conf |
|
convcrf_config["skip_init_softmax"] = True |
|
convcrf_config["final_softmax"] = True |
|
shape = [opt.convcrf_shape, opt.convcrf_shape] |
|
postprocess = convcrf.GaussCRF( |
|
conf=convcrf_config, shape=shape, nclasses=2, use_gpu=True |
|
).to(opt.device) |
|
|
|
figure_path = opt.figure_path + f"_{dataset_title.upper()}" |
|
if opt.save_figure: |
|
if os.path.exists(figure_path): |
|
shutil.rmtree(figure_path) |
|
os.mkdir(figure_path) |
|
cprint("Saving figures to {}".format(figure_path), "blue") |
|
|
|
if opt.max_pool_postproc > 1: |
|
max_pool = nn.MaxPool2d( |
|
kernel_size=opt.max_pool_postproc, |
|
stride=1, |
|
padding=(opt.max_pool_postproc - 1) // 2, |
|
).to(opt.device) |
|
else: |
|
max_pool = nn.Identity().to(opt.device) |
|
|
|
roc_auc_elements = { |
|
modality: {"map_scores": [], "vol_scores": []} |
|
for modality in itertools.chain(opt.modality, ["ensemble"]) |
|
} |
|
roc_auc_elements["labels"] = [] |
|
with torch.no_grad(): |
|
for i, data in metric_logger.log_every( |
|
dataloader, print_freq=opt.print_freq, header=f"[{suffix} {epoch}]" |
|
): |
|
if (opt.debug or opt.wholetest) and i > 50: |
|
break |
|
|
|
image_size = data["image"].shape[-2:] |
|
label = data["label"] |
|
mask = data["mask"] |
|
if opt.crf_postproc or opt.spixel_postproc or opt.convcrf_postproc: |
|
spixel = data["spixel"].to(opt.device) |
|
if max(image_size) > opt.tile_size and opt.large_image_strategy == "slide": |
|
outputs = sliding_predict( |
|
model, data, opt.tile_size, opt.tile_overlap, opt |
|
) |
|
else: |
|
image = data["image"].to(opt.device) |
|
outputs = model(image, seg_size=image.shape[-2:]) |
|
|
|
if opt.max_pool_postproc > 1: |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
outputs[modality]["out_map"] = max_pool( |
|
outputs[modality]["out_map"] |
|
) |
|
|
|
if opt.crf_postproc: |
|
raw_prob = outputs["ensemble"]["out_map"] |
|
image = data["unnormalized_image"] * 255.0 |
|
if opt.crf_downsample > 1: |
|
image = ( |
|
torch.nn.functional.interpolate( |
|
image, |
|
size=( |
|
image_size[0] // opt.crf_downsample, |
|
image_size[1] // opt.crf_downsample, |
|
), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
.clamp(0, 255) |
|
.int() |
|
) |
|
image = image.squeeze(0).numpy().astype(np.uint8).transpose(1, 2, 0) |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
prob = outputs[modality]["out_map"].squeeze(1) |
|
if opt.crf_downsample > 1: |
|
prob = ( |
|
torch.nn.functional.interpolate( |
|
prob, |
|
size=( |
|
image_size[0] // opt.crf_downsample, |
|
image_size[1] // opt.crf_downsample, |
|
), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
.clamp(0, 1) |
|
.squeeze(0) |
|
) |
|
prob = torch.cat([prob, 1 - prob], dim=0).detach().cpu().numpy() |
|
prob = postprocess(image, prob) |
|
prob = prob[None, 0, ...] |
|
prob = torch.tensor(prob, device=opt.device).unsqueeze(0) |
|
if opt.crf_downsample > 1: |
|
prob = torch.nn.functional.interpolate( |
|
prob, size=image_size, mode="bilinear", align_corners=False |
|
).clamp(0, 1) |
|
outputs[modality]["out_map"] = prob |
|
outputs[modality]["map_pred"] = ( |
|
outputs[modality]["out_map"].max().unsqueeze(0) |
|
) |
|
elif opt.convcrf_postproc: |
|
raw_prob = outputs["ensemble"]["out_map"] |
|
image = data["unnormalized_image"].to(opt.device) * 255.0 |
|
image = ( |
|
torch.nn.functional.interpolate( |
|
image, |
|
size=(opt.convcrf_shape, opt.convcrf_shape), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
.clamp(0, 255) |
|
.int() |
|
) |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
prob = outputs[modality]["out_map"] |
|
prob = torch.cat([prob, 1 - prob], dim=1) |
|
prob = torch.nn.functional.interpolate( |
|
prob, |
|
size=(opt.convcrf_shape, opt.convcrf_shape), |
|
mode="bilinear", |
|
align_corners=False, |
|
).clamp(0, 1) |
|
prob = postprocess(unary=prob, img=image) |
|
prob = torch.nn.functional.interpolate( |
|
prob, size=image_size, mode="bilinear", align_corners=False |
|
).clamp(0, 1) |
|
outputs[modality]["out_map"] = prob[:, 0, None, ...] |
|
outputs[modality]["map_pred"] = ( |
|
outputs[modality]["out_map"].max().unsqueeze(0) |
|
) |
|
elif opt.spixel_postproc: |
|
raw_prob = outputs["ensemble"]["out_map"] |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
outputs[modality]["out_map"] = get_spixel_tgt_map( |
|
outputs[modality]["out_map"], spixel |
|
) |
|
|
|
|
|
roc_auc_elements["labels"].extend(label.detach().cpu().tolist()) |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
roc_auc_elements[modality]["map_scores"].extend( |
|
outputs[modality]["map_pred"].detach().cpu().tolist() |
|
) |
|
roc_auc_elements[modality]["vol_scores"].extend( |
|
(outputs[modality]["vol_pred"]).detach().cpu().tolist() |
|
) |
|
|
|
|
|
out_map = { |
|
modality: outputs[modality]["out_map"] > opt.mask_threshold |
|
for modality in itertools.chain(opt.modality, ["ensemble"]) |
|
} |
|
|
|
|
|
if label.item() == 1.0: |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
pixel_metrics = misc.calculate_pixel_f1( |
|
out_map[modality].float().detach().cpu().numpy().flatten(), |
|
mask.detach().cpu().numpy().flatten(), |
|
suffix=f"/{modality}", |
|
) |
|
metric_logger.update(**pixel_metrics) |
|
|
|
|
|
if opt.save_figure: |
|
unnormalized_image = data["unnormalized_image"] |
|
|
|
image_id = Path(data["id"][0]).stem |
|
save_image( |
|
( |
|
outputs["ensemble"]["out_map"][0, ...] > opt.mask_threshold |
|
).float() |
|
* 255, |
|
os.path.join(figure_path, f"{image_id}_ensemble_map.png"), |
|
) |
|
|
|
image_metrics = update_image_roc_auc_metric( |
|
opt.modality + ["ensemble"], |
|
roc_auc_elements, |
|
{ |
|
modality: metric_logger.meters[f"pixel_f1/{modality}"].avg |
|
for modality in itertools.chain(opt.modality, ["ensemble"]) |
|
}, |
|
) |
|
metric_logger.update(**image_metrics) |
|
|
|
metric_logger.prepend_subprefix(f"{dataset_title.upper()}_") |
|
metric_logger.write_tensorboard(epoch) |
|
print("Average status:") |
|
print(metric_logger.stat_table()) |
|
|
|
return metric_logger.get_meters() |
|
|
|
|
|
def update_image_roc_auc_metric(modalities: List, roc_auc_elements, pixel_f1=None): |
|
|
|
result = {} |
|
for modality in modalities: |
|
image_metrics = misc.calculate_img_score( |
|
np.array(roc_auc_elements[modality]["map_scores"]) > 0.5, |
|
(np.array(roc_auc_elements["labels"]) > 0).astype(np.int), |
|
suffix=f"/{modality}", |
|
) |
|
if pixel_f1 is not None: |
|
image_f1 = image_metrics[f"image_f1/{modality}"] |
|
combined_f1 = hmean([image_f1, pixel_f1[modality]]) |
|
image_metrics[f"comb_f1/{modality}"] = float(combined_f1) |
|
if 0.0 in roc_auc_elements["labels"] and 1.0 in roc_auc_elements["labels"]: |
|
image_auc = metrics.roc_auc_score( |
|
roc_auc_elements["labels"], roc_auc_elements[modality]["map_scores"] |
|
) |
|
image_metrics[f"image_auc/{modality}"] = image_auc |
|
result.update(image_metrics) |
|
|
|
return result |
|
|
|
|
|
def pad_image(image, target_size): |
|
image_size = image.shape[-2:] |
|
if image_size != target_size: |
|
row_missing = target_size[0] - image_size[0] |
|
col_missing = target_size[1] - image_size[1] |
|
image = nn.functional.pad( |
|
image, (0, row_missing, 0, col_missing), "constant", 0 |
|
) |
|
return image |
|
|
|
|
|
def sliding_predict(model: nn.Module, data, tile_size, tile_overlap, opt): |
|
image = data["image"] |
|
mask = data["mask"] |
|
image = image.to(opt.device) |
|
image_size = image.shape[-2:] |
|
stride = ceil(tile_size * (1 - tile_overlap)) |
|
tile_rows = int(ceil((image_size[0] - tile_size) / stride) + 1) |
|
tile_cols = int(ceil((image_size[1] - tile_size) / stride) + 1) |
|
result = {} |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
result[modality] = { |
|
"out_map": torch.zeros_like( |
|
mask, requires_grad=False, dtype=torch.float32, device=opt.device |
|
), |
|
"out_vol_map": torch.zeros_like( |
|
mask, requires_grad=False, dtype=torch.float32, device=opt.device |
|
), |
|
} |
|
map_counter = torch.zeros_like( |
|
mask, requires_grad=False, dtype=torch.float32, device=opt.device |
|
) |
|
|
|
with torch.no_grad(): |
|
for row in range(tile_rows): |
|
for col in range(tile_cols): |
|
x1 = int(col * stride) |
|
y1 = int(row * stride) |
|
x2 = min(x1 + tile_size, image_size[1]) |
|
y2 = min(y1 + tile_size, image_size[0]) |
|
x1 = max(int(x2 - tile_size), 0) |
|
y1 = max(int(y2 - tile_size), 0) |
|
|
|
image_tile = image[:, :, y1:y2, x1:x2] |
|
image_tile = pad_image(image_tile, [opt.tile_size, opt.tile_size]) |
|
tile_outputs = model(image_tile, seg_size=(image_tile.shape[-2:])) |
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
result[modality]["out_map"][:, :, y1:y2, x1:x2] += tile_outputs[ |
|
modality |
|
]["out_map"][:, :, : y2 - y1, : x2 - x1] |
|
out_vol_map = get_volume_seg_map( |
|
tile_outputs[modality]["out_vol"], |
|
size=image_tile.shape[-2:], |
|
label=data["label"], |
|
kmeans=KMeans(2) if opt.consistency_kmeans else None, |
|
)[:, :, : y2 - y1, : x2 - x1] |
|
result[modality]["out_vol_map"][:, :, y1:y2, x1:x2] += out_vol_map |
|
map_counter[:, :, y1:y2, x1:x2] += 1 |
|
|
|
for modality in itertools.chain(opt.modality, ["ensemble"]): |
|
result[modality]["out_map"] /= map_counter |
|
result[modality]["out_vol_map"] /= map_counter |
|
result[modality]["map_pred"] = ( |
|
result[modality]["out_map"].max().unsqueeze(0) |
|
) |
|
result[modality]["vol_pred"] = ( |
|
result[modality]["out_vol_map"].max().unsqueeze(0) |
|
) |
|
|
|
return result |
|
|