Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmpose.registry import MODELS | |
| class BCELoss(nn.Module): | |
| """Binary Cross Entropy loss. | |
| Args: | |
| use_target_weight (bool): Option to use weighted loss. | |
| Different joint types may have different target weights. | |
| loss_weight (float): Weight of the loss. Default: 1.0. | |
| with_logits (bool): Whether to use BCEWithLogitsLoss. Default: False. | |
| """ | |
| def __init__(self, | |
| use_target_weight=False, | |
| loss_weight=1., | |
| with_logits=False): | |
| super().__init__() | |
| self.criterion = F.binary_cross_entropy if not with_logits\ | |
| else F.binary_cross_entropy_with_logits | |
| self.use_target_weight = use_target_weight | |
| self.loss_weight = loss_weight | |
| def forward(self, output, target, target_weight=None): | |
| """Forward function. | |
| Note: | |
| - batch_size: N | |
| - num_labels: K | |
| Args: | |
| output (torch.Tensor[N, K]): Output classification. | |
| target (torch.Tensor[N, K]): Target classification. | |
| target_weight (torch.Tensor[N, K] or torch.Tensor[N]): | |
| Weights across different labels. | |
| """ | |
| if self.use_target_weight: | |
| assert target_weight is not None | |
| loss = self.criterion(output, target, reduction='none') | |
| if target_weight.dim() == 1: | |
| target_weight = target_weight[:, None] | |
| loss = (loss * target_weight).mean() | |
| else: | |
| loss = self.criterion(output, target) | |
| return loss * self.loss_weight | |
| class JSDiscretLoss(nn.Module): | |
| """Discrete JS Divergence loss for DSNT with Gaussian Heatmap. | |
| Modified from `the official implementation | |
| <https://github.com/anibali/dsntnn/blob/master/dsntnn/__init__.py>`_. | |
| Args: | |
| use_target_weight (bool): Option to use weighted loss. | |
| Different joint types may have different target weights. | |
| size_average (bool): Option to average the loss by the batch_size. | |
| """ | |
| def __init__( | |
| self, | |
| use_target_weight=True, | |
| size_average: bool = True, | |
| ): | |
| super(JSDiscretLoss, self).__init__() | |
| self.use_target_weight = use_target_weight | |
| self.size_average = size_average | |
| self.kl_loss = nn.KLDivLoss(reduction='none') | |
| def kl(self, p, q): | |
| """Kullback-Leibler Divergence.""" | |
| eps = 1e-24 | |
| kl_values = self.kl_loss((q + eps).log(), p) | |
| return kl_values | |
| def js(self, pred_hm, gt_hm): | |
| """Jensen-Shannon Divergence.""" | |
| m = 0.5 * (pred_hm + gt_hm) | |
| js_values = 0.5 * (self.kl(pred_hm, m) + self.kl(gt_hm, m)) | |
| return js_values | |
| def forward(self, pred_hm, gt_hm, target_weight=None): | |
| """Forward function. | |
| Args: | |
| pred_hm (torch.Tensor[N, K, H, W]): Predicted heatmaps. | |
| gt_hm (torch.Tensor[N, K, H, W]): Target heatmaps. | |
| target_weight (torch.Tensor[N, K] or torch.Tensor[N]): | |
| Weights across different labels. | |
| Returns: | |
| torch.Tensor: Loss value. | |
| """ | |
| if self.use_target_weight: | |
| assert target_weight is not None | |
| assert pred_hm.ndim >= target_weight.ndim | |
| for i in range(pred_hm.ndim - target_weight.ndim): | |
| target_weight = target_weight.unsqueeze(-1) | |
| loss = self.js(pred_hm * target_weight, gt_hm * target_weight) | |
| else: | |
| loss = self.js(pred_hm, gt_hm) | |
| if self.size_average: | |
| loss /= len(gt_hm) | |
| return loss.sum() | |
| class KLDiscretLoss(nn.Module): | |
| """Discrete KL Divergence loss for SimCC with Gaussian Label Smoothing. | |
| Modified from `the official implementation. | |
| <https://github.com/leeyegy/SimCC>`_. | |
| Args: | |
| beta (float): Temperature factor of Softmax. | |
| label_softmax (bool): Whether to use Softmax on labels. | |
| use_target_weight (bool): Option to use weighted loss. | |
| Different joint types may have different target weights. | |
| """ | |
| def __init__(self, beta=1.0, label_softmax=False, use_target_weight=True): | |
| super(KLDiscretLoss, self).__init__() | |
| self.beta = beta | |
| self.label_softmax = label_softmax | |
| self.use_target_weight = use_target_weight | |
| self.log_softmax = nn.LogSoftmax(dim=1) | |
| self.kl_loss = nn.KLDivLoss(reduction='none') | |
| def criterion(self, dec_outs, labels): | |
| """Criterion function.""" | |
| log_pt = self.log_softmax(dec_outs * self.beta) | |
| if self.label_softmax: | |
| labels = F.softmax(labels * self.beta, dim=1) | |
| loss = torch.mean(self.kl_loss(log_pt, labels), dim=1) | |
| return loss | |
| def forward(self, pred_simcc, gt_simcc, target_weight): | |
| """Forward function. | |
| Args: | |
| pred_simcc (Tuple[Tensor, Tensor]): Predicted SimCC vectors of | |
| x-axis and y-axis. | |
| gt_simcc (Tuple[Tensor, Tensor]): Target representations. | |
| target_weight (torch.Tensor[N, K] or torch.Tensor[N]): | |
| Weights across different labels. | |
| """ | |
| num_joints = pred_simcc[0].size(1) | |
| loss = 0 | |
| if self.use_target_weight: | |
| weight = target_weight.reshape(-1) | |
| else: | |
| weight = 1. | |
| for pred, target in zip(pred_simcc, gt_simcc): | |
| pred = pred.reshape(-1, pred.size(-1)) | |
| target = target.reshape(-1, target.size(-1)) | |
| loss += self.criterion(pred, target).mul(weight).sum() | |
| return loss / num_joints | |
| class InfoNCELoss(nn.Module): | |
| """InfoNCE loss for training a discriminative representation space with a | |
| contrastive manner. | |
| `Representation Learning with Contrastive Predictive Coding | |
| arXiv: <https://arxiv.org/abs/1611.05424>`_. | |
| Args: | |
| temperature (float, optional): The temperature to use in the softmax | |
| function. Higher temperatures lead to softer probability | |
| distributions. Defaults to 1.0. | |
| loss_weight (float, optional): The weight to apply to the loss. | |
| Defaults to 1.0. | |
| """ | |
| def __init__(self, temperature: float = 1.0, loss_weight=1.0) -> None: | |
| super(InfoNCELoss, self).__init__() | |
| assert temperature > 0, f'the argument `temperature` must be ' \ | |
| f'positive, but got {temperature}' | |
| self.temp = temperature | |
| self.loss_weight = loss_weight | |
| def forward(self, features: torch.Tensor) -> torch.Tensor: | |
| """Computes the InfoNCE loss. | |
| Args: | |
| features (Tensor): A tensor containing the feature | |
| representations of different samples. | |
| Returns: | |
| Tensor: A tensor of shape (1,) containing the InfoNCE loss. | |
| """ | |
| n = features.size(0) | |
| features_norm = F.normalize(features, dim=1) | |
| logits = features_norm.mm(features_norm.t()) / self.temp | |
| targets = torch.arange(n, dtype=torch.long, device=features.device) | |
| loss = F.cross_entropy(logits, targets, reduction='sum') | |
| return loss * self.loss_weight | |