ZIP-B / losses /zero_inflated_poisson_nll.py
Yiming-M's picture
2025-07-31 15:53 🐣
0ecb9aa verified
raw
history blame
4.65 kB
import torch
from torch import nn, Tensor
from einops import rearrange
from typing import List, Tuple
from .utils import _reshape_density, _bin_count
EPS = 1e-8
class ZIPoissonNLL(nn.Module):
def __init__(
self,
reduction: str = "mean",
) -> None:
super().__init__()
assert reduction in ["none", "mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}."
self.reduction = reduction
def forward(
self,
logit_pi_maps: Tensor,
lambda_maps: Tensor,
gt_den_maps: Tensor,
) -> Tensor:
assert len(logit_pi_maps.shape) == len(lambda_maps.shape) == len(gt_den_maps.shape) == 4, f"Expected 4D (B, C, H, W) tensor, got {logit_pi_maps.shape}, {lambda_maps.shape}, and {gt_den_maps.shape}"
B, _, H, W = lambda_maps.shape
assert logit_pi_maps.shape == (B, 2, H, W), f"Expected logit_pi_maps to have shape (B, 2, H, W), got {logit_pi_maps.shape}"
assert lambda_maps.shape == (B, 1, H, W), f"Expected lambda_maps to have shape (B, 1, H, W), got {lambda_maps.shape}"
if gt_den_maps.shape[2:] != (H, W):
gt_h, gt_w = gt_den_maps.shape[-2], gt_den_maps.shape[-1]
assert gt_h % H == 0 and gt_w % W == 0 and gt_h // H == gt_w // W, f"Expected the spatial dimension of gt_den_maps to be a multiple of that of lambda_maps, got {gt_den_maps.shape} and {lambda_maps.shape}"
gt_den_maps = _reshape_density(gt_den_maps, block_size=gt_h // H)
assert gt_den_maps.shape == (B, 1, H, W), f"Expected gt_den_maps to have shape (B, 1, H, W), got {gt_den_maps.shape}"
pi_maps = logit_pi_maps.softmax(dim=1)
zero_indices = (gt_den_maps == 0).float()
zero_loss = -torch.log(pi_maps[:, 0:1] + pi_maps[:, 1:] * torch.exp(-lambda_maps) + EPS) * zero_indices
poisson_log_p = gt_den_maps * torch.log(lambda_maps + EPS) - lambda_maps
nonzero_loss = (-torch.log(pi_maps[:, 1:] + EPS) - poisson_log_p) * (1.0 - zero_indices)
loss = (zero_loss + nonzero_loss).sum(dim=(-1, -2))
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss, {"zipnll": loss.detach()}
class ZICrossEntropy(nn.Module):
def __init__(
self,
bins: List[Tuple[int, int]],
reduction: str = "mean",
) -> None:
super().__init__()
assert all([low <= high for low, high in bins]), f"Expected bins to be a list of tuples (low, high) where low <= high, got {bins}"
assert reduction in ["mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}."
self.bins = bins
self.reduction = reduction
self.ce_loss_fn = nn.CrossEntropyLoss(reduction="none")
def forward(
self,
logit_maps: Tensor,
gt_den_maps: Tensor,
) -> Tensor:
assert len(logit_maps.shape) == len(gt_den_maps.shape) == 4, f"Expected 4D (B, C, H, W) tensor, got {logit_maps.shape} and {gt_den_maps.shape}"
B, _, H, W = logit_maps.shape
assert logit_maps.shape[0] == B and logit_maps.shape[2:] == (H, W), f"Expected logit_maps to have shape (B, C, H, W), got {logit_maps.shape}"
if gt_den_maps.shape[2:] != (H, W):
gt_h, gt_w = gt_den_maps.shape[-2], gt_den_maps.shape[-1]
assert gt_h % H == 0 and gt_w % W == 0 and gt_h // H == gt_w // W, f"Expected the spatial dimension of gt_den_maps to be a multiple of that of logit_maps, got {gt_den_maps.shape} and {logit_maps.shape}"
gt_den_maps = _reshape_density(gt_den_maps, block_size=gt_h // H)
assert gt_den_maps.shape == (B, 1, H, W), f"Expected gt_den_maps to have shape (B, 1, H, W), got {gt_den_maps.shape}"
gt_class_maps = _bin_count(gt_den_maps, bins=self.bins)
gt_class_maps = rearrange(gt_class_maps, "B H W -> B (H W)") # flatten spatial dimensions
logit_maps = rearrange(logit_maps, "B C H W -> B (H W) C") # flatten spatial dimensions
loss = 0.0
for idx in range(gt_class_maps.shape[0]):
gt_class_map, logit_map = gt_class_maps[idx], logit_maps[idx]
mask = gt_class_map > 0
# Find gt_class_map values and logit_maps values where gt_class_map > 0
gt_class_map = gt_class_map[mask] - 1
logit_map = logit_map[mask]
loss += self.ce_loss_fn(logit_map, gt_class_map).sum()
if self.reduction == "mean":
loss /= gt_class_maps.shape[0]
return loss, {"cls_zice": loss.detach()}