WSCL / losses /volume_mask_loss.py
yhzhai's picture
release code
482ab8a
raw
history blame
1.36 kB
import torch
import torch.nn as nn
from einops import rearrange
def get_volume_mask_loss(opt):
return VolumeMaskLoss()
class VolumeMaskLoss(nn.Module):
def __init__(self):
super().__init__()
self.bce_loss = nn.BCELoss(reduction="mean")
def _get_volume_mask(self, mask):
with torch.no_grad():
h, w = mask.shape[-2:]
# use orthogonal vector [0, 1] and [1, 0] to generate the ground truth
mask[torch.where(mask > 0.5)] = 1.0
mask[torch.where(mask <= 0.5)] = 0.0
mask = rearrange(mask, "b c h w -> b c (h w)")
mask_append = 1 - mask.clone()
mask = torch.cat([mask, mask_append], dim=1)
mask = torch.bmm(mask.transpose(-1, -2), mask)
mask = rearrange(mask, "b (h1 w1) (h2 w2) -> b h1 w1 h2 w2", h1=h, h2=h)
mask = 1 - mask # 0 indicates consistency, and 1 indicates inconsistency
return mask
def forward(self, out_volume, mask):
volume_size = out_volume.shape[-2:]
if volume_size != mask.shape[-2:]:
mask = nn.functional.interpolate(
mask, size=volume_size, mode="bilinear", align_corners=False
)
volume_mask = self._get_volume_mask(mask)
loss = self.bce_loss(out_volume, volume_mask)
return {"loss": loss}