File size: 2,115 Bytes
15fa80a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F
class ReConsLoss(nn.Module):
    def __init__(self, recons_loss, nb_joints):
        super(ReConsLoss, self).__init__()
        
        if recons_loss == 'l1': 
            self.Loss = torch.nn.L1Loss()
        elif recons_loss == 'l2' : 
            self.Loss = torch.nn.MSELoss()
        elif recons_loss == 'l1_smooth' : 
            self.Loss = torch.nn.SmoothL1Loss()
        
        # 4 global motion associated to root
        # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d)
        # 3 global vel xyz
        # 4 foot contact
        self.nb_joints = nb_joints
        self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
        
    def forward(self, motion_pred, motion_gt) : 
        loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
        return loss
    
    def forward_vel(self, motion_pred, motion_gt) : 
        loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
        return loss
    
def loss_robust(origin_prediction,perturbation_prediction,loss='l2'):
    if loss=='l1':
        return torch.nn.L1Loss(origin_prediction,perturbation_prediction)
    if loss=='l2':
        return torch.nn.MSELoss(origin_prediction,perturbation_prediction)
    if loss=='l1_smooth':
        return torch.nn.SmoothL1Loss(origin_prediction,perturbation_prediction)
    if loss=='jsd':
      return calculate_kl_divergence(origin_prediction,perturbation_prediction)


def calculate_kl_divergence(p, q):
    return torch.sum(p * (torch.log2(p) - torch.log2(q)))

def calculate_jsd_loss(p, q):
    # 将概率分布标准化为概率密度
    p_normalized = F.normalize(p, p=1, dim=-1)
    q_normalized = F.normalize(q, p=1, dim=-1)

    # 计算平均分布
    m = 0.5 * (p_normalized + q_normalized)

    # 计算两个分布与平均分布的 KL 散度
    kl_p = calculate_kl_divergence(p_normalized, m)
    kl_q = calculate_kl_divergence(q_normalized, m)

    # 计算 JSD
    jsd_loss = 0.5 * (kl_p + kl_q)

    return jsd_loss