File size: 4,145 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
 
 
 
c74a070
 
a80d6bb
 
c74a070
a80d6bb
c74a070
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
 
 
a80d6bb
 
 
c74a070
 
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c74a070
a80d6bb
c74a070
a80d6bb
c74a070
a80d6bb
 
c74a070
 
 
 
a80d6bb
 
c74a070
 
 
 
a80d6bb
 
c74a070
 
 
 
a80d6bb
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn.functional as F
import numpy as np

from .geom import gather_nd

# input: [batch_size, C, H, W]
# output: [batch_size, C, H, W], [batch_size, C, H, W]
def peakiness_score(inputs, moving_instance_max, ksize=3, dilation=1):
    inputs = inputs / moving_instance_max

    batch_size, C, H, W = inputs.shape

    pad_size = ksize // 2 + (dilation - 1)
    kernel = torch.ones([C, 1, ksize, ksize], device=inputs.device) / (ksize * ksize)

    pad_inputs = F.pad(inputs, [pad_size] * 4, mode="reflect")

    avg_spatial_inputs = F.conv2d(
        pad_inputs, kernel, stride=1, dilation=dilation, padding=0, groups=C
    )
    avg_channel_inputs = torch.mean(
        inputs, axis=1, keepdim=True
    )  # channel dimension is 1

    alpha = F.softplus(inputs - avg_spatial_inputs)
    beta = F.softplus(inputs - avg_channel_inputs)

    return alpha, beta


# input: score_map [batch_size, 1, H, W]
# output: indices [2, k, 2], scores [2, k]
def extract_kpts(score_map, k=256, score_thld=0, edge_thld=0, nms_size=3, eof_size=5):
    h = score_map.shape[2]
    w = score_map.shape[3]

    mask = score_map > score_thld
    if nms_size > 0:
        nms_mask = F.max_pool2d(
            score_map, kernel_size=nms_size, stride=1, padding=nms_size // 2
        )
        nms_mask = torch.eq(score_map, nms_mask)
        mask = torch.logical_and(nms_mask, mask)
    if eof_size > 0:
        eof_mask = torch.ones(
            (1, 1, h - 2 * eof_size, w - 2 * eof_size),
            dtype=torch.float32,
            device=score_map.device,
        )
        eof_mask = F.pad(eof_mask, [eof_size] * 4, value=0)
        eof_mask = eof_mask.bool()
        mask = torch.logical_and(eof_mask, mask)
    if edge_thld > 0:
        non_edge_mask = edge_mask(score_map, 1, dilation=3, edge_thld=edge_thld)
        mask = torch.logical_and(non_edge_mask, mask)

    bs = score_map.shape[0]
    if bs is None:
        indices = torch.nonzero(mask)[0]
        scores = gather_nd(score_map, indices)[0]
        sample = torch.sort(scores, descending=True)[1][0:k]
        indices = indices[sample].unsqueeze(0)
        scores = scores[sample].unsqueeze(0)
    else:
        indices = []
        scores = []
        for i in range(bs):
            tmp_mask = mask[i][0]
            tmp_score_map = score_map[i][0]
            tmp_indices = torch.nonzero(tmp_mask)
            tmp_scores = gather_nd(tmp_score_map, tmp_indices)
            tmp_sample = torch.sort(tmp_scores, descending=True)[1][0:k]
            tmp_indices = tmp_indices[tmp_sample]
            tmp_scores = tmp_scores[tmp_sample]
            indices.append(tmp_indices)
            scores.append(tmp_scores)
        try:
            indices = torch.stack(indices, dim=0)
            scores = torch.stack(scores, dim=0)
        except:
            min_num = np.min([len(i) for i in indices])
            indices = torch.stack([i[:min_num] for i in indices], dim=0)
            scores = torch.stack([i[:min_num] for i in scores], dim=0)
    return indices, scores


def edge_mask(inputs, n_channel, dilation=1, edge_thld=5):
    b, c, h, w = inputs.size()
    device = inputs.device

    dii_filter = torch.tensor([[0, 1.0, 0], [0, -2.0, 0], [0, 1.0, 0]]).view(1, 1, 3, 3)
    dij_filter = 0.25 * torch.tensor(
        [[1.0, 0, -1.0], [0, 0.0, 0], [-1.0, 0, 1.0]]
    ).view(1, 1, 3, 3)
    djj_filter = torch.tensor([[0, 0, 0], [1.0, -2.0, 1.0], [0, 0, 0]]).view(1, 1, 3, 3)

    dii = F.conv2d(
        inputs.view(-1, 1, h, w),
        dii_filter.to(device),
        padding=dilation,
        dilation=dilation,
    ).view(b, c, h, w)
    dij = F.conv2d(
        inputs.view(-1, 1, h, w),
        dij_filter.to(device),
        padding=dilation,
        dilation=dilation,
    ).view(b, c, h, w)
    djj = F.conv2d(
        inputs.view(-1, 1, h, w),
        djj_filter.to(device),
        padding=dilation,
        dilation=dilation,
    ).view(b, c, h, w)

    det = dii * djj - dij * dij
    tr = dii + djj
    del dii, dij, djj

    threshold = (edge_thld + 1) ** 2 / edge_thld
    is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)

    return is_not_edge