File size: 7,152 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Modified from:
#   https://github.com/anibali/pytorch-stacked-hourglass 
#   https://github.com/bearpaw/pytorch-pose

import math
import torch
from kornia.geometry.subpix import dsnt     # kornia 0.4.0
import torch.nn.functional as F
from .transforms import transform_preds

__all__ = ['get_preds', 'get_preds_soft', 'calc_dists', 'dist_acc', 'accuracy', 'final_preds_untransformed',
           'final_preds', 'AverageMeter']

def get_preds(scores, return_maxval=False):
    ''' get predictions from score maps in torch Tensor
        return type: torch.LongTensor
    '''
    assert scores.dim() == 4, 'Score maps should be 4-dim'
    maxval, idx = torch.max(scores.view(scores.size(0), scores.size(1), -1), 2)

    maxval = maxval.view(scores.size(0), scores.size(1), 1)
    idx = idx.view(scores.size(0), scores.size(1), 1) + 1

    preds = idx.repeat(1, 1, 2).float()

    preds[:,:,0] = (preds[:,:,0] - 1) % scores.size(3) + 1
    preds[:,:,1] = torch.floor((preds[:,:,1] - 1) / scores.size(3)) + 1

    pred_mask = maxval.gt(0).repeat(1, 1, 2).float()    # values > 0
    preds *= pred_mask
    if return_maxval:
        return preds, maxval
    else:
        return preds


def get_preds_soft(scores, return_maxval=False, norm_coords=False, norm_and_unnorm_coords=False):
    ''' get predictions from score maps in torch Tensor
        predictions are made assuming a logit output map
        return type: torch.LongTensor
    '''

    # New: work on logit predictions
    scores_norm = dsnt.spatial_softmax2d(scores, temperature=torch.tensor(1))
    # maxval_norm, idx_norm = torch.max(scores_norm.view(scores.size(0), scores.size(1), -1), 2)
    # from unnormalized to normalized see:
    # from -1to1 to 0to64
    # see https://github.com/kornia/kornia/blob/b9ffe7efcba7399daeeb8028f10c22941b55d32d/kornia/utils/grid.py#L7 (line 40)
    # xs = (xs / (width - 1) - 0.5) * 2
    # ys = (ys / (height - 1) - 0.5) * 2

    device = scores.device

    if return_maxval:
        preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) 
        # grid_sample(input, grid, mode='bilinear', padding_mode='zeros')
        gs_input_single = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3]))     # (120, 1, 64, 64)
        gs_input = scores_norm.reshape((-1, 1, scores_norm.shape[2], scores_norm.shape[3]))     # (120, 1, 64, 64)

        half_pad = 2
        gs_input_single_padded = F.pad(input=gs_input_single, pad=(half_pad, half_pad, half_pad, half_pad, 0, 0, 0, 0), mode='constant', value=0)
        gs_input_all = torch.zeros((gs_input_single.shape[0], 9, gs_input_single.shape[2], gs_input_single.shape[3])).to(device)
        ind_tot = 0
        for ind0 in [-1, 0, 1]:
            for ind1 in [-1, 0, 1]:
                gs_input_all[:, ind_tot, :, :] = gs_input_single_padded[:, 0, half_pad+ind0:-half_pad+ind0, half_pad+ind1:-half_pad+ind1]
                ind_tot +=1

        gs_grid = preds_normalized.reshape((-1, 2))[:, None, None, :]                           # (120, 1, 1, 2)
        gs_output_all = F.grid_sample(gs_input_all, gs_grid, mode='nearest', padding_mode='zeros', align_corners=True).reshape((gs_input_all.shape[0], gs_input_all.shape[1], 1))
        gs_output = gs_output_all.sum(axis=1)
        # scores_norm[0, :, :, :].max(axis=2)[0].max(axis=1)[0]
        # gs_output[0, :, 0]
        gs_output_resh = gs_output.reshape((scores_norm.shape[0], scores_norm.shape[1], 1))

        if norm_and_unnorm_coords:
            preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1
            return preds, preds_normalized, gs_output_resh      
        elif norm_coords:
            return preds_normalized, gs_output_resh
        else:
            preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1
            return preds, gs_output_resh
    else:
        if norm_coords:
            preds_normalized = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=True) 
            return preds_normalized
        else:
            preds = dsnt.spatial_expectation2d(scores_norm, normalized_coordinates=False) + 1
            return preds


def calc_dists(preds, target, normalize):
    preds = preds.float()
    target = target.float()
    dists = torch.zeros(preds.size(1), preds.size(0))
    for n in range(preds.size(0)):
        for c in range(preds.size(1)):
            if target[n,c,0] > 1 and target[n, c, 1] > 1:
                dists[c, n] = torch.dist(preds[n,c,:], target[n,c,:])/normalize[n]
            else:
                dists[c, n] = -1
    return dists

def dist_acc(dist, thr=0.5):
    ''' Return percentage below threshold while ignoring values with a -1 '''
    dist = dist[dist != -1]
    if len(dist) > 0:
        return 1.0 * (dist < thr).sum().item() / len(dist)
    else:
        return -1

def accuracy(output, target, idxs=None, thr=0.5):
    ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations
        First value to be returned is average accuracy across 'idxs', followed by individual accuracies
    '''
    if idxs is None:
        idxs = list(range(target.shape[-3]))
    preds   = get_preds_soft(output)     # get_preds(output)
    gts     = get_preds(target)
    norm    = torch.ones(preds.size(0))*output.size(3)/10
    dists   = calc_dists(preds, gts, norm)

    acc = torch.zeros(len(idxs)+1)
    avg_acc = 0
    cnt = 0

    for i in range(len(idxs)):
        acc[i+1] = dist_acc(dists[idxs[i]], thr=thr)
        if acc[i+1] >= 0:
            avg_acc = avg_acc + acc[i+1]
            cnt += 1

    if cnt != 0:
        acc[0] = avg_acc / cnt
    return acc

def final_preds_untransformed(output, res):
    coords = get_preds_soft(output)     # get_preds(output) # float type

    # pose-processing
    for n in range(coords.size(0)):
        for p in range(coords.size(1)):
            hm = output[n][p]
            px = int(math.floor(coords[n][p][0]))
            py = int(math.floor(coords[n][p][1]))
            if px > 1 and px < res[0] and py > 1 and py < res[1]:
                diff = torch.Tensor([hm[py - 1][px] - hm[py - 1][px - 2], hm[py][px - 1]-hm[py - 2][px - 1]])
                coords[n][p] += diff.sign() * .25
    coords += 0.5

    if coords.dim() < 3:
        coords = coords.unsqueeze(0)

    coords -= 1  # Convert from 1-based to 0-based coordinates

    return coords

def final_preds(output, center, scale, res):
    coords = final_preds_untransformed(output, res)
    preds = coords.clone()

    # Transform back
    for i in range(coords.size(0)):
        preds[i] = transform_preds(coords[i], center[i], scale[i], res)

    if preds.dim() < 3:
        preds = preds.unsqueeze(0)

    return preds


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count