File size: 1,360 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
from einops import rearrange


def get_volume_mask_loss(opt):
    return VolumeMaskLoss()


class VolumeMaskLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.bce_loss = nn.BCELoss(reduction="mean")

    def _get_volume_mask(self, mask):
        with torch.no_grad():
            h, w = mask.shape[-2:]
            # use orthogonal vector [0, 1] and [1, 0] to generate the ground truth
            mask[torch.where(mask > 0.5)] = 1.0
            mask[torch.where(mask <= 0.5)] = 0.0

            mask = rearrange(mask, "b c h w -> b c (h w)")
            mask_append = 1 - mask.clone()
            mask = torch.cat([mask, mask_append], dim=1)
            mask = torch.bmm(mask.transpose(-1, -2), mask)
            mask = rearrange(mask, "b (h1 w1) (h2 w2) -> b h1 w1 h2 w2", h1=h, h2=h)
            mask = 1 - mask  # 0 indicates consistency, and 1 indicates inconsistency
        return mask

    def forward(self, out_volume, mask):
        volume_size = out_volume.shape[-2:]
        if volume_size != mask.shape[-2:]:
            mask = nn.functional.interpolate(
                mask, size=volume_size, mode="bilinear", align_corners=False
            )
        volume_mask = self._get_volume_mask(mask)
        loss = self.bce_loss(out_volume, volume_mask)

        return {"loss": loss}