File size: 3,918 Bytes
2ca2f68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import monai
import torch
import numpy as np
import matplotlib.pyplot as plt


def warp_func():
    warp = monai.networks.blocks.Warp(mode="bilinear", padding_mode="border")
    return warp


def warp_nearest_func():
    warp_nearest = monai.networks.blocks.Warp(
        mode="nearest", padding_mode="border")
    return warp_nearest


def lncc_loss_func():
    lncc_loss = monai.losses.LocalNormalizedCrossCorrelationLoss(
        spatial_dims=3,
        kernel_size=3,
        kernel_type='rectangular',
        reduction="mean",
        smooth_nr=1e-5,
        smooth_dr=1e-5,
    )
    return lncc_loss


def similarity_loss(displacement_field, image_pair):
    warp = warp_func()
    lncc_loss = lncc_loss_func()
    """ Accepts a batch of displacement fields, shape (B,3,H,W,D),
        and a batch of image pairs, shape (B,2,H,W,D). """
    warped_img2 = warp(image_pair[:, [1], :, :, :], displacement_field)
    return lncc_loss(
        warped_img2,  # prediction
        image_pair[:, [0], :, :, :]  # target
    )


def regularization_loss_func():
    # normalize=True, reduction='mean'
    return monai.losses.BendingEnergyLoss(normalize=True, reduction='mean')


def dice_loss_func():
    dice_loss = monai.losses.DiceLoss(
        include_background=True,
        to_onehot_y=False,
        softmax=False,
        reduction="mean"
    )
    return dice_loss


def dice_loss_func2():
    dice_loss = monai.losses.DiceLoss(
        include_background=True,
        to_onehot_y=True,
        softmax=True,
        reduction="mean"
    )
    return dice_loss


def anatomy_loss(displacement_field, image_pair, seg_net, gt_seg1=None, gt_seg2=None, num_segmentation_classes=None):
    """
    Accepts a batch of displacement fields, shape (B,3,H,W,D),
    and a batch of image pairs, shape (B,2,H,W,D).
    seg_net is the model used to segment an image,
      mapping (B,1,H,W,D) to (B,C,H,W,D) where C is the number of segmentation classes.
    gt_seg1 and gt_seg2 are ground truth segmentations for the images in image_pair, if ground truth is available;
      if unavailable then they can be None.
      gt_seg1 and gt_seg2 are expected to be in the form of class labels, with shape (B,1,H,W,D).
    """
    if gt_seg1 is not None:
        # ground truth seg of target image
        seg1 = monai.networks.one_hot(
            gt_seg1, num_segmentation_classes
        )
    else:
        # seg_net on target image, "noisy ground truth"
        seg1 = seg_net(image_pair[:, [0], :, :, :]).softmax(dim=1)

    if gt_seg2 is not None:
        # ground truth seg of moving image
        seg2 = monai.networks.one_hot(
            gt_seg2, num_segmentation_classes
        )
    else:
        # seg_net on moving image, "noisy ground truth"
        seg2 = seg_net(image_pair[:, [1], :, :, :]).softmax(dim=1)

    # seg1 and seg2 are now in the form of class probabilities at each voxel
    # The trilinear interpolation of the function `warp` is then safe to use;
    # it will preserve the probabilistic interpretation of seg2.
    dice_loss = dice_loss_func()
    warp = warp_func()
    return dice_loss(
        warp(seg2, displacement_field),  # warp of moving image segmentation
        seg1  # target image segmentation
    )


def reg_losses(batch, device, reg_net, seg_net, num_segmentation_classes):
    img12 = batch['img12'].to(device)
    displacement_field12 = reg_net(img12)
    loss_sim = similarity_loss(displacement_field12, img12)
    regularization_loss = regularization_loss_func()
    loss_reg = regularization_loss(displacement_field12)

    gt_seg1 = batch['seg1'].to(device) if 'seg1' in batch.keys() else None
    gt_seg2 = batch['seg2'].to(device) if 'seg2' in batch.keys() else None
    loss_ana = anatomy_loss(displacement_field12, img12,
                            seg_net, gt_seg1, gt_seg2, num_segmentation_classes)

    return loss_sim, loss_reg, loss_ana, displacement_field12