Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| """Modified from https://github.com/JunMa11/SegWithDistMap/blob/ | |
| master/code/train_LA_HD.py (Apache-2.0 License)""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from scipy.ndimage import distance_transform_edt as distance | |
| from torch import Tensor | |
| from mmseg.registry import MODELS | |
| from .utils import get_class_weight, weighted_loss | |
| def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor: | |
| """ | |
| compute the distance transform map of foreground in mask | |
| Args: | |
| img_gt: Ground truth of the image, (b, h, w) | |
| pred: Predictions of the segmentation head after softmax, (b, c, h, w) | |
| Returns: | |
| output: the foreground Distance Map (SDM) | |
| dtm(x) = 0; x in segmentation boundary | |
| inf|x-y|; x in segmentation | |
| """ | |
| fg_dtm = torch.zeros_like(pred) | |
| out_shape = pred.shape | |
| for b in range(out_shape[0]): # batch size | |
| for c in range(1, out_shape[1]): # default 0 channel is background | |
| posmask = img_gt[b].byte() | |
| if posmask.any(): | |
| posdis = distance(posmask) | |
| fg_dtm[b][c] = torch.from_numpy(posdis) | |
| return fg_dtm | |
| def hd_loss(seg_soft: Tensor, | |
| gt: Tensor, | |
| seg_dtm: Tensor, | |
| gt_dtm: Tensor, | |
| class_weight=None, | |
| ignore_index=255) -> Tensor: | |
| """ | |
| compute huasdorff distance loss for segmentation | |
| Args: | |
| seg_soft: softmax results, shape=(b,c,x,y) | |
| gt: ground truth, shape=(b,x,y) | |
| seg_dtm: segmentation distance transform map, shape=(b,c,x,y) | |
| gt_dtm: ground truth distance transform map, shape=(b,c,x,y) | |
| Returns: | |
| output: hd_loss | |
| """ | |
| assert seg_soft.shape[0] == gt.shape[0] | |
| total_loss = 0 | |
| num_class = seg_soft.shape[1] | |
| if class_weight is not None: | |
| assert class_weight.ndim == num_class | |
| for i in range(1, num_class): | |
| if i != ignore_index: | |
| delta_s = (seg_soft[:, i, ...] - gt.float())**2 | |
| s_dtm = seg_dtm[:, i, ...]**2 | |
| g_dtm = gt_dtm[:, i, ...]**2 | |
| dtm = s_dtm + g_dtm | |
| multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm) | |
| hd_loss = multiplied.mean() | |
| if class_weight is not None: | |
| hd_loss *= class_weight[i] | |
| total_loss += hd_loss | |
| return total_loss / num_class | |
| class HuasdorffDisstanceLoss(nn.Module): | |
| """HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform | |
| Maps Boost Segmentation CNNs: An Empirical Study. | |
| <http://proceedings.mlr.press/v121/ma20b.html>`_. | |
| Args: | |
| reduction (str, optional): The method used to reduce the loss into | |
| a scalar. Defaults to 'mean'. | |
| class_weight (list[float] | str, optional): Weight of each class. If in | |
| str format, read them from a file. Defaults to None. | |
| loss_weight (float): Weight of the loss. Defaults to 1.0. | |
| ignore_index (int | None): The label index to be ignored. Default: 255. | |
| loss_name (str): Name of the loss item. If you want this loss | |
| item to be included into the backward graph, `loss_` must be the | |
| prefix of the name. Defaults to 'loss_boundary'. | |
| """ | |
| def __init__(self, | |
| reduction='mean', | |
| class_weight=None, | |
| loss_weight=1.0, | |
| ignore_index=255, | |
| loss_name='loss_huasdorff_disstance', | |
| **kwargs): | |
| super().__init__() | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| self.class_weight = get_class_weight(class_weight) | |
| self._loss_name = loss_name | |
| self.ignore_index = ignore_index | |
| def forward(self, | |
| pred: Tensor, | |
| target: Tensor, | |
| avg_factor=None, | |
| reduction_override=None, | |
| **kwargs) -> Tensor: | |
| """Forward function. | |
| Args: | |
| pred (Tensor): Predictions of the segmentation head. (B, C, H, W) | |
| target (Tensor): Ground truth of the image. (B, H, W) | |
| avg_factor (int, optional): Average factor that is used to | |
| average the loss. Defaults to None. | |
| reduction_override (str, optional): The reduction method used | |
| to override the original reduction method of the loss. | |
| Options are "none", "mean" and "sum". | |
| Returns: | |
| Tensor: Loss tensor. | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = ( | |
| reduction_override if reduction_override else self.reduction) | |
| if self.class_weight is not None: | |
| class_weight = pred.new_tensor(self.class_weight) | |
| else: | |
| class_weight = None | |
| pred_soft = F.softmax(pred, dim=1) | |
| valid_mask = (target != self.ignore_index).long() | |
| target = target * valid_mask | |
| with torch.no_grad(): | |
| gt_dtm = compute_dtm(target.cpu(), pred_soft) | |
| gt_dtm = gt_dtm.float() | |
| seg_dtm2 = compute_dtm( | |
| pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft) | |
| seg_dtm2 = seg_dtm2.float() | |
| loss_hd = self.loss_weight * hd_loss( | |
| pred_soft, | |
| target, | |
| seg_dtm=seg_dtm2, | |
| gt_dtm=gt_dtm, | |
| reduction=reduction, | |
| avg_factor=avg_factor, | |
| class_weight=class_weight, | |
| ignore_index=self.ignore_index) | |
| return loss_hd | |
| def loss_name(self): | |
| return self._loss_name | |