LKCell / cell_segmentation /trainer /trainer_cellvit.py
qingke1's picture
initial commit
aea73e2
raw
history blame
44 kB
# -*- coding: utf-8 -*-
# CellViT Trainer Class
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import logging
from pathlib import Path
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import math
import csv
# import wandb
from matplotlib import pyplot as plt
from skimage.color import rgba2rgb
from sklearn.metrics import accuracy_score
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchmetrics.functional import dice
from torchmetrics.functional.classification import binary_jaccard_index
from base_ml.base_early_stopping import EarlyStopping
from base_ml.base_trainer import BaseTrainer
from models.segmentation.cell_segmentation.cellvit import DataclassHVStorage
from cell_segmentation.utils.metrics import get_fast_pq, remap_label
from cell_segmentation.utils.tools import cropping_center
from models.segmentation.cell_segmentation.cellvit import CellViT
from utils.tools import AverageMeter
from timm.utils import ModelEma
from torch.cuda.amp import GradScaler, autocast
class CellViTTrainer(BaseTrainer):
"""CellViT trainer class
Args:
model (CellViT): CellViT model that should be trained
loss_fn_dict (dict): Dictionary with loss functions for each branch with a dictionary of loss functions.
Name of branch as top-level key, followed by a dictionary with loss name, loss fn and weighting factor
Example:
{
"nuclei_binary_map": {"bce": {loss_fn(Callable), weight_factor(float)}, "dice": {loss_fn(Callable), weight_factor(float)}},
"hv_map": {"bce": {loss_fn(Callable), weight_factor(float)}, "dice": {loss_fn(Callable), weight_factor(float)}},
"nuclei_type_map": {"bce": {loss_fn(Callable), weight_factor(float)}, "dice": {loss_fn(Callable), weight_factor(float)}}
"tissue_types": {"ce": {loss_fn(Callable), weight_factor(float)}}
}
Required Keys are:
* nuclei_binary_map
* hv_map
* nuclei_type_map
* tissue types
optimizer (Optimizer): Optimizer
scheduler (_LRScheduler): Learning rate scheduler
device (str): Cuda device to use, e.g., cuda:0.
logger (logging.Logger): Logger module
logdir (Union[Path, str]): Logging directory
num_classes (int): Number of nuclei classes
dataset_config (dict): Dataset configuration. Required Keys are:
* "tissue_types": describing the present tissue types with corresponding integer
* "nuclei_types": describing the present nuclei types with corresponding integer
experiment_config (dict): Configuration of this experiment
early_stopping (EarlyStopping, optional): Early Stopping Class. Defaults to None.
log_images (bool, optional): If images should be logged to WandB. Defaults to False.
magnification (int, optional): Image magnification. Please select either 40 or 20. Defaults to 40.
mixed_precision (bool, optional): If mixed-precision should be used. Defaults to False.
"""
def __init__(
self,
model: CellViT,
loss_fn_dict: dict,
optimizer: Optimizer,
scheduler: _LRScheduler,
device: str,
logger: logging.Logger,
logdir: Union[Path, str],
num_classes: int,
dataset_config: dict,
experiment_config: dict,
early_stopping: EarlyStopping = None,
log_images: bool = False,
magnification: int = 40,
mixed_precision: bool = False,
#model_ema : bool = True,
):
super().__init__(
model=model,
loss_fn=None,
optimizer=optimizer,
scheduler=scheduler,
device=device,
logger=logger,
logdir=logdir,
experiment_config=experiment_config,
early_stopping=early_stopping,
accum_iter=1,
log_images=log_images,
mixed_precision=mixed_precision,
)
self.loss_fn_dict = loss_fn_dict
self.num_classes = num_classes
self.dataset_config = dataset_config
self.tissue_types = dataset_config["tissue_types"]
self.reverse_tissue_types = {v: k for k, v in self.tissue_types.items()}
self.nuclei_types = dataset_config["nuclei_types"]
self.magnification = magnification
#self.model_ema = model_ema
# setup logging objects
self.loss_avg_tracker = {"Total_Loss": AverageMeter("Total_Loss", ":.4f")}
for branch, loss_fns in self.loss_fn_dict.items():
for loss_name in loss_fns:
self.loss_avg_tracker[f"{branch}_{loss_name}"] = AverageMeter(
f"{branch}_{loss_name}", ":.4f"
)
self.batch_avg_tissue_acc = AverageMeter("Batch_avg_tissue_ACC", ":4.f")
def train_epoch(
self, epoch: int, train_dataloader: DataLoader, unfreeze_epoch: int = 50
) -> Tuple[dict, dict]:
"""Training logic for a training epoch
Args:
epoch (int): Current epoch number
train_dataloader (DataLoader): Train dataloader
unfreeze_epoch (int, optional): Epoch to unfreeze layers
Returns:
Tuple[dict, dict]: wandb logging dictionaries
* Scalar metrics
* Image metrics
"""
self.model.train()
if epoch >= unfreeze_epoch:
self.model.unfreeze_encoder()
# if self.model_ema and epoch == 0:
# self.model_ema_instance = ModelEma(
# model=self.model,
# decay=0.9999,
# device='cuda',
# resume=''
# )
binary_dice_scores = []
binary_jaccard_scores = []
tissue_pred = []
tissue_gt = []
train_example_img = None
# reset metrics
self.loss_avg_tracker["Total_Loss"].reset()
for branch, loss_fns in self.loss_fn_dict.items():
for loss_name in loss_fns:
self.loss_avg_tracker[f"{branch}_{loss_name}"].reset()
self.batch_avg_tissue_acc.reset()
# randomly select a batch that should be displayed
if self.log_images:
select_example_image = int(torch.randint(0, len(train_dataloader), (1,)))
else:
select_example_image = None
train_loop = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader))
for batch_idx, batch in train_loop:
return_example_images = batch_idx == select_example_image
batch_metrics, example_img = self.train_step(
batch,
batch_idx,
len(train_dataloader),
return_example_images=return_example_images,
)
if example_img is not None:
train_example_img = example_img
binary_dice_scores = (
binary_dice_scores + batch_metrics["binary_dice_scores"]
)
binary_jaccard_scores = (
binary_jaccard_scores + batch_metrics["binary_jaccard_scores"]
)
tissue_pred.append(batch_metrics["tissue_pred"])
tissue_gt.append(batch_metrics["tissue_gt"])
train_loop.set_postfix(
{
"Loss": np.round(self.loss_avg_tracker["Total_Loss"].avg, 3),
"Dice": np.round(np.nanmean(binary_dice_scores), 3),
"Pred-Acc": np.round(self.batch_avg_tissue_acc.avg, 3),
}
)
# calculate global metrics
binary_dice_scores = np.array(binary_dice_scores)
binary_jaccard_scores = np.array(binary_jaccard_scores)
tissue_detection_accuracy = accuracy_score(
y_true=np.concatenate(tissue_gt), y_pred=np.concatenate(tissue_pred)
)
scalar_metrics = {
"Loss/Train": self.loss_avg_tracker["Total_Loss"].avg,
"Binary-Cell-Dice-Mean/Train": np.nanmean(binary_dice_scores),
"Binary-Cell-Jacard-Mean/Train": np.nanmean(binary_jaccard_scores),
"Tissue-Multiclass-Accuracy/Train": tissue_detection_accuracy,
}
for branch, loss_fns in self.loss_fn_dict.items():
for loss_name in loss_fns:
scalar_metrics[f"{branch}_{loss_name}/Train"] = self.loss_avg_tracker[
f"{branch}_{loss_name}"
].avg
self.logger.info(
f"{'Training epoch stats:' : <25} "
f"Loss: {self.loss_avg_tracker['Total_Loss'].avg:.4f} - "
f"Binary-Cell-Dice: {np.nanmean(binary_dice_scores):.4f} - "
f"Binary-Cell-Jacard: {np.nanmean(binary_jaccard_scores):.4f} - "
f"Tissue-MC-Acc.: {tissue_detection_accuracy:.4f}"
)
image_metrics = {"Example-Predictions/Train": train_example_img}
return scalar_metrics, image_metrics
def train_step(
self,
batch: object,
batch_idx: int,
num_batches: int,
return_example_images: bool,
) -> Tuple[dict, Union[plt.Figure, None]]:
"""Training step
Args:
batch (object): Training batch, consisting of images ([0]), masks ([1]), tissue_types ([2]) and figure filenames ([3])
batch_idx (int): Batch index
num_batches (int): Total number of batches in epoch
return_example_images (bool): If an example preciction image should be returned
Returns:
Tuple[dict, Union[plt.Figure, None]]:
* Batch-Metrics: dictionary with the following keys:
* Example prediction image
"""
# unpack batch
imgs = batch[0].to(self.device) # imgs shape: (batch_size, 3, H, W) (16,3,256,256)
masks = batch[
1
] # dict: keys: "instance_map", [16,256,256],"nuclei_map",[16,256,256], "nuclei_binary_map",[16,256,256], "hv_map"[16,2,256,256]
tissue_types = batch[2] # list[str]
#change
#scaler = GradScaler(init_scale=2.0)
if self.mixed_precision:
with torch.autocast(device_type="cuda", dtype=torch.float16):
#with torch.cuda.amp.autocast(False):
# make predictions
predictions_ = self.model.forward(imgs) #img.shape=(16,3,256,256) model.forward(imgs) 'tissue_types'(16,19),'nuclei_binary_map'(16,2,128,128),'hv_map'(16,2,128,128),'nuclei_type_map'(16,6,128,128)
# reshaping and postprocessing
predictions = self.unpack_predictions(predictions=predictions_)
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types)
# calculate loss
total_loss = self.calculate_loss(predictions, gt)
# if torch.isnan(total_loss):
# print("nan in loss")
#if math.isnan(total_loss.item()):
#print("nan")
# import pdb; pdb.set_trace()
# backward pass
self.scaler.scale(total_loss).backward()
# 阈值剪切梯度
#torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0)
# if torch.any(torch.tensor([torch.any(torch.isnan(param.data)) for param in self.model.parameters()])):
# print("nan in model parameters")
if (
((batch_idx + 1) % self.accum_iter == 0)
or ((batch_idx + 1) == num_batches)
or (self.accum_iter == 1)
):
# self.scaler.unscale_(self.optimizer)
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
# if self.model_ema:
# self.model_ema_instance.update(self.model)
self.optimizer.zero_grad(set_to_none=True)
self.model.zero_grad()
else:
predictions_ = self.model.forward(imgs)
predictions = self.unpack_predictions(predictions=predictions_)
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types)
# calculate loss
total_loss = self.calculate_loss(predictions, gt)
total_loss.backward()
if (
((batch_idx + 1) % self.accum_iter == 0)
or ((batch_idx + 1) == num_batches)
or (self.accum_iter == 1)
):
self.optimizer.step()
# if self.model_ema:
# self.model_ema_instance.update(self.model)
self.optimizer.zero_grad(set_to_none=True)
self.model.zero_grad()
with torch.cuda.device(self.device):
torch.cuda.empty_cache()
batch_metrics = self.calculate_step_metric_train(predictions, gt)
if return_example_images:
return_example_images = self.generate_example_image(
imgs, predictions, gt, num_images=4, num_nuclei_classes=self.num_classes
)
else:
return_example_images = None
return batch_metrics, return_example_images
def validation_epoch(
self, epoch: int, val_dataloader: DataLoader
) -> Tuple[dict, dict, float]:
"""Validation logic for a validation epoch
Args:
epoch (int): Current epoch number
val_dataloader (DataLoader): Validation dataloader
Returns:
Tuple[dict, dict, float]: wandb logging dictionaries
* Scalar metrics
* Image metrics
* Early stopping metric
"""
self.model.eval()
binary_dice_scores = []
binary_jaccard_scores = []
pq_scores = []
cell_type_pq_scores = []
tissue_pred = []
tissue_gt = []
val_example_img = None
# reset metrics
self.loss_avg_tracker["Total_Loss"].reset()
for branch, loss_fns in self.loss_fn_dict.items():
for loss_name in loss_fns:
self.loss_avg_tracker[f"{branch}_{loss_name}"].reset()
self.batch_avg_tissue_acc.reset()
# randomly select a batch that should be displayed
if self.log_images:
select_example_image = int(torch.randint(0, len(val_dataloader), (1,)))
else:
select_example_image = None
val_loop = tqdm.tqdm(enumerate(val_dataloader), total=len(val_dataloader))
with torch.no_grad():
for batch_idx, batch in val_loop:
return_example_images = batch_idx == select_example_image
batch_metrics, example_img= self.validation_step(
batch, batch_idx, return_example_images
)
# 检查总体损失是否为NaN
# if np.isnan(self.loss_avg_tracker["Total_Loss"].avg):
# print("NaN loss for image:", batch_idx)
if example_img is not None:
val_example_img = example_img
binary_dice_scores = (
binary_dice_scores + batch_metrics["binary_dice_scores"]
)
binary_jaccard_scores = (
binary_jaccard_scores + batch_metrics["binary_jaccard_scores"]
)
pq_scores = pq_scores + batch_metrics["pq_scores"]
cell_type_pq_scores = (
cell_type_pq_scores + batch_metrics["cell_type_pq_scores"]
)
tissue_pred.append(batch_metrics["tissue_pred"])
tissue_gt.append(batch_metrics["tissue_gt"])
val_loop.set_postfix(
{
"Loss": np.round(self.loss_avg_tracker["Total_Loss"].avg, 3),
"Dice": np.round(np.nanmean(binary_dice_scores), 3),
"Pred-Acc": np.round(self.batch_avg_tissue_acc.avg, 3),
}
)
tissue_types_val = [
self.reverse_tissue_types[t].lower() for t in np.concatenate(tissue_gt)
]
# calculate global metrics
binary_dice_scores = np.array(binary_dice_scores)
binary_jaccard_scores = np.array(binary_jaccard_scores)
pq_scores = np.array(pq_scores)
tissue_detection_accuracy = accuracy_score(
y_true=np.concatenate(tissue_gt), y_pred=np.concatenate(tissue_pred)
)
scalar_metrics = {
"Loss/Validation": self.loss_avg_tracker["Total_Loss"].avg,
"Binary-Cell-Dice-Mean/Validation": np.nanmean(binary_dice_scores),
"Binary-Cell-Jacard-Mean/Validation": np.nanmean(binary_jaccard_scores),
"Tissue-Multiclass-Accuracy/Validation": tissue_detection_accuracy,
"bPQ/Validation": np.nanmean(pq_scores),
"mPQ/Validation": np.nanmean(
[np.nanmean(pq) for pq in cell_type_pq_scores]
),
}
for branch, loss_fns in self.loss_fn_dict.items():
for loss_name in loss_fns:
scalar_metrics[
f"{branch}_{loss_name}/Validation"
] = self.loss_avg_tracker[f"{branch}_{loss_name}"].avg #这里的loss_avg_tracker是在train_step中定义的
# calculate local metrics
# per tissue class
for tissue in self.tissue_types.keys():
tissue = tissue.lower()
tissue_ids = np.where(np.asarray(tissue_types_val) == tissue)
scalar_metrics[f"{tissue}-Dice/Validation"] = np.nanmean(
binary_dice_scores[tissue_ids]
)
scalar_metrics[f"{tissue}-Jaccard/Validation"] = np.nanmean(
binary_jaccard_scores[tissue_ids]
)
scalar_metrics[f"{tissue}-bPQ/Validation"] = np.nanmean(
pq_scores[tissue_ids]
)
scalar_metrics[f"{tissue}-mPQ/Validation"] = np.nanmean(
[np.nanmean(pq) for pq in np.array(cell_type_pq_scores)[tissue_ids]]
)
# calculate nuclei metrics
for nuc_name, nuc_type in self.nuclei_types.items():
if nuc_name.lower() == "background":
continue
scalar_metrics[f"{nuc_name}-PQ/Validation"] = np.nanmean(
[pq[nuc_type] for pq in cell_type_pq_scores]
)
self.logger.info(
f"{'Validation epoch stats:' : <25} "
f"Loss: {self.loss_avg_tracker['Total_Loss'].avg:.4f} - "
f"Binary-Cell-Dice: {np.nanmean(binary_dice_scores):.4f} - "
f"Binary-Cell-Jacard: {np.nanmean(binary_jaccard_scores):.4f} - "
f"bPQ-Score: {np.nanmean(pq_scores):.4f} - "
f"mPQ-Score: {scalar_metrics['mPQ/Validation']:.4f} - "
f"Tissue-MC-Acc.: {tissue_detection_accuracy:.4f}"
)
image_metrics = {"Example-Predictions/Validation": val_example_img}
return scalar_metrics, image_metrics, np.nanmean(pq_scores)
def validation_step(
self,
batch: object,
batch_idx: int,
return_example_images: bool,
):
"""Validation step
Args:
batch (object): Training batch, consisting of images ([0]), masks ([1]), tissue_types ([2]) and figure filenames ([3])
batch_idx (int): Batch index
return_example_images (bool): If an example preciction image should be returned
Returns:
Tuple[dict, Union[plt.Figure, None]]:
* Batch-Metrics: dictionary, structure not fixed yet
* Example prediction image
"""
# unpack batch, for shape compare train_step method
imgs = batch[0].to(self.device)
masks = batch[1]
tissue_types = batch[2]
# nan_loss_images = []
# csv_file = "/data3/ziweicui/PanNuke/cellvit-png/fold1_nan_loss_images.csv"
self.model.zero_grad()
self.optimizer.zero_grad()
# with open(csv_file, 'a') as f:
# csv_write = csv.writer(f)
if self.mixed_precision:
with torch.autocast(device_type="cuda", dtype=torch.float16):
# make predictions
predictions_ = self.model.forward(imgs)
# reshaping and postprocessing
predictions = self.unpack_predictions(predictions=predictions_)
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types)
# calculate loss
_ = self.calculate_loss(predictions, gt)
# 检查损失是否为NaN
#loss_value = _.item()
# if math.isnan(loss_value):
# print("NaN loss for image:", batch[3])
#nan_loss_images.append(batch[3])
else:
predictions_ = self.model.forward(imgs)
# reshaping and postprocessing
predictions = self.unpack_predictions(predictions=predictions_)
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types)
# calculate loss
_ = self.calculate_loss(predictions, gt)
# 检查损失是否为NaN
loss_value = _.item()
if math.isnan(loss_value):
print("NaN loss for image:", batch[3])
# get metrics for this batch
batch_metrics = self.calculate_step_metric_validation(predictions, gt)
if return_example_images:
try:
return_example_images = self.generate_example_image(
imgs,
predictions,
gt,
num_images=4,
num_nuclei_classes=self.num_classes,
)
except AssertionError:
self.logger.error(
"AssertionError for Example Image. Please check. Continue without image."
)
return_example_images = None
else:
return_example_images = None
return batch_metrics, return_example_images
def unpack_predictions(self, predictions: dict) -> DataclassHVStorage:
"""Unpack the given predictions. Main focus lays on reshaping and postprocessing predictions, e.g. separating instances
Args:
predictions (dict): Dictionary with the following keys:
* tissue_types: Logit tissue prediction output. Shape: (batch_size, num_tissue_classes)
* nuclei_binary_map: Logit output for binary nuclei prediction branch. Shape: (batch_size, 2, H, W)
* hv_map: Logit output for hv-prediction. Shape: (batch_size, 2, H, W)
* nuclei_type_map: Logit output for nuclei instance-prediction. Shape: (batch_size, num_nuclei_classes, H, W)
Returns:
DataclassHVStorage: Processed network output
"""
predictions["tissue_types"] = predictions["tissue_types"].to(self.device)
predictions["nuclei_binary_map"] = F.softmax(
predictions["nuclei_binary_map"], dim=1
) # shape: (batch_size, 2, H, W)
predictions["nuclei_type_map"] = F.softmax(
predictions["nuclei_type_map"], dim=1
) # shape: (batch_size, num_nuclei_classes, H, W)
(
predictions["instance_map"],
predictions["instance_types"],
) = self.model.calculate_instance_map(
predictions, self.magnification
) # shape: (batch_size, H, W)
predictions["instance_types_nuclei"] = self.model.generate_instance_nuclei_map(
predictions["instance_map"], predictions["instance_types"]
).to(
self.device
) # shape: (batch_size, num_nuclei_classes, H, W) (32, 256, 256, 6)
if "regression_map" not in predictions.keys():
predictions["regression_map"] = None
predictions = DataclassHVStorage(
nuclei_binary_map=predictions["nuclei_binary_map"],
hv_map=predictions["hv_map"],
nuclei_type_map=predictions["nuclei_type_map"],
tissue_types=predictions["tissue_types"],
instance_map=predictions["instance_map"],
instance_types=predictions["instance_types"],
instance_types_nuclei=predictions["instance_types_nuclei"],
batch_size=predictions["tissue_types"].shape[0],
regression_map=predictions["regression_map"],
num_nuclei_classes=self.num_classes,
)
return predictions
def unpack_masks(self, masks: dict, tissue_types: list) -> DataclassHVStorage:
"""Unpack the given masks. Main focus lays on reshaping and postprocessing masks to generate one dict
Args:
masks (dict): Required keys are:
* instance_map: Pixel-wise nuclear instance segmentations. Shape: (batch_size, H, W)
* nuclei_binary_map: Binary nuclei segmentations. Shape: (batch_size, H, W)
* hv_map: HV-Map. Shape: (batch_size, 2, H, W)
* nuclei_type_map: Nuclei instance-prediction and segmentation (not binary, each instance has own integer).
Shape: (batch_size, num_nuclei_classes, H, W)
tissue_types (list): List of string names of ground-truth tissue types
Returns:
DataclassHVStorage: GT-Results with matching shapes and output types
"""
# get ground truth values, perform one hot encoding for segmentation maps
gt_nuclei_binary_map_onehot = (
F.one_hot(masks["nuclei_binary_map"], num_classes=2)
).type(
torch.float32
) # background, nuclei
#nuclei_type_maps = torch.squeeze(masks["nuclei_type_map"]).type(torch.int64)
nuclei_type_maps = masks["nuclei_type_map"].type(torch.int64)
gt_nuclei_type_maps_onehot = F.one_hot(
nuclei_type_maps, num_classes=self.num_classes
).type(
torch.float32
) # background + nuclei types
# assemble ground truth dictionary
gt = {
"nuclei_type_map": gt_nuclei_type_maps_onehot.permute(0, 3, 1, 2).to(
self.device
), # shape: (batch_size, H, W, num_nuclei_classes)
"nuclei_binary_map": gt_nuclei_binary_map_onehot.permute(0, 3, 1, 2).to(
self.device
), # shape: (batch_size, H, W, 2)
"hv_map": masks["hv_map"].to(self.device), # shape: (batch_size,2, H, W)
"instance_map": masks["instance_map"].to(
self.device
), # shape: (batch_size, H, W) -> each instance has one integer
"instance_types_nuclei": (
gt_nuclei_type_maps_onehot * masks["instance_map"][..., None]
)
.permute(0, 3, 1, 2)
.to(
self.device
), # shape: (batch_size, num_nuclei_classes, H, W) -> instance has one integer, for each nuclei class
"tissue_types": torch.Tensor([self.tissue_types[t] for t in tissue_types])
.type(torch.LongTensor)
.to(self.device), # shape: batch_size
}
if "regression_map" in masks:
gt["regression_map"] = masks["regression_map"].to(self.device)
gt = DataclassHVStorage(
**gt,
batch_size=gt["tissue_types"].shape[0],
num_nuclei_classes=self.num_classes,
)
return gt
def calculate_loss(
self, predictions: DataclassHVStorage, gt: DataclassHVStorage
) -> torch.Tensor:
"""Calculate the loss
Args:
predictions (DataclassHVStorage): Predictions
gt (DataclassHVStorage): Ground-Truth values
Returns:
torch.Tensor: Loss
"""
predictions = predictions.get_dict()
gt = gt.get_dict()
total_loss = 0
for branch, pred in predictions.items():
if branch in [
"instance_map",
"instance_types",
"instance_types_nuclei",
]:
continue
if branch not in self.loss_fn_dict:
continue
branch_loss_fns = self.loss_fn_dict[branch]
for loss_name, loss_setting in branch_loss_fns.items():
loss_fn = loss_setting["loss_fn"]
weight = loss_setting["weight"]
if loss_name == "msge":
loss_value = loss_fn(
input=pred,
target=gt[branch],
focus=gt["nuclei_binary_map"],
device=self.device,
)
else:
loss_value = loss_fn(input=pred, target=gt[branch])
total_loss = total_loss + weight * loss_value
self.loss_avg_tracker[f"{branch}_{loss_name}"].update(
loss_value.detach().cpu().numpy()
)
self.loss_avg_tracker["Total_Loss"].update(total_loss.detach().cpu().numpy())
return total_loss
def calculate_step_metric_train(
self, predictions: DataclassHVStorage, gt: DataclassHVStorage
) -> dict:
"""Calculate the metrics for the training step
Args:
predictions (DataclassHVStorage): Processed network output
gt (DataclassHVStorage): Ground truth values
Returns:
dict: Dictionary with metrics. Keys:
binary_dice_scores, binary_jaccard_scores, tissue_pred, tissue_gt
"""
predictions = predictions.get_dict()
gt = gt.get_dict()
# Tissue Tpyes logits to probs and argmax to get class
predictions["tissue_types_classes"] = F.softmax(
predictions["tissue_types"], dim=-1
)
pred_tissue = (
torch.argmax(predictions["tissue_types_classes"], dim=-1)
.detach()
.cpu()
.numpy()
.astype(np.uint8)
)
predictions["instance_map"] = predictions["instance_map"].detach().cpu()
predictions["instance_types_nuclei"] = (
predictions["instance_types_nuclei"].detach().cpu().numpy().astype("int32")
)
gt["tissue_types"] = gt["tissue_types"].detach().cpu().numpy().astype(np.uint8)
gt["nuclei_binary_map"] = torch.argmax(gt["nuclei_binary_map"], dim=1).type(
torch.uint8
)
gt["instance_types_nuclei"] = (
gt["instance_types_nuclei"].detach().cpu().numpy().astype("int32")
)
tissue_detection_accuracy = accuracy_score(
y_true=gt["tissue_types"], y_pred=pred_tissue
)
self.batch_avg_tissue_acc.update(tissue_detection_accuracy)
binary_dice_scores = []
binary_jaccard_scores = []
for i in range(len(pred_tissue)):
# binary dice score: Score for cell detection per image, without background
pred_binary_map = torch.argmax(predictions["nuclei_binary_map"][i], dim=0)
target_binary_map = gt["nuclei_binary_map"][i]
cell_dice = (
dice(preds=pred_binary_map, target=target_binary_map, ignore_index=0)
.detach()
.cpu()
)
binary_dice_scores.append(float(cell_dice))
# binary aji
cell_jaccard = (
binary_jaccard_index(
preds=pred_binary_map,
target=target_binary_map,
)
.detach()
.cpu()
)
binary_jaccard_scores.append(float(cell_jaccard))
batch_metrics = {
"binary_dice_scores": binary_dice_scores,
"binary_jaccard_scores": binary_jaccard_scores,
"tissue_pred": pred_tissue,
"tissue_gt": gt["tissue_types"],
}
return batch_metrics
def calculate_step_metric_validation(self, predictions: dict, gt: dict) -> dict:
"""Calculate the metrics for the training step
Args:
predictions (DataclassHVStorage): OrderedDict: Processed network output
gt (DataclassHVStorage): Ground truth values
Returns:
dict: Dictionary with metrics. Keys:
binary_dice_scores, binary_jaccard_scores, tissue_pred, tissue_gt
"""
predictions = predictions.get_dict()
gt = gt.get_dict()
# Tissue Tpyes logits to probs and argmax to get class
predictions["tissue_types_classes"] = F.softmax(
predictions["tissue_types"], dim=-1
)
pred_tissue = (
torch.argmax(predictions["tissue_types_classes"], dim=-1)
.detach()
.cpu()
.numpy()
.astype(np.uint8)
)
predictions["instance_map"] = predictions["instance_map"].detach().cpu()
predictions["instance_types_nuclei"] = (
predictions["instance_types_nuclei"].detach().cpu().numpy().astype("int32")
)
#change
predictions["instance_types_nuclei"] = predictions["instance_types_nuclei"].transpose(0, 3, 1, 2)
instance_maps_gt = gt["instance_map"].detach().cpu()
gt["tissue_types"] = gt["tissue_types"].detach().cpu().numpy().astype(np.uint8)
gt["nuclei_binary_map"] = torch.argmax(gt["nuclei_binary_map"], dim=1).type(
torch.uint8
)
gt["instance_types_nuclei"] = (
gt["instance_types_nuclei"].detach().cpu().numpy().astype("int32")
)
tissue_detection_accuracy = accuracy_score(
y_true=gt["tissue_types"], y_pred=pred_tissue
)
self.batch_avg_tissue_acc.update(tissue_detection_accuracy)
binary_dice_scores = []
binary_jaccard_scores = []
cell_type_pq_scores = []
pq_scores = []
for i in range(len(pred_tissue)):
# binary dice score: Score for cell detection per image, without background
pred_binary_map = torch.argmax(predictions["nuclei_binary_map"][i], dim=0)
target_binary_map = gt["nuclei_binary_map"][i]
cell_dice = (
dice(preds=pred_binary_map, target=target_binary_map, ignore_index=0)
.detach()
.cpu()
)
binary_dice_scores.append(float(cell_dice))
# binary aji
cell_jaccard = (
binary_jaccard_index(
preds=pred_binary_map,
target=target_binary_map,
)
.detach()
.cpu()
)
binary_jaccard_scores.append(float(cell_jaccard))
# pq values
remapped_instance_pred = remap_label(predictions["instance_map"][i])
remapped_gt = remap_label(instance_maps_gt[i])
[_, _, pq], _ = get_fast_pq(true=remapped_gt, pred=remapped_instance_pred)
pq_scores.append(pq)
#pq values per class (skip background)
nuclei_type_pq = []
for j in range(0, self.num_classes):
pred_nuclei_instance_class = remap_label(
predictions["instance_types_nuclei"][i][j, ...]
)
target_nuclei_instance_class = remap_label(
gt["instance_types_nuclei"][i][j, ...]
)
# if ground truth is empty, skip from calculation
if len(np.unique(target_nuclei_instance_class)) == 1:
pq_tmp = np.nan
else:
[_, _, pq_tmp], _ = get_fast_pq(
pred_nuclei_instance_class,
target_nuclei_instance_class,
match_iou=0.5,
)
nuclei_type_pq.append(pq_tmp)
cell_type_pq_scores.append(nuclei_type_pq)
batch_metrics = {
"binary_dice_scores": binary_dice_scores,
"binary_jaccard_scores": binary_jaccard_scores,
"pq_scores": pq_scores,
"cell_type_pq_scores": cell_type_pq_scores,
"tissue_pred": pred_tissue,
"tissue_gt": gt["tissue_types"],
}
return batch_metrics
@staticmethod
def generate_example_image(
imgs: Union[torch.Tensor, np.ndarray],
predictions: DataclassHVStorage,
gt: DataclassHVStorage,
num_nuclei_classes: int,
num_images: int = 2,
) -> plt.Figure:
"""Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth
Args:
imgs (Union[torch.Tensor, np.ndarray]): Images to process, a random number (num_images) is selected from this stack
Shape: (batch_size, 3, H', W')
predictions (DataclassHVStorage): Predictions
gt (DataclassHVStorage): gt
num_nuclei_classes (int): Number of total nuclei classes including background
num_images (int, optional): Number of example patches to display. Defaults to 2.
Returns:
plt.Figure: Figure with example patches
"""
predictions = predictions.get_dict()
gt = gt.get_dict()
assert num_images <= imgs.shape[0]
num_images = 4
predictions["nuclei_binary_map"] = predictions["nuclei_binary_map"].permute(
0, 2, 3, 1
)
predictions["hv_map"] = predictions["hv_map"].permute(0, 2, 3, 1)
predictions["nuclei_type_map"] = predictions["nuclei_type_map"].permute(
0, 2, 3, 1
)
predictions["instance_types_nuclei"] = predictions[
"instance_types_nuclei"
].transpose(0, 2, 3, 1)
gt["hv_map"] = gt["hv_map"].permute(0, 2, 3, 1)
gt["nuclei_type_map"] = gt["nuclei_type_map"].permute(0, 2, 3, 1)
predictions["instance_types_nuclei"] = predictions[
"instance_types_nuclei"
].transpose(0, 2, 3, 1)
h = gt["hv_map"].shape[1]
w = gt["hv_map"].shape[2]
sample_indices = torch.randint(0, imgs.shape[0], (num_images,))
# convert to rgb and crop to selection
sample_images = (
imgs[sample_indices].permute(0, 2, 3, 1).contiguous().cpu().numpy()
) # convert to rgb
sample_images = cropping_center(sample_images, (h, w), True)
# get predictions
pred_sample_binary_map = (
predictions["nuclei_binary_map"][sample_indices, :, :, 1]
.detach()
.cpu()
.numpy()
)
pred_sample_hv_map = (
predictions["hv_map"][sample_indices].detach().cpu().numpy()
)
pred_sample_instance_maps = (
predictions["instance_map"][sample_indices].detach().cpu().numpy()
)
pred_sample_type_maps = (
torch.argmax(predictions["nuclei_type_map"][sample_indices], dim=-1)
.detach()
.cpu()
.numpy()
)
# get ground truth labels
gt_sample_binary_map = (
gt["nuclei_binary_map"][sample_indices].detach().cpu().numpy()
)
gt_sample_hv_map = gt["hv_map"][sample_indices].detach().cpu().numpy()
gt_sample_instance_map = (
gt["instance_map"][sample_indices].detach().cpu().numpy()
)
gt_sample_type_map = (
torch.argmax(gt["nuclei_type_map"][sample_indices], dim=-1)
.detach()
.cpu()
.numpy()
)
# create colormaps
hv_cmap = plt.get_cmap("jet")
binary_cmap = plt.get_cmap("jet")
instance_map = plt.get_cmap("viridis")
# setup plot
fig, axs = plt.subplots(num_images, figsize=(6, 2 * num_images), dpi=150)
for i in range(num_images):
placeholder = np.zeros((2 * h, 6 * w, 3))
# orig image
placeholder[:h, :w, :3] = sample_images[i]
placeholder[h : 2 * h, :w, :3] = sample_images[i]
# binary prediction
placeholder[:h, w : 2 * w, :3] = rgba2rgb(
binary_cmap(gt_sample_binary_map[i] * 255)
)
placeholder[h : 2 * h, w : 2 * w, :3] = rgba2rgb(
binary_cmap(pred_sample_binary_map[i])
) # *255?
# hv maps
placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb(
hv_cmap((gt_sample_hv_map[i, :, :, 0] + 1) / 2)
)
placeholder[h : 2 * h, 2 * w : 3 * w, :3] = rgba2rgb(
hv_cmap((pred_sample_hv_map[i, :, :, 0] + 1) / 2)
)
placeholder[:h, 3 * w : 4 * w, :3] = rgba2rgb(
hv_cmap((gt_sample_hv_map[i, :, :, 1] + 1) / 2)
)
placeholder[h : 2 * h, 3 * w : 4 * w, :3] = rgba2rgb(
hv_cmap((pred_sample_hv_map[i, :, :, 1] + 1) / 2)
)
# instance_predictions
placeholder[:h, 4 * w : 5 * w, :3] = rgba2rgb(
instance_map(
(gt_sample_instance_map[i] - np.min(gt_sample_instance_map[i]))
/ (
np.max(gt_sample_instance_map[i])
- np.min(gt_sample_instance_map[i] + 1e-10)
)
)
)
placeholder[h : 2 * h, 4 * w : 5 * w, :3] = rgba2rgb(
instance_map(
(
pred_sample_instance_maps[i]
- np.min(pred_sample_instance_maps[i])
)
/ (
np.max(pred_sample_instance_maps[i])
- np.min(pred_sample_instance_maps[i] + 1e-10)
)
)
)
# type_predictions
placeholder[:h, 5 * w : 6 * w, :3] = rgba2rgb(
binary_cmap(gt_sample_type_map[i] / num_nuclei_classes)
)
placeholder[h : 2 * h, 5 * w : 6 * w, :3] = rgba2rgb(
binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes)
)
# plotting
axs[i].imshow(placeholder)
axs[i].set_xticks([], [])
# plot labels in first row
if i == 0:
axs[i].set_xticks(np.arange(w / 2, 6 * w, w))
axs[i].set_xticklabels(
[
"Image",
"Binary-Cells",
"HV-Map-0",
"HV-Map-1",
"Cell Instances",
"Nuclei-Instances",
],
fontsize=6,
)
axs[i].xaxis.tick_top()
axs[i].set_yticks(np.arange(h / 2, 2 * h, h))
axs[i].set_yticklabels(["GT", "Pred."], fontsize=6)
axs[i].tick_params(axis="both", which="both", length=0)
grid_x = np.arange(w, 6 * w, w)
grid_y = np.arange(h, 2 * h, h)
for x_seg in grid_x:
axs[i].axvline(x_seg, color="black")
for y_seg in grid_y:
axs[i].axhline(y_seg, color="black")
fig.suptitle(f"Patch Predictions for {num_images} Examples")
fig.tight_layout()
return fig