|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
from einops import rearrange |
|
|
from typing import List, Tuple |
|
|
from .utils import _reshape_density, _bin_count |
|
|
|
|
|
EPS = 1e-8 |
|
|
|
|
|
|
|
|
class ZIPoissonNLL(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
reduction: str = "mean", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert reduction in ["none", "mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}." |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
logit_pi_maps: Tensor, |
|
|
lambda_maps: Tensor, |
|
|
gt_den_maps: Tensor, |
|
|
) -> Tensor: |
|
|
assert len(logit_pi_maps.shape) == len(lambda_maps.shape) == len(gt_den_maps.shape) == 4, f"Expected 4D (B, C, H, W) tensor, got {logit_pi_maps.shape}, {lambda_maps.shape}, and {gt_den_maps.shape}" |
|
|
B, _, H, W = lambda_maps.shape |
|
|
assert logit_pi_maps.shape == (B, 2, H, W), f"Expected logit_pi_maps to have shape (B, 2, H, W), got {logit_pi_maps.shape}" |
|
|
assert lambda_maps.shape == (B, 1, H, W), f"Expected lambda_maps to have shape (B, 1, H, W), got {lambda_maps.shape}" |
|
|
if gt_den_maps.shape[2:] != (H, W): |
|
|
gt_h, gt_w = gt_den_maps.shape[-2], gt_den_maps.shape[-1] |
|
|
assert gt_h % H == 0 and gt_w % W == 0 and gt_h // H == gt_w // W, f"Expected the spatial dimension of gt_den_maps to be a multiple of that of lambda_maps, got {gt_den_maps.shape} and {lambda_maps.shape}" |
|
|
gt_den_maps = _reshape_density(gt_den_maps, block_size=gt_h // H) |
|
|
assert gt_den_maps.shape == (B, 1, H, W), f"Expected gt_den_maps to have shape (B, 1, H, W), got {gt_den_maps.shape}" |
|
|
|
|
|
pi_maps = logit_pi_maps.softmax(dim=1) |
|
|
zero_indices = (gt_den_maps == 0).float() |
|
|
zero_loss = -torch.log(pi_maps[:, 0:1] + pi_maps[:, 1:] * torch.exp(-lambda_maps) + EPS) * zero_indices |
|
|
|
|
|
poisson_log_p = gt_den_maps * torch.log(lambda_maps + EPS) - lambda_maps |
|
|
nonzero_loss = (-torch.log(pi_maps[:, 1:] + EPS) - poisson_log_p) * (1.0 - zero_indices) |
|
|
|
|
|
loss = (zero_loss + nonzero_loss).sum(dim=(-1, -2)) |
|
|
if self.reduction == "mean": |
|
|
loss = loss.mean() |
|
|
elif self.reduction == "sum": |
|
|
loss = loss.sum() |
|
|
|
|
|
return loss, {"zipnll": loss.detach()} |
|
|
|
|
|
|
|
|
class ZICrossEntropy(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
bins: List[Tuple[int, int]], |
|
|
reduction: str = "mean", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert all([low <= high for low, high in bins]), f"Expected bins to be a list of tuples (low, high) where low <= high, got {bins}" |
|
|
assert reduction in ["mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}." |
|
|
|
|
|
self.bins = bins |
|
|
self.reduction = reduction |
|
|
self.ce_loss_fn = nn.CrossEntropyLoss(reduction="none") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
logit_maps: Tensor, |
|
|
gt_den_maps: Tensor, |
|
|
) -> Tensor: |
|
|
assert len(logit_maps.shape) == len(gt_den_maps.shape) == 4, f"Expected 4D (B, C, H, W) tensor, got {logit_maps.shape} and {gt_den_maps.shape}" |
|
|
B, _, H, W = logit_maps.shape |
|
|
assert logit_maps.shape[0] == B and logit_maps.shape[2:] == (H, W), f"Expected logit_maps to have shape (B, C, H, W), got {logit_maps.shape}" |
|
|
if gt_den_maps.shape[2:] != (H, W): |
|
|
gt_h, gt_w = gt_den_maps.shape[-2], gt_den_maps.shape[-1] |
|
|
assert gt_h % H == 0 and gt_w % W == 0 and gt_h // H == gt_w // W, f"Expected the spatial dimension of gt_den_maps to be a multiple of that of logit_maps, got {gt_den_maps.shape} and {logit_maps.shape}" |
|
|
gt_den_maps = _reshape_density(gt_den_maps, block_size=gt_h // H) |
|
|
assert gt_den_maps.shape == (B, 1, H, W), f"Expected gt_den_maps to have shape (B, 1, H, W), got {gt_den_maps.shape}" |
|
|
|
|
|
gt_class_maps = _bin_count(gt_den_maps, bins=self.bins) |
|
|
gt_class_maps = rearrange(gt_class_maps, "B H W -> B (H W)") |
|
|
logit_maps = rearrange(logit_maps, "B C H W -> B (H W) C") |
|
|
|
|
|
loss = 0.0 |
|
|
for idx in range(gt_class_maps.shape[0]): |
|
|
gt_class_map, logit_map = gt_class_maps[idx], logit_maps[idx] |
|
|
mask = gt_class_map > 0 |
|
|
|
|
|
gt_class_map = gt_class_map[mask] - 1 |
|
|
logit_map = logit_map[mask] |
|
|
loss += self.ce_loss_fn(logit_map, gt_class_map).sum() |
|
|
|
|
|
if self.reduction == "mean": |
|
|
loss /= gt_class_maps.shape[0] |
|
|
|
|
|
return loss, {"cls_zice": loss.detach()} |
|
|
|