File size: 2,885 Bytes
32408ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Layer functions"""

import torch
import torch.nn.functional as F

import cirtorch.layers.functional as CF


def smoothing_avg_pooling(feats, kernel_size):
    """Smoothing average pooling

    :param torch.Tensor feats: Feature map
    :param int kernel_size: kernel size of pooling
    :return torch.Tensor: Smoothend feature map
    """
    pad = kernel_size // 2
    return F.avg_pool2d(feats, (kernel_size, kernel_size), stride=1, padding=pad,
                        count_include_pad=False)


def weighted_spoc(ms_feats, ms_weights):
    """Weighted SPoC pooling, summed over scales.

    :param list ms_feats: A list of feature maps, each at a different scale
    :param list ms_weights: A list of weights, each at a different scale
    :return torch.Tensor: L2-normalized global descriptor
    """
    desc = torch.zeros((1, ms_feats[0].shape[1]), dtype=torch.float32, device=ms_feats[0].device)
    for feats, weights in zip(ms_feats, ms_weights):
        desc += (feats * weights).sum((-2, -1)).squeeze()
    return CF.l2n(desc)


def how_select_local(ms_feats, ms_masks, *, scales, features_num):
    """Convert multi-scale feature maps with attentions to a list of local descriptors

    :param list ms_feats: A list of feature maps, each at a different scale
    :param list ms_masks: A list of attentions, each at a different scale
    :param list scales: A list of scales (floats)
    :param int features_num: Number of features to be returned (sorted by attenions)
    :return tuple: A list of descriptors, attentions, locations (x_coor, y_coor) and scales where
            elements from each list correspond to each other
    """
    device = ms_feats[0].device
    size = sum(x.shape[0] * x.shape[1] for x in ms_masks)

    desc = torch.zeros(size, ms_feats[0].shape[1], dtype=torch.float32, device=device)
    atts = torch.zeros(size, dtype=torch.float32, device=device)
    locs = torch.zeros(size, 2, dtype=torch.int16, device=device)
    scls = torch.zeros(size, dtype=torch.float16, device=device)

    pointer = 0
    for sc, vs, ms in zip(scales, ms_feats, ms_masks):
        if len(ms.shape) == 0:
            continue

        height, width = ms.shape
        numel = torch.numel(ms)
        slc = slice(pointer, pointer+numel)
        pointer += numel

        desc[slc] = vs.squeeze(0).reshape(vs.shape[1], -1).T
        atts[slc] = ms.reshape(-1)
        width_arr = torch.arange(width, dtype=torch.int16)
        locs[slc, 0] = width_arr.repeat(height).to(device) # x axis
        height_arr = torch.arange(height, dtype=torch.int16)
        locs[slc, 1] = height_arr.view(-1, 1).repeat(1, width).reshape(-1).to(device) # y axis
        scls[slc] = sc

    keep_n = min(features_num, atts.shape[0]) if features_num is not None else atts.shape[0]
    idx = atts.sort(descending=True)[1][:keep_n]

    return desc[idx], atts[idx], locs[idx], scls[idx]