mm3dtest / projects /CENet /cenet /boundary_loss.py
giantmonkeyTC
2344
34d1f8b
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from mmdet3d.registry import MODELS
def one_hot(label: Tensor,
n_classes: int,
requires_grad: bool = True) -> Tensor:
"""Return One Hot Label."""
device = label.device
one_hot_label = torch.eye(
n_classes, device=device, requires_grad=requires_grad)[label]
one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)
return one_hot_label
@MODELS.register_module()
class BoundaryLoss(nn.Module):
"""Boundary loss."""
def __init__(self, theta0=3, theta=5, loss_weight: float = 1.0) -> None:
super(BoundaryLoss, self).__init__()
self.theta0 = theta0
self.theta = theta
self.loss_weight = loss_weight
def forward(self, pred: Tensor, gt: Tensor) -> Tensor:
"""Forward function.
Args:
pred (Tensor): The output from model.
gt (Tensor): Ground truth map.
Returns:
Tensor: Loss tensor.
"""
pred = F.softmax(pred, dim=1)
n, c, _, _ = pred.shape
# one-hot vector of ground truth
one_hot_gt = one_hot(gt, c)
# boundary map
gt_b = F.max_pool2d(
1 - one_hot_gt,
kernel_size=self.theta0,
stride=1,
padding=(self.theta0 - 1) // 2)
gt_b -= 1 - one_hot_gt
pred_b = F.max_pool2d(
1 - pred,
kernel_size=self.theta0,
stride=1,
padding=(self.theta0 - 1) // 2)
pred_b -= 1 - pred
gt_b = gt_b.view(n, c, -1)
pred_b = pred_b.view(n, c, -1)
# Precision, Recall
P = torch.sum(pred_b * gt_b, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
R = torch.sum(pred_b * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)
# Boundary F1 Score
BF1 = 2 * P * R / (P + R + 1e-7)
# summing BF1 Score for each class and average over mini-batch
loss = torch.mean(1 - BF1)
return self.loss_weight * loss