Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# CellViT Trainer Class | |
# | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
import logging | |
from pathlib import Path | |
from typing import Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import tqdm | |
import math | |
import csv | |
# import wandb | |
from matplotlib import pyplot as plt | |
from skimage.color import rgba2rgb | |
from sklearn.metrics import accuracy_score | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import _LRScheduler | |
from torch.utils.data import DataLoader | |
from torchmetrics.functional import dice | |
from torchmetrics.functional.classification import binary_jaccard_index | |
from base_ml.base_early_stopping import EarlyStopping | |
from base_ml.base_trainer import BaseTrainer | |
from models.segmentation.cell_segmentation.cellvit import DataclassHVStorage | |
from cell_segmentation.utils.metrics import get_fast_pq, remap_label | |
from cell_segmentation.utils.tools import cropping_center | |
from models.segmentation.cell_segmentation.cellvit import CellViT | |
from utils.tools import AverageMeter | |
from timm.utils import ModelEma | |
from torch.cuda.amp import GradScaler, autocast | |
class CellViTTrainer(BaseTrainer): | |
"""CellViT trainer class | |
Args: | |
model (CellViT): CellViT model that should be trained | |
loss_fn_dict (dict): Dictionary with loss functions for each branch with a dictionary of loss functions. | |
Name of branch as top-level key, followed by a dictionary with loss name, loss fn and weighting factor | |
Example: | |
{ | |
"nuclei_binary_map": {"bce": {loss_fn(Callable), weight_factor(float)}, "dice": {loss_fn(Callable), weight_factor(float)}}, | |
"hv_map": {"bce": {loss_fn(Callable), weight_factor(float)}, "dice": {loss_fn(Callable), weight_factor(float)}}, | |
"nuclei_type_map": {"bce": {loss_fn(Callable), weight_factor(float)}, "dice": {loss_fn(Callable), weight_factor(float)}} | |
"tissue_types": {"ce": {loss_fn(Callable), weight_factor(float)}} | |
} | |
Required Keys are: | |
* nuclei_binary_map | |
* hv_map | |
* nuclei_type_map | |
* tissue types | |
optimizer (Optimizer): Optimizer | |
scheduler (_LRScheduler): Learning rate scheduler | |
device (str): Cuda device to use, e.g., cuda:0. | |
logger (logging.Logger): Logger module | |
logdir (Union[Path, str]): Logging directory | |
num_classes (int): Number of nuclei classes | |
dataset_config (dict): Dataset configuration. Required Keys are: | |
* "tissue_types": describing the present tissue types with corresponding integer | |
* "nuclei_types": describing the present nuclei types with corresponding integer | |
experiment_config (dict): Configuration of this experiment | |
early_stopping (EarlyStopping, optional): Early Stopping Class. Defaults to None. | |
log_images (bool, optional): If images should be logged to WandB. Defaults to False. | |
magnification (int, optional): Image magnification. Please select either 40 or 20. Defaults to 40. | |
mixed_precision (bool, optional): If mixed-precision should be used. Defaults to False. | |
""" | |
def __init__( | |
self, | |
model: CellViT, | |
loss_fn_dict: dict, | |
optimizer: Optimizer, | |
scheduler: _LRScheduler, | |
device: str, | |
logger: logging.Logger, | |
logdir: Union[Path, str], | |
num_classes: int, | |
dataset_config: dict, | |
experiment_config: dict, | |
early_stopping: EarlyStopping = None, | |
log_images: bool = False, | |
magnification: int = 40, | |
mixed_precision: bool = False, | |
#model_ema : bool = True, | |
): | |
super().__init__( | |
model=model, | |
loss_fn=None, | |
optimizer=optimizer, | |
scheduler=scheduler, | |
device=device, | |
logger=logger, | |
logdir=logdir, | |
experiment_config=experiment_config, | |
early_stopping=early_stopping, | |
accum_iter=1, | |
log_images=log_images, | |
mixed_precision=mixed_precision, | |
) | |
self.loss_fn_dict = loss_fn_dict | |
self.num_classes = num_classes | |
self.dataset_config = dataset_config | |
self.tissue_types = dataset_config["tissue_types"] | |
self.reverse_tissue_types = {v: k for k, v in self.tissue_types.items()} | |
self.nuclei_types = dataset_config["nuclei_types"] | |
self.magnification = magnification | |
#self.model_ema = model_ema | |
# setup logging objects | |
self.loss_avg_tracker = {"Total_Loss": AverageMeter("Total_Loss", ":.4f")} | |
for branch, loss_fns in self.loss_fn_dict.items(): | |
for loss_name in loss_fns: | |
self.loss_avg_tracker[f"{branch}_{loss_name}"] = AverageMeter( | |
f"{branch}_{loss_name}", ":.4f" | |
) | |
self.batch_avg_tissue_acc = AverageMeter("Batch_avg_tissue_ACC", ":4.f") | |
def train_epoch( | |
self, epoch: int, train_dataloader: DataLoader, unfreeze_epoch: int = 50 | |
) -> Tuple[dict, dict]: | |
"""Training logic for a training epoch | |
Args: | |
epoch (int): Current epoch number | |
train_dataloader (DataLoader): Train dataloader | |
unfreeze_epoch (int, optional): Epoch to unfreeze layers | |
Returns: | |
Tuple[dict, dict]: wandb logging dictionaries | |
* Scalar metrics | |
* Image metrics | |
""" | |
self.model.train() | |
if epoch >= unfreeze_epoch: | |
self.model.unfreeze_encoder() | |
# if self.model_ema and epoch == 0: | |
# self.model_ema_instance = ModelEma( | |
# model=self.model, | |
# decay=0.9999, | |
# device='cuda', | |
# resume='' | |
# ) | |
binary_dice_scores = [] | |
binary_jaccard_scores = [] | |
tissue_pred = [] | |
tissue_gt = [] | |
train_example_img = None | |
# reset metrics | |
self.loss_avg_tracker["Total_Loss"].reset() | |
for branch, loss_fns in self.loss_fn_dict.items(): | |
for loss_name in loss_fns: | |
self.loss_avg_tracker[f"{branch}_{loss_name}"].reset() | |
self.batch_avg_tissue_acc.reset() | |
# randomly select a batch that should be displayed | |
if self.log_images: | |
select_example_image = int(torch.randint(0, len(train_dataloader), (1,))) | |
else: | |
select_example_image = None | |
train_loop = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)) | |
for batch_idx, batch in train_loop: | |
return_example_images = batch_idx == select_example_image | |
batch_metrics, example_img = self.train_step( | |
batch, | |
batch_idx, | |
len(train_dataloader), | |
return_example_images=return_example_images, | |
) | |
if example_img is not None: | |
train_example_img = example_img | |
binary_dice_scores = ( | |
binary_dice_scores + batch_metrics["binary_dice_scores"] | |
) | |
binary_jaccard_scores = ( | |
binary_jaccard_scores + batch_metrics["binary_jaccard_scores"] | |
) | |
tissue_pred.append(batch_metrics["tissue_pred"]) | |
tissue_gt.append(batch_metrics["tissue_gt"]) | |
train_loop.set_postfix( | |
{ | |
"Loss": np.round(self.loss_avg_tracker["Total_Loss"].avg, 3), | |
"Dice": np.round(np.nanmean(binary_dice_scores), 3), | |
"Pred-Acc": np.round(self.batch_avg_tissue_acc.avg, 3), | |
} | |
) | |
# calculate global metrics | |
binary_dice_scores = np.array(binary_dice_scores) | |
binary_jaccard_scores = np.array(binary_jaccard_scores) | |
tissue_detection_accuracy = accuracy_score( | |
y_true=np.concatenate(tissue_gt), y_pred=np.concatenate(tissue_pred) | |
) | |
scalar_metrics = { | |
"Loss/Train": self.loss_avg_tracker["Total_Loss"].avg, | |
"Binary-Cell-Dice-Mean/Train": np.nanmean(binary_dice_scores), | |
"Binary-Cell-Jacard-Mean/Train": np.nanmean(binary_jaccard_scores), | |
"Tissue-Multiclass-Accuracy/Train": tissue_detection_accuracy, | |
} | |
for branch, loss_fns in self.loss_fn_dict.items(): | |
for loss_name in loss_fns: | |
scalar_metrics[f"{branch}_{loss_name}/Train"] = self.loss_avg_tracker[ | |
f"{branch}_{loss_name}" | |
].avg | |
self.logger.info( | |
f"{'Training epoch stats:' : <25} " | |
f"Loss: {self.loss_avg_tracker['Total_Loss'].avg:.4f} - " | |
f"Binary-Cell-Dice: {np.nanmean(binary_dice_scores):.4f} - " | |
f"Binary-Cell-Jacard: {np.nanmean(binary_jaccard_scores):.4f} - " | |
f"Tissue-MC-Acc.: {tissue_detection_accuracy:.4f}" | |
) | |
image_metrics = {"Example-Predictions/Train": train_example_img} | |
return scalar_metrics, image_metrics | |
def train_step( | |
self, | |
batch: object, | |
batch_idx: int, | |
num_batches: int, | |
return_example_images: bool, | |
) -> Tuple[dict, Union[plt.Figure, None]]: | |
"""Training step | |
Args: | |
batch (object): Training batch, consisting of images ([0]), masks ([1]), tissue_types ([2]) and figure filenames ([3]) | |
batch_idx (int): Batch index | |
num_batches (int): Total number of batches in epoch | |
return_example_images (bool): If an example preciction image should be returned | |
Returns: | |
Tuple[dict, Union[plt.Figure, None]]: | |
* Batch-Metrics: dictionary with the following keys: | |
* Example prediction image | |
""" | |
# unpack batch | |
imgs = batch[0].to(self.device) # imgs shape: (batch_size, 3, H, W) (16,3,256,256) | |
masks = batch[ | |
1 | |
] # dict: keys: "instance_map", [16,256,256],"nuclei_map",[16,256,256], "nuclei_binary_map",[16,256,256], "hv_map"[16,2,256,256] | |
tissue_types = batch[2] # list[str] | |
#change | |
#scaler = GradScaler(init_scale=2.0) | |
if self.mixed_precision: | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
#with torch.cuda.amp.autocast(False): | |
# make predictions | |
predictions_ = self.model.forward(imgs) #img.shape=(16,3,256,256) model.forward(imgs) 'tissue_types'(16,19),'nuclei_binary_map'(16,2,128,128),'hv_map'(16,2,128,128),'nuclei_type_map'(16,6,128,128) | |
# reshaping and postprocessing | |
predictions = self.unpack_predictions(predictions=predictions_) | |
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types) | |
# calculate loss | |
total_loss = self.calculate_loss(predictions, gt) | |
# if torch.isnan(total_loss): | |
# print("nan in loss") | |
#if math.isnan(total_loss.item()): | |
#print("nan") | |
# import pdb; pdb.set_trace() | |
# backward pass | |
self.scaler.scale(total_loss).backward() | |
# 阈值剪切梯度 | |
#torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=1.0) | |
# if torch.any(torch.tensor([torch.any(torch.isnan(param.data)) for param in self.model.parameters()])): | |
# print("nan in model parameters") | |
if ( | |
((batch_idx + 1) % self.accum_iter == 0) | |
or ((batch_idx + 1) == num_batches) | |
or (self.accum_iter == 1) | |
): | |
# self.scaler.unscale_(self.optimizer) | |
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
self.scaler.step(self.optimizer) | |
self.scaler.update() | |
# if self.model_ema: | |
# self.model_ema_instance.update(self.model) | |
self.optimizer.zero_grad(set_to_none=True) | |
self.model.zero_grad() | |
else: | |
predictions_ = self.model.forward(imgs) | |
predictions = self.unpack_predictions(predictions=predictions_) | |
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types) | |
# calculate loss | |
total_loss = self.calculate_loss(predictions, gt) | |
total_loss.backward() | |
if ( | |
((batch_idx + 1) % self.accum_iter == 0) | |
or ((batch_idx + 1) == num_batches) | |
or (self.accum_iter == 1) | |
): | |
self.optimizer.step() | |
# if self.model_ema: | |
# self.model_ema_instance.update(self.model) | |
self.optimizer.zero_grad(set_to_none=True) | |
self.model.zero_grad() | |
with torch.cuda.device(self.device): | |
torch.cuda.empty_cache() | |
batch_metrics = self.calculate_step_metric_train(predictions, gt) | |
if return_example_images: | |
return_example_images = self.generate_example_image( | |
imgs, predictions, gt, num_images=4, num_nuclei_classes=self.num_classes | |
) | |
else: | |
return_example_images = None | |
return batch_metrics, return_example_images | |
def validation_epoch( | |
self, epoch: int, val_dataloader: DataLoader | |
) -> Tuple[dict, dict, float]: | |
"""Validation logic for a validation epoch | |
Args: | |
epoch (int): Current epoch number | |
val_dataloader (DataLoader): Validation dataloader | |
Returns: | |
Tuple[dict, dict, float]: wandb logging dictionaries | |
* Scalar metrics | |
* Image metrics | |
* Early stopping metric | |
""" | |
self.model.eval() | |
binary_dice_scores = [] | |
binary_jaccard_scores = [] | |
pq_scores = [] | |
cell_type_pq_scores = [] | |
tissue_pred = [] | |
tissue_gt = [] | |
val_example_img = None | |
# reset metrics | |
self.loss_avg_tracker["Total_Loss"].reset() | |
for branch, loss_fns in self.loss_fn_dict.items(): | |
for loss_name in loss_fns: | |
self.loss_avg_tracker[f"{branch}_{loss_name}"].reset() | |
self.batch_avg_tissue_acc.reset() | |
# randomly select a batch that should be displayed | |
if self.log_images: | |
select_example_image = int(torch.randint(0, len(val_dataloader), (1,))) | |
else: | |
select_example_image = None | |
val_loop = tqdm.tqdm(enumerate(val_dataloader), total=len(val_dataloader)) | |
with torch.no_grad(): | |
for batch_idx, batch in val_loop: | |
return_example_images = batch_idx == select_example_image | |
batch_metrics, example_img= self.validation_step( | |
batch, batch_idx, return_example_images | |
) | |
# 检查总体损失是否为NaN | |
# if np.isnan(self.loss_avg_tracker["Total_Loss"].avg): | |
# print("NaN loss for image:", batch_idx) | |
if example_img is not None: | |
val_example_img = example_img | |
binary_dice_scores = ( | |
binary_dice_scores + batch_metrics["binary_dice_scores"] | |
) | |
binary_jaccard_scores = ( | |
binary_jaccard_scores + batch_metrics["binary_jaccard_scores"] | |
) | |
pq_scores = pq_scores + batch_metrics["pq_scores"] | |
cell_type_pq_scores = ( | |
cell_type_pq_scores + batch_metrics["cell_type_pq_scores"] | |
) | |
tissue_pred.append(batch_metrics["tissue_pred"]) | |
tissue_gt.append(batch_metrics["tissue_gt"]) | |
val_loop.set_postfix( | |
{ | |
"Loss": np.round(self.loss_avg_tracker["Total_Loss"].avg, 3), | |
"Dice": np.round(np.nanmean(binary_dice_scores), 3), | |
"Pred-Acc": np.round(self.batch_avg_tissue_acc.avg, 3), | |
} | |
) | |
tissue_types_val = [ | |
self.reverse_tissue_types[t].lower() for t in np.concatenate(tissue_gt) | |
] | |
# calculate global metrics | |
binary_dice_scores = np.array(binary_dice_scores) | |
binary_jaccard_scores = np.array(binary_jaccard_scores) | |
pq_scores = np.array(pq_scores) | |
tissue_detection_accuracy = accuracy_score( | |
y_true=np.concatenate(tissue_gt), y_pred=np.concatenate(tissue_pred) | |
) | |
scalar_metrics = { | |
"Loss/Validation": self.loss_avg_tracker["Total_Loss"].avg, | |
"Binary-Cell-Dice-Mean/Validation": np.nanmean(binary_dice_scores), | |
"Binary-Cell-Jacard-Mean/Validation": np.nanmean(binary_jaccard_scores), | |
"Tissue-Multiclass-Accuracy/Validation": tissue_detection_accuracy, | |
"bPQ/Validation": np.nanmean(pq_scores), | |
"mPQ/Validation": np.nanmean( | |
[np.nanmean(pq) for pq in cell_type_pq_scores] | |
), | |
} | |
for branch, loss_fns in self.loss_fn_dict.items(): | |
for loss_name in loss_fns: | |
scalar_metrics[ | |
f"{branch}_{loss_name}/Validation" | |
] = self.loss_avg_tracker[f"{branch}_{loss_name}"].avg #这里的loss_avg_tracker是在train_step中定义的 | |
# calculate local metrics | |
# per tissue class | |
for tissue in self.tissue_types.keys(): | |
tissue = tissue.lower() | |
tissue_ids = np.where(np.asarray(tissue_types_val) == tissue) | |
scalar_metrics[f"{tissue}-Dice/Validation"] = np.nanmean( | |
binary_dice_scores[tissue_ids] | |
) | |
scalar_metrics[f"{tissue}-Jaccard/Validation"] = np.nanmean( | |
binary_jaccard_scores[tissue_ids] | |
) | |
scalar_metrics[f"{tissue}-bPQ/Validation"] = np.nanmean( | |
pq_scores[tissue_ids] | |
) | |
scalar_metrics[f"{tissue}-mPQ/Validation"] = np.nanmean( | |
[np.nanmean(pq) for pq in np.array(cell_type_pq_scores)[tissue_ids]] | |
) | |
# calculate nuclei metrics | |
for nuc_name, nuc_type in self.nuclei_types.items(): | |
if nuc_name.lower() == "background": | |
continue | |
scalar_metrics[f"{nuc_name}-PQ/Validation"] = np.nanmean( | |
[pq[nuc_type] for pq in cell_type_pq_scores] | |
) | |
self.logger.info( | |
f"{'Validation epoch stats:' : <25} " | |
f"Loss: {self.loss_avg_tracker['Total_Loss'].avg:.4f} - " | |
f"Binary-Cell-Dice: {np.nanmean(binary_dice_scores):.4f} - " | |
f"Binary-Cell-Jacard: {np.nanmean(binary_jaccard_scores):.4f} - " | |
f"bPQ-Score: {np.nanmean(pq_scores):.4f} - " | |
f"mPQ-Score: {scalar_metrics['mPQ/Validation']:.4f} - " | |
f"Tissue-MC-Acc.: {tissue_detection_accuracy:.4f}" | |
) | |
image_metrics = {"Example-Predictions/Validation": val_example_img} | |
return scalar_metrics, image_metrics, np.nanmean(pq_scores) | |
def validation_step( | |
self, | |
batch: object, | |
batch_idx: int, | |
return_example_images: bool, | |
): | |
"""Validation step | |
Args: | |
batch (object): Training batch, consisting of images ([0]), masks ([1]), tissue_types ([2]) and figure filenames ([3]) | |
batch_idx (int): Batch index | |
return_example_images (bool): If an example preciction image should be returned | |
Returns: | |
Tuple[dict, Union[plt.Figure, None]]: | |
* Batch-Metrics: dictionary, structure not fixed yet | |
* Example prediction image | |
""" | |
# unpack batch, for shape compare train_step method | |
imgs = batch[0].to(self.device) | |
masks = batch[1] | |
tissue_types = batch[2] | |
# nan_loss_images = [] | |
# csv_file = "/data3/ziweicui/PanNuke/cellvit-png/fold1_nan_loss_images.csv" | |
self.model.zero_grad() | |
self.optimizer.zero_grad() | |
# with open(csv_file, 'a') as f: | |
# csv_write = csv.writer(f) | |
if self.mixed_precision: | |
with torch.autocast(device_type="cuda", dtype=torch.float16): | |
# make predictions | |
predictions_ = self.model.forward(imgs) | |
# reshaping and postprocessing | |
predictions = self.unpack_predictions(predictions=predictions_) | |
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types) | |
# calculate loss | |
_ = self.calculate_loss(predictions, gt) | |
# 检查损失是否为NaN | |
#loss_value = _.item() | |
# if math.isnan(loss_value): | |
# print("NaN loss for image:", batch[3]) | |
#nan_loss_images.append(batch[3]) | |
else: | |
predictions_ = self.model.forward(imgs) | |
# reshaping and postprocessing | |
predictions = self.unpack_predictions(predictions=predictions_) | |
gt = self.unpack_masks(masks=masks, tissue_types=tissue_types) | |
# calculate loss | |
_ = self.calculate_loss(predictions, gt) | |
# 检查损失是否为NaN | |
loss_value = _.item() | |
if math.isnan(loss_value): | |
print("NaN loss for image:", batch[3]) | |
# get metrics for this batch | |
batch_metrics = self.calculate_step_metric_validation(predictions, gt) | |
if return_example_images: | |
try: | |
return_example_images = self.generate_example_image( | |
imgs, | |
predictions, | |
gt, | |
num_images=4, | |
num_nuclei_classes=self.num_classes, | |
) | |
except AssertionError: | |
self.logger.error( | |
"AssertionError for Example Image. Please check. Continue without image." | |
) | |
return_example_images = None | |
else: | |
return_example_images = None | |
return batch_metrics, return_example_images | |
def unpack_predictions(self, predictions: dict) -> DataclassHVStorage: | |
"""Unpack the given predictions. Main focus lays on reshaping and postprocessing predictions, e.g. separating instances | |
Args: | |
predictions (dict): Dictionary with the following keys: | |
* tissue_types: Logit tissue prediction output. Shape: (batch_size, num_tissue_classes) | |
* nuclei_binary_map: Logit output for binary nuclei prediction branch. Shape: (batch_size, 2, H, W) | |
* hv_map: Logit output for hv-prediction. Shape: (batch_size, 2, H, W) | |
* nuclei_type_map: Logit output for nuclei instance-prediction. Shape: (batch_size, num_nuclei_classes, H, W) | |
Returns: | |
DataclassHVStorage: Processed network output | |
""" | |
predictions["tissue_types"] = predictions["tissue_types"].to(self.device) | |
predictions["nuclei_binary_map"] = F.softmax( | |
predictions["nuclei_binary_map"], dim=1 | |
) # shape: (batch_size, 2, H, W) | |
predictions["nuclei_type_map"] = F.softmax( | |
predictions["nuclei_type_map"], dim=1 | |
) # shape: (batch_size, num_nuclei_classes, H, W) | |
( | |
predictions["instance_map"], | |
predictions["instance_types"], | |
) = self.model.calculate_instance_map( | |
predictions, self.magnification | |
) # shape: (batch_size, H, W) | |
predictions["instance_types_nuclei"] = self.model.generate_instance_nuclei_map( | |
predictions["instance_map"], predictions["instance_types"] | |
).to( | |
self.device | |
) # shape: (batch_size, num_nuclei_classes, H, W) (32, 256, 256, 6) | |
if "regression_map" not in predictions.keys(): | |
predictions["regression_map"] = None | |
predictions = DataclassHVStorage( | |
nuclei_binary_map=predictions["nuclei_binary_map"], | |
hv_map=predictions["hv_map"], | |
nuclei_type_map=predictions["nuclei_type_map"], | |
tissue_types=predictions["tissue_types"], | |
instance_map=predictions["instance_map"], | |
instance_types=predictions["instance_types"], | |
instance_types_nuclei=predictions["instance_types_nuclei"], | |
batch_size=predictions["tissue_types"].shape[0], | |
regression_map=predictions["regression_map"], | |
num_nuclei_classes=self.num_classes, | |
) | |
return predictions | |
def unpack_masks(self, masks: dict, tissue_types: list) -> DataclassHVStorage: | |
"""Unpack the given masks. Main focus lays on reshaping and postprocessing masks to generate one dict | |
Args: | |
masks (dict): Required keys are: | |
* instance_map: Pixel-wise nuclear instance segmentations. Shape: (batch_size, H, W) | |
* nuclei_binary_map: Binary nuclei segmentations. Shape: (batch_size, H, W) | |
* hv_map: HV-Map. Shape: (batch_size, 2, H, W) | |
* nuclei_type_map: Nuclei instance-prediction and segmentation (not binary, each instance has own integer). | |
Shape: (batch_size, num_nuclei_classes, H, W) | |
tissue_types (list): List of string names of ground-truth tissue types | |
Returns: | |
DataclassHVStorage: GT-Results with matching shapes and output types | |
""" | |
# get ground truth values, perform one hot encoding for segmentation maps | |
gt_nuclei_binary_map_onehot = ( | |
F.one_hot(masks["nuclei_binary_map"], num_classes=2) | |
).type( | |
torch.float32 | |
) # background, nuclei | |
#nuclei_type_maps = torch.squeeze(masks["nuclei_type_map"]).type(torch.int64) | |
nuclei_type_maps = masks["nuclei_type_map"].type(torch.int64) | |
gt_nuclei_type_maps_onehot = F.one_hot( | |
nuclei_type_maps, num_classes=self.num_classes | |
).type( | |
torch.float32 | |
) # background + nuclei types | |
# assemble ground truth dictionary | |
gt = { | |
"nuclei_type_map": gt_nuclei_type_maps_onehot.permute(0, 3, 1, 2).to( | |
self.device | |
), # shape: (batch_size, H, W, num_nuclei_classes) | |
"nuclei_binary_map": gt_nuclei_binary_map_onehot.permute(0, 3, 1, 2).to( | |
self.device | |
), # shape: (batch_size, H, W, 2) | |
"hv_map": masks["hv_map"].to(self.device), # shape: (batch_size,2, H, W) | |
"instance_map": masks["instance_map"].to( | |
self.device | |
), # shape: (batch_size, H, W) -> each instance has one integer | |
"instance_types_nuclei": ( | |
gt_nuclei_type_maps_onehot * masks["instance_map"][..., None] | |
) | |
.permute(0, 3, 1, 2) | |
.to( | |
self.device | |
), # shape: (batch_size, num_nuclei_classes, H, W) -> instance has one integer, for each nuclei class | |
"tissue_types": torch.Tensor([self.tissue_types[t] for t in tissue_types]) | |
.type(torch.LongTensor) | |
.to(self.device), # shape: batch_size | |
} | |
if "regression_map" in masks: | |
gt["regression_map"] = masks["regression_map"].to(self.device) | |
gt = DataclassHVStorage( | |
**gt, | |
batch_size=gt["tissue_types"].shape[0], | |
num_nuclei_classes=self.num_classes, | |
) | |
return gt | |
def calculate_loss( | |
self, predictions: DataclassHVStorage, gt: DataclassHVStorage | |
) -> torch.Tensor: | |
"""Calculate the loss | |
Args: | |
predictions (DataclassHVStorage): Predictions | |
gt (DataclassHVStorage): Ground-Truth values | |
Returns: | |
torch.Tensor: Loss | |
""" | |
predictions = predictions.get_dict() | |
gt = gt.get_dict() | |
total_loss = 0 | |
for branch, pred in predictions.items(): | |
if branch in [ | |
"instance_map", | |
"instance_types", | |
"instance_types_nuclei", | |
]: | |
continue | |
if branch not in self.loss_fn_dict: | |
continue | |
branch_loss_fns = self.loss_fn_dict[branch] | |
for loss_name, loss_setting in branch_loss_fns.items(): | |
loss_fn = loss_setting["loss_fn"] | |
weight = loss_setting["weight"] | |
if loss_name == "msge": | |
loss_value = loss_fn( | |
input=pred, | |
target=gt[branch], | |
focus=gt["nuclei_binary_map"], | |
device=self.device, | |
) | |
else: | |
loss_value = loss_fn(input=pred, target=gt[branch]) | |
total_loss = total_loss + weight * loss_value | |
self.loss_avg_tracker[f"{branch}_{loss_name}"].update( | |
loss_value.detach().cpu().numpy() | |
) | |
self.loss_avg_tracker["Total_Loss"].update(total_loss.detach().cpu().numpy()) | |
return total_loss | |
def calculate_step_metric_train( | |
self, predictions: DataclassHVStorage, gt: DataclassHVStorage | |
) -> dict: | |
"""Calculate the metrics for the training step | |
Args: | |
predictions (DataclassHVStorage): Processed network output | |
gt (DataclassHVStorage): Ground truth values | |
Returns: | |
dict: Dictionary with metrics. Keys: | |
binary_dice_scores, binary_jaccard_scores, tissue_pred, tissue_gt | |
""" | |
predictions = predictions.get_dict() | |
gt = gt.get_dict() | |
# Tissue Tpyes logits to probs and argmax to get class | |
predictions["tissue_types_classes"] = F.softmax( | |
predictions["tissue_types"], dim=-1 | |
) | |
pred_tissue = ( | |
torch.argmax(predictions["tissue_types_classes"], dim=-1) | |
.detach() | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
predictions["instance_map"] = predictions["instance_map"].detach().cpu() | |
predictions["instance_types_nuclei"] = ( | |
predictions["instance_types_nuclei"].detach().cpu().numpy().astype("int32") | |
) | |
gt["tissue_types"] = gt["tissue_types"].detach().cpu().numpy().astype(np.uint8) | |
gt["nuclei_binary_map"] = torch.argmax(gt["nuclei_binary_map"], dim=1).type( | |
torch.uint8 | |
) | |
gt["instance_types_nuclei"] = ( | |
gt["instance_types_nuclei"].detach().cpu().numpy().astype("int32") | |
) | |
tissue_detection_accuracy = accuracy_score( | |
y_true=gt["tissue_types"], y_pred=pred_tissue | |
) | |
self.batch_avg_tissue_acc.update(tissue_detection_accuracy) | |
binary_dice_scores = [] | |
binary_jaccard_scores = [] | |
for i in range(len(pred_tissue)): | |
# binary dice score: Score for cell detection per image, without background | |
pred_binary_map = torch.argmax(predictions["nuclei_binary_map"][i], dim=0) | |
target_binary_map = gt["nuclei_binary_map"][i] | |
cell_dice = ( | |
dice(preds=pred_binary_map, target=target_binary_map, ignore_index=0) | |
.detach() | |
.cpu() | |
) | |
binary_dice_scores.append(float(cell_dice)) | |
# binary aji | |
cell_jaccard = ( | |
binary_jaccard_index( | |
preds=pred_binary_map, | |
target=target_binary_map, | |
) | |
.detach() | |
.cpu() | |
) | |
binary_jaccard_scores.append(float(cell_jaccard)) | |
batch_metrics = { | |
"binary_dice_scores": binary_dice_scores, | |
"binary_jaccard_scores": binary_jaccard_scores, | |
"tissue_pred": pred_tissue, | |
"tissue_gt": gt["tissue_types"], | |
} | |
return batch_metrics | |
def calculate_step_metric_validation(self, predictions: dict, gt: dict) -> dict: | |
"""Calculate the metrics for the training step | |
Args: | |
predictions (DataclassHVStorage): OrderedDict: Processed network output | |
gt (DataclassHVStorage): Ground truth values | |
Returns: | |
dict: Dictionary with metrics. Keys: | |
binary_dice_scores, binary_jaccard_scores, tissue_pred, tissue_gt | |
""" | |
predictions = predictions.get_dict() | |
gt = gt.get_dict() | |
# Tissue Tpyes logits to probs and argmax to get class | |
predictions["tissue_types_classes"] = F.softmax( | |
predictions["tissue_types"], dim=-1 | |
) | |
pred_tissue = ( | |
torch.argmax(predictions["tissue_types_classes"], dim=-1) | |
.detach() | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
predictions["instance_map"] = predictions["instance_map"].detach().cpu() | |
predictions["instance_types_nuclei"] = ( | |
predictions["instance_types_nuclei"].detach().cpu().numpy().astype("int32") | |
) | |
#change | |
predictions["instance_types_nuclei"] = predictions["instance_types_nuclei"].transpose(0, 3, 1, 2) | |
instance_maps_gt = gt["instance_map"].detach().cpu() | |
gt["tissue_types"] = gt["tissue_types"].detach().cpu().numpy().astype(np.uint8) | |
gt["nuclei_binary_map"] = torch.argmax(gt["nuclei_binary_map"], dim=1).type( | |
torch.uint8 | |
) | |
gt["instance_types_nuclei"] = ( | |
gt["instance_types_nuclei"].detach().cpu().numpy().astype("int32") | |
) | |
tissue_detection_accuracy = accuracy_score( | |
y_true=gt["tissue_types"], y_pred=pred_tissue | |
) | |
self.batch_avg_tissue_acc.update(tissue_detection_accuracy) | |
binary_dice_scores = [] | |
binary_jaccard_scores = [] | |
cell_type_pq_scores = [] | |
pq_scores = [] | |
for i in range(len(pred_tissue)): | |
# binary dice score: Score for cell detection per image, without background | |
pred_binary_map = torch.argmax(predictions["nuclei_binary_map"][i], dim=0) | |
target_binary_map = gt["nuclei_binary_map"][i] | |
cell_dice = ( | |
dice(preds=pred_binary_map, target=target_binary_map, ignore_index=0) | |
.detach() | |
.cpu() | |
) | |
binary_dice_scores.append(float(cell_dice)) | |
# binary aji | |
cell_jaccard = ( | |
binary_jaccard_index( | |
preds=pred_binary_map, | |
target=target_binary_map, | |
) | |
.detach() | |
.cpu() | |
) | |
binary_jaccard_scores.append(float(cell_jaccard)) | |
# pq values | |
remapped_instance_pred = remap_label(predictions["instance_map"][i]) | |
remapped_gt = remap_label(instance_maps_gt[i]) | |
[_, _, pq], _ = get_fast_pq(true=remapped_gt, pred=remapped_instance_pred) | |
pq_scores.append(pq) | |
#pq values per class (skip background) | |
nuclei_type_pq = [] | |
for j in range(0, self.num_classes): | |
pred_nuclei_instance_class = remap_label( | |
predictions["instance_types_nuclei"][i][j, ...] | |
) | |
target_nuclei_instance_class = remap_label( | |
gt["instance_types_nuclei"][i][j, ...] | |
) | |
# if ground truth is empty, skip from calculation | |
if len(np.unique(target_nuclei_instance_class)) == 1: | |
pq_tmp = np.nan | |
else: | |
[_, _, pq_tmp], _ = get_fast_pq( | |
pred_nuclei_instance_class, | |
target_nuclei_instance_class, | |
match_iou=0.5, | |
) | |
nuclei_type_pq.append(pq_tmp) | |
cell_type_pq_scores.append(nuclei_type_pq) | |
batch_metrics = { | |
"binary_dice_scores": binary_dice_scores, | |
"binary_jaccard_scores": binary_jaccard_scores, | |
"pq_scores": pq_scores, | |
"cell_type_pq_scores": cell_type_pq_scores, | |
"tissue_pred": pred_tissue, | |
"tissue_gt": gt["tissue_types"], | |
} | |
return batch_metrics | |
def generate_example_image( | |
imgs: Union[torch.Tensor, np.ndarray], | |
predictions: DataclassHVStorage, | |
gt: DataclassHVStorage, | |
num_nuclei_classes: int, | |
num_images: int = 2, | |
) -> plt.Figure: | |
"""Generate example plot with image, binary_pred, hv-map and instance map from prediction and ground-truth | |
Args: | |
imgs (Union[torch.Tensor, np.ndarray]): Images to process, a random number (num_images) is selected from this stack | |
Shape: (batch_size, 3, H', W') | |
predictions (DataclassHVStorage): Predictions | |
gt (DataclassHVStorage): gt | |
num_nuclei_classes (int): Number of total nuclei classes including background | |
num_images (int, optional): Number of example patches to display. Defaults to 2. | |
Returns: | |
plt.Figure: Figure with example patches | |
""" | |
predictions = predictions.get_dict() | |
gt = gt.get_dict() | |
assert num_images <= imgs.shape[0] | |
num_images = 4 | |
predictions["nuclei_binary_map"] = predictions["nuclei_binary_map"].permute( | |
0, 2, 3, 1 | |
) | |
predictions["hv_map"] = predictions["hv_map"].permute(0, 2, 3, 1) | |
predictions["nuclei_type_map"] = predictions["nuclei_type_map"].permute( | |
0, 2, 3, 1 | |
) | |
predictions["instance_types_nuclei"] = predictions[ | |
"instance_types_nuclei" | |
].transpose(0, 2, 3, 1) | |
gt["hv_map"] = gt["hv_map"].permute(0, 2, 3, 1) | |
gt["nuclei_type_map"] = gt["nuclei_type_map"].permute(0, 2, 3, 1) | |
predictions["instance_types_nuclei"] = predictions[ | |
"instance_types_nuclei" | |
].transpose(0, 2, 3, 1) | |
h = gt["hv_map"].shape[1] | |
w = gt["hv_map"].shape[2] | |
sample_indices = torch.randint(0, imgs.shape[0], (num_images,)) | |
# convert to rgb and crop to selection | |
sample_images = ( | |
imgs[sample_indices].permute(0, 2, 3, 1).contiguous().cpu().numpy() | |
) # convert to rgb | |
sample_images = cropping_center(sample_images, (h, w), True) | |
# get predictions | |
pred_sample_binary_map = ( | |
predictions["nuclei_binary_map"][sample_indices, :, :, 1] | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
pred_sample_hv_map = ( | |
predictions["hv_map"][sample_indices].detach().cpu().numpy() | |
) | |
pred_sample_instance_maps = ( | |
predictions["instance_map"][sample_indices].detach().cpu().numpy() | |
) | |
pred_sample_type_maps = ( | |
torch.argmax(predictions["nuclei_type_map"][sample_indices], dim=-1) | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
# get ground truth labels | |
gt_sample_binary_map = ( | |
gt["nuclei_binary_map"][sample_indices].detach().cpu().numpy() | |
) | |
gt_sample_hv_map = gt["hv_map"][sample_indices].detach().cpu().numpy() | |
gt_sample_instance_map = ( | |
gt["instance_map"][sample_indices].detach().cpu().numpy() | |
) | |
gt_sample_type_map = ( | |
torch.argmax(gt["nuclei_type_map"][sample_indices], dim=-1) | |
.detach() | |
.cpu() | |
.numpy() | |
) | |
# create colormaps | |
hv_cmap = plt.get_cmap("jet") | |
binary_cmap = plt.get_cmap("jet") | |
instance_map = plt.get_cmap("viridis") | |
# setup plot | |
fig, axs = plt.subplots(num_images, figsize=(6, 2 * num_images), dpi=150) | |
for i in range(num_images): | |
placeholder = np.zeros((2 * h, 6 * w, 3)) | |
# orig image | |
placeholder[:h, :w, :3] = sample_images[i] | |
placeholder[h : 2 * h, :w, :3] = sample_images[i] | |
# binary prediction | |
placeholder[:h, w : 2 * w, :3] = rgba2rgb( | |
binary_cmap(gt_sample_binary_map[i] * 255) | |
) | |
placeholder[h : 2 * h, w : 2 * w, :3] = rgba2rgb( | |
binary_cmap(pred_sample_binary_map[i]) | |
) # *255? | |
# hv maps | |
placeholder[:h, 2 * w : 3 * w, :3] = rgba2rgb( | |
hv_cmap((gt_sample_hv_map[i, :, :, 0] + 1) / 2) | |
) | |
placeholder[h : 2 * h, 2 * w : 3 * w, :3] = rgba2rgb( | |
hv_cmap((pred_sample_hv_map[i, :, :, 0] + 1) / 2) | |
) | |
placeholder[:h, 3 * w : 4 * w, :3] = rgba2rgb( | |
hv_cmap((gt_sample_hv_map[i, :, :, 1] + 1) / 2) | |
) | |
placeholder[h : 2 * h, 3 * w : 4 * w, :3] = rgba2rgb( | |
hv_cmap((pred_sample_hv_map[i, :, :, 1] + 1) / 2) | |
) | |
# instance_predictions | |
placeholder[:h, 4 * w : 5 * w, :3] = rgba2rgb( | |
instance_map( | |
(gt_sample_instance_map[i] - np.min(gt_sample_instance_map[i])) | |
/ ( | |
np.max(gt_sample_instance_map[i]) | |
- np.min(gt_sample_instance_map[i] + 1e-10) | |
) | |
) | |
) | |
placeholder[h : 2 * h, 4 * w : 5 * w, :3] = rgba2rgb( | |
instance_map( | |
( | |
pred_sample_instance_maps[i] | |
- np.min(pred_sample_instance_maps[i]) | |
) | |
/ ( | |
np.max(pred_sample_instance_maps[i]) | |
- np.min(pred_sample_instance_maps[i] + 1e-10) | |
) | |
) | |
) | |
# type_predictions | |
placeholder[:h, 5 * w : 6 * w, :3] = rgba2rgb( | |
binary_cmap(gt_sample_type_map[i] / num_nuclei_classes) | |
) | |
placeholder[h : 2 * h, 5 * w : 6 * w, :3] = rgba2rgb( | |
binary_cmap(pred_sample_type_maps[i] / num_nuclei_classes) | |
) | |
# plotting | |
axs[i].imshow(placeholder) | |
axs[i].set_xticks([], []) | |
# plot labels in first row | |
if i == 0: | |
axs[i].set_xticks(np.arange(w / 2, 6 * w, w)) | |
axs[i].set_xticklabels( | |
[ | |
"Image", | |
"Binary-Cells", | |
"HV-Map-0", | |
"HV-Map-1", | |
"Cell Instances", | |
"Nuclei-Instances", | |
], | |
fontsize=6, | |
) | |
axs[i].xaxis.tick_top() | |
axs[i].set_yticks(np.arange(h / 2, 2 * h, h)) | |
axs[i].set_yticklabels(["GT", "Pred."], fontsize=6) | |
axs[i].tick_params(axis="both", which="both", length=0) | |
grid_x = np.arange(w, 6 * w, w) | |
grid_y = np.arange(h, 2 * h, h) | |
for x_seg in grid_x: | |
axs[i].axvline(x_seg, color="black") | |
for y_seg in grid_y: | |
axs[i].axhline(y_seg, color="black") | |
fig.suptitle(f"Patch Predictions for {num_images} Examples") | |
fig.tight_layout() | |
return fig | |