|
import monai |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def warp_func(): |
|
warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="border") |
|
return warp |
|
|
|
|
|
def warp_nearest_func(): |
|
warp_nearest = monai.networks.blocks.Warp( |
|
mode="nearest", padding_mode="border") |
|
return warp_nearest |
|
|
|
|
|
def lncc_loss_func(): |
|
lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss( |
|
spatial_dims=3, |
|
kernel_size=3, |
|
kernel_type='rectangular', |
|
reduction="mean", |
|
smooth_nr=1e-5, |
|
smooth_dr=1e-5, |
|
) |
|
return lncc_loss |
|
|
|
|
|
def similarity_loss(displacement_field, image_pair): |
|
warp = warp_func() |
|
lncc_loss = lncc_loss_func() |
|
""" Accepts a batch of displacement fields, shape (B,3,H,W,D), |
|
and a batch of image pairs, shape (B,2,H,W,D). """ |
|
warped_img2 = warp(image_pair[:, [1], :, :, :], displacement_field) |
|
return lncc_loss( |
|
warped_img2, |
|
image_pair[:, [0], :, :, :] |
|
) |
|
|
|
|
|
def regularization_loss_func(): |
|
|
|
return monai.losses.BendingEnergyLoss(normalize=True, reduction='mean') |
|
|
|
|
|
def dice_loss_func(): |
|
dice_loss = monai.losses.DiceLoss( |
|
include_background=True, |
|
to_onehot_y=False, |
|
softmax=False, |
|
reduction="mean" |
|
) |
|
return dice_loss |
|
|
|
|
|
def dice_loss_func2(): |
|
dice_loss = monai.losses.DiceLoss( |
|
include_background=True, |
|
to_onehot_y=True, |
|
softmax=True, |
|
reduction="mean" |
|
) |
|
return dice_loss |
|
|
|
|
|
def anatomy_loss(displacement_field, image_pair, seg_net, gt_seg1=None, gt_seg2=None, num_segmentation_classes=None): |
|
""" |
|
Accepts a batch of displacement fields, shape (B,3,H,W,D), |
|
and a batch of image pairs, shape (B,2,H,W,D). |
|
seg_net is the model used to segment an image, |
|
mapping (B,1,H,W,D) to (B,C,H,W,D) where C is the number of segmentation classes. |
|
gt_seg1 and gt_seg2 are ground truth segmentations for the images in image_pair, if ground truth is available; |
|
if unavailable then they can be None. |
|
gt_seg1 and gt_seg2 are expected to be in the form of class labels, with shape (B,1,H,W,D). |
|
""" |
|
if gt_seg1 is not None: |
|
|
|
seg1 = monai.networks.one_hot( |
|
gt_seg1, num_segmentation_classes |
|
) |
|
else: |
|
|
|
seg1 = seg_net(image_pair[:, [0], :, :, :]).softmax(dim=1) |
|
|
|
if gt_seg2 is not None: |
|
|
|
seg2 = monai.networks.one_hot( |
|
gt_seg2, num_segmentation_classes |
|
) |
|
else: |
|
|
|
seg2 = seg_net(image_pair[:, [1], :, :, :]).softmax(dim=1) |
|
|
|
|
|
|
|
|
|
dice_loss = dice_loss_func() |
|
warp = warp_func() |
|
return dice_loss( |
|
warp(seg2, displacement_field), |
|
seg1 |
|
) |
|
|
|
|
|
def reg_losses(batch, device, reg_net, seg_net, num_segmentation_classes): |
|
img12 = batch['img12'].to(device) |
|
displacement_field12 = reg_net(img12) |
|
loss_sim = similarity_loss(displacement_field12, img12) |
|
regularization_loss = regularization_loss_func() |
|
loss_reg = regularization_loss(displacement_field12) |
|
|
|
gt_seg1 = batch['seg1'].to(device) if 'seg1' in batch.keys() else None |
|
gt_seg2 = batch['seg2'].to(device) if 'seg2' in batch.keys() else None |
|
loss_ana = anatomy_loss(displacement_field12, img12, |
|
seg_net, gt_seg1, gt_seg2, num_segmentation_classes) |
|
|
|
return loss_sim, loss_reg, loss_ana, displacement_field12 |
|
|