File size: 5,994 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch.nn as nn
import torch
from torch.nn.functional import mse_loss
# for NEW: losses when calculated on keypoint locations
# see https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/subpix/dsnt.html
# from kornia.geometry import dsnt            # old kornia version  
from kornia.geometry.subpix import dsnt     # kornia 0.4.0

def joints_mse_loss_orig(output, target, target_weight=None):
    batch_size = output.size(0)
    num_joints = output.size(1)
    heatmaps_pred = output.view((batch_size, num_joints, -1)).split(1, 1)
    heatmaps_gt = target.view((batch_size, num_joints, -1)).split(1, 1)

    loss = 0
    for idx in range(num_joints):
        heatmap_pred = heatmaps_pred[idx]
        heatmap_gt = heatmaps_gt[idx]
        if target_weight is None:
            loss += 0.5 * mse_loss(heatmap_pred, heatmap_gt, reduction='mean')
        else:
            loss += 0.5 * mse_loss(
                heatmap_pred.mul(target_weight[:, idx]),
                heatmap_gt.mul(target_weight[:, idx]),
                reduction='mean'
            )

    return loss / num_joints


class JointsMSELoss(nn.Module):
    def __init__(self, use_target_weight=True):
        super().__init__()
        self.use_target_weight = use_target_weight
        raise NotImplementedError

    def forward(self, output, target, target_weight):
        if not self.use_target_weight:
            target_weight = None
        return joints_mse_loss_orig(output, target, target_weight)




# ----- NEW: losses when calculated on keypoint locations instead of keypoint heatmaps -----


def joints_mse_loss_onKPloc(output, target, meta, target_weight=None):
    # debugging:
    # for old kornia version
    # output_softmax_2d = dsnt.spatial_softmax_2d(target, temperature=torch.tensor(100))
    # output_kp = dsnt.spatial_softargmax_2d(output_softmax_2d, normalized_coordinates=False) + 1 
    # print(output_kp[0])
    # print(meta['tpts'][0])
    # render gaussian
    # dsnt.render_gaussian_2d(meta['tpts'][0][0, :2].to('cpu'), torch.tensor(([5., 5.])).to('cpu'), [256, 256], False)
    # output_softmax_2d = dsnt.spatial_softmax_2d(output, temperature=torch.tensor(100))
    # target_norm = target / target.sum(axis=3).sum(axis=2)[:, :, None, None]
    # output_softmax_2d = dsnt.spatial_softmax_2d(output*10)       # (target, temperature=torch.tensor(10))
    # output_kp = dsnt.spatial_softargmax_2d(target_norm, normalized_coordinates=False) + 1 

    # normalize target heatmap
    '''target_sum = target.sum(axis=3).sum(axis=2)[:, :, None, None]
    target_sum[target_sum==0] = 1e-2
    target_norm = target / target_sum'''
    target_norm = target        # now we have normalized heatmaps

    # normalize predictions -> from logits to probability distribution
    output_norm = dsnt.spatial_softmax2d(output, temperature=torch.tensor(1))

    # heatmap loss (for normalization)
    heatmap_loss = joints_mse_loss_orig(output_norm, target_norm, target_weight)

    # keypoint distance loss (average distance in pixels)
    output_kp = dsnt.spatial_expectation2d(output_norm, normalized_coordinates=False) + 1   # (bs, 20, 2)
    target_kp = meta['tpts'].to(output_kp.device)        # (bs, 20, 3)
    output_kp_resh = output_kp.reshape((-1, 2))
    target_kp_resh = target_kp[:, :, :2].reshape((-1, 2))
    weights_resh = target_kp[:, :, 2].reshape((-1))
    # dist_loss = (((output_kp_resh - target_kp_resh)**2).sum(axis=1).sqrt()*weights_resh)[weights_resh>0].sum() / min(weights_resh[weights_resh>0].sum(), 1e-5)
    dist_loss = (((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0]).sum() / max(weights_resh[weights_resh>0].sum(), 1e-5)


    # return heatmap_loss*100 # + 0.0001*dist_loss 

    # import pdb; pdb.set_trace()


    '''import matplotlib as mpl
    mpl.use('Agg')
    import matplotlib.pyplot as plt

    img_np = output_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3] 
    img_np = img_np * 255./ img_np.max()
    # plot image
    plt.imshow(img_np)     
    plt.savefig('./debugging_output/test_output.png')
    plt.close()

    img_np = target_norm[0, :, :, :].detach().cpu().numpy().transpose(1, 2, 0)[:, :, :3] 
    img_np = img_np * 255./ img_np.max()
    # plot image
    plt.imshow(img_np)     
    plt.savefig('./debugging_output/test_gt.png')
    plt.close()'''

    # print(heatmap_loss*100)
    # print(dist_loss * 1e-4)

    # distlossonly: return dist_loss * 1e-4 
    # both: return dist_loss * 1e-4 + heatmap_loss*100 
    return dist_loss * 1e-4 + heatmap_loss*100 




class JointsMSELoss_onKPloc(nn.Module):
    def __init__(self, use_target_weight=True):
        super().__init__()
        self.use_target_weight = use_target_weight

    def forward(self, output, target, target_weight):
        if not self.use_target_weight:
            target_weight = None
        return joints_mse_loss_onKPloc(output, target, meta, target_weight)





# ----- NEW: lsegmentation loss -----

import torch.nn.functional as F

'''def resize2d(img, size):
    return (F.adaptive_avg_pool2d(Variable(img,volatile=True), size)).data
    # F.adaptive_avg_pool2d(meta['silh'], (64,64))).data'''

def segmentation_loss(output, meta):
    # output: (6, 2, 64, 64)
    # meta.keys(): ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'silh']
    # prepare target silhouettes
    target_silh = meta['silh']
    target_silh_l = target_silh.to(torch.long)
    criterion_ce = nn.CrossEntropyLoss()
    if output.shape[2] == 64:
        target_silh_64 = F.adaptive_avg_pool2d(target_silh, (64,64))
        target_silh_64[target_silh_64>0.5] = 1
        target_silh_64[target_silh_64<=0.5] = 0
        target_silh_64_l = target_silh_64.to(torch.long)
        loss_silh_64 = criterion_ce(output, target_silh_64_l)       # 0.7
        return loss_silh_64
    else:
        loss_silh_l = criterion_ce(output, target_silh_l)       # 0.7
        return loss_silh_l