import os import sys sys.path.append(os.getcwd()) import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class KeypointLoss(nn.Module): def __init__(self): super(KeypointLoss, self).__init__() def forward(self, pred_seq, gt_seq, gt_conf=None): #pred_seq: (B, C, T) if gt_conf is not None: gt_conf = gt_conf >= 0.01 return F.mse_loss(pred_seq[gt_conf], gt_seq[gt_conf], reduction='mean') else: return F.mse_loss(pred_seq, gt_seq) class KLLoss(nn.Module): def __init__(self, kl_tolerance): super(KLLoss, self).__init__() self.kl_tolerance = kl_tolerance def forward(self, mu, var, mul=1): kl_tolerance = self.kl_tolerance * mul * var.shape[1] / 64 kld_loss = -0.5 * torch.sum(1 + var - mu**2 - var.exp(), dim=1) # kld_loss = -0.5 * torch.sum(1 + (var-1) - (mu) ** 2 - (var-1).exp(), dim=1) if self.kl_tolerance is not None: # above_line = kld_loss[kld_loss > self.kl_tolerance] # if len(above_line) > 0: # kld_loss = torch.mean(kld_loss) # else: # kld_loss = 0 kld_loss = torch.where(kld_loss > kl_tolerance, kld_loss, torch.tensor(kl_tolerance, device='cuda')) # else: kld_loss = torch.mean(kld_loss) return kld_loss class L2KLLoss(nn.Module): def __init__(self, kl_tolerance): super(L2KLLoss, self).__init__() self.kl_tolerance = kl_tolerance def forward(self, x): # TODO: check kld_loss = torch.sum(x ** 2, dim=1) if self.kl_tolerance is not None: above_line = kld_loss[kld_loss > self.kl_tolerance] if len(above_line) > 0: kld_loss = torch.mean(kld_loss) else: kld_loss = 0 else: kld_loss = torch.mean(kld_loss) return kld_loss class L2RegLoss(nn.Module): def __init__(self): super(L2RegLoss, self).__init__() def forward(self, x): #TODO: check return torch.sum(x**2) class L2Loss(nn.Module): def __init__(self): super(L2Loss, self).__init__() def forward(self, x): # TODO: check return torch.sum(x ** 2) class AudioLoss(nn.Module): def __init__(self): super(AudioLoss, self).__init__() def forward(self, dynamics, gt_poses): #pay attention, normalized mean = torch.mean(gt_poses, dim=-1).unsqueeze(-1) gt = gt_poses - mean return F.mse_loss(dynamics, gt) L1Loss = nn.L1Loss