File size: 4,049 Bytes
437b5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn.functional as F
import numpy as np

from .geom import gather_nd

# input: [batch_size, C, H, W]
# output: [batch_size, C, H, W], [batch_size, C, H, W]
def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
    inputs = inputs / moving_instance_max
    
    batch_size, C, H, W = inputs.shape

    pad_size = ksize // 2 + (dilation - 1)
    kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)
    
    pad_inputs = F.pad(inputs, [pad_size] * 4, mode='reflect')

    avg_spatial_inputs = F.conv2d(
        pad_inputs,
        kernel,
        stride=1,
        dilation=dilation,
        padding=0,
        groups=C
    )
    avg_channel_inputs = torch.mean(inputs, axis=1, keepdim=True) # channel dimension is 1

    alpha = F.softplus(inputs - avg_spatial_inputs)
    beta = F.softplus(inputs - avg_channel_inputs)

    return alpha, beta


# input: score_map [batch_size, 1, H, W]
# output: indices [2, k, 2], scores [2, k]
def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5):
    h = score_map.shape[2]
    w = score_map.shape[3]

    mask = score_map > score_thld
    if nms_size > 0:
        nms_mask = F.max_pool2d(score_map, kernel_size=nms_size, stride=1, padding=nms_size//2)
        nms_mask = torch.eq(score_map, nms_mask)
        mask = torch.logical_and(nms_mask, mask)
    if eof_size > 0:
        eof_mask = torch.ones((1, 1, h - 2 * eof_size, w - 2 * eof_size), dtype=torch.float32, device=score_map.device)
        eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
        eof_mask = eof_mask.bool()
        mask = torch.logical_and(eof_mask, mask)
    if edge_thld > 0:
        non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld)
        mask = torch.logical_and(non_edge_mask, mask)

    bs = score_map.shape[0]
    if bs is None:
        indices = torch.nonzero(mask)[0]
        scores = gather_nd(score_map, indices)[0]
        sample = torch.sort(scores, descending=True)[1][0:k]
        indices = indices[sample].unsqueeze(0)
        scores = scores[sample].unsqueeze(0)
    else:
        indices = []
        scores = []
        for i in range(bs):
            tmp_mask = mask[i][0]
            tmp_score_map = score_map[i][0]
            tmp_indices = torch.nonzero(tmp_mask)
            tmp_scores = gather_nd(tmp_score_map, tmp_indices)
            tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k]
            tmp_indices = tmp_indices[tmp_sample]
            tmp_scores = tmp_scores[tmp_sample]
            indices.append(tmp_indices)
            scores.append(tmp_scores)
        try:
            indices = torch.stack(indices, dim=0)
            scores = torch.stack(scores, dim=0)
        except:
            min_num = np.min([len(i) for i in indices])
            indices = torch.stack([i[:min_num] for i in indices], dim=0)
            scores = torch.stack([i[:min_num] for i in scores], dim=0)
    return indices, scores


def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
    b, c, h, w = inputs.size()
    device = inputs.device

    dii_filter = torch.tensor(
        [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
    ).view(1, 1, 3, 3)
    dij_filter = 0.25 * torch.tensor(
        [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
    ).view(1, 1, 3, 3)
    djj_filter = torch.tensor(
        [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
    ).view(1, 1, 3, 3)

    dii = F.conv2d(
        inputs.view(-1, 1, h, w), dii_filter.to(device), padding=dilation, dilation=dilation
    ).view(b, c, h, w)
    dij = F.conv2d(
        inputs.view(-1, 1, h, w), dij_filter.to(device), padding=dilation, dilation=dilation
    ).view(b, c, h, w)
    djj = F.conv2d(
        inputs.view(-1, 1, h, w), djj_filter.to(device), padding=dilation, dilation=dilation
    ).view(b, c, h, w)

    det = dii * djj - dij * dij
    tr = dii + djj
    del dii, dij, djj

    threshold = (edge_thld + 1) ** 2 / edge_thld
    is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)

    return is_not_edge