File size: 5,890 Bytes
9651aac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

import os
import torch
from torch import nn
import torchvision

from cirtorch.networks import imageretrievalnet

from how import layers
from how.layers import functional as HF

from lit import LocalfeatureIntegrationTransformer

from how.networks.how_net import HOWNet, CORERCF_SIZE

class FIReNet(HOWNet):

    def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
        super().__init__(features, attention, None, dim_reduction, meta, runtime)
        self.lit = lit
        self.return_global = False
        
    def copy_excluding_dim_reduction(self):
        """Return a copy of this network without the dim_reduction layer"""
        meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
        return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)

    def copy_with_runtime(self, runtime):
        """Return a copy of this network with a different runtime dict"""
        return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)

    def parameter_groups(self):
        """Return torch parameter groups"""
        layers = [self.features, self.attention, self.smoothing, self.lit]
        parameters = [{'params': x.parameters()} for x in layers if x is not None]
        if self.dim_reduction:
            # Do not update dimensionality reduction layer
            parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
        return parameters

    def get_superfeatures(self, x, *, scales):
        """

        return a list of tuple (features, attentionmpas) where each is a list containing requested scales

        features is a tensor BxDxNx1

        attentionmaps is a tensor BxNxHxW

        """
        feats = []
        attns = []
        strengths = []
        for s in scales:
            xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
            o = self.features(xs)
            o, attn = self.lit(o)
            strength = self.attention(o)
            if self.smoothing:
                o = self.smoothing(o)
            if self.dim_reduction:
                o = self.dim_reduction(o)
            feats.append(o)
            attns.append(attn)
            strengths.append(strength)
        return feats, attns, strengths
        
    def forward(self, x):
        if self.return_global:
            return self.forward_global(x, scales=self.runtime['training_scales'])
        return self.get_superfeatures(x, scales=self.runtime['training_scales'])
        
    def forward_global(self, x, *, scales):
        """Return global descriptor"""
        feats, _, strengths = self.get_superfeatures(x, scales=scales)
        return HF.weighted_spoc(feats, strengths)
        
    def forward_local(self, x, *, features_num, scales):
        """Return selected super features"""
        feats, _, strengths = self.get_superfeatures(x, scales=scales)
        return HF.how_select_local(feats, strengths, scales=scales, features_num=features_num)

def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
    """Initialize FIRe network

    :param str architecture: Network backbone architecture (e.g. resnet18)

    :param str pretrained: url of the pretrained model (None for using random initialization)

    :param int skip_layer: How many layers of blocks should be skipped (from the end)

    :param dict dim_reduction: Options for the dimensionality reduction layer

    :param dict lit: Options for the lit layer

    :param dict runtime: Runtime options to be stored in the network

    :return FIRe: Initialized network

    """
    # Take convolutional layers as features, always ends with ReLU to make last activations non-negative
    net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead 
    if architecture.startswith('alexnet') or architecture.startswith('vgg'):
        features = list(net_in.features.children())[:-1]
    elif architecture.startswith('resnet'):
        features = list(net_in.children())[:-2]
    elif architecture.startswith('densenet'):
        features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
    elif architecture.startswith('squeezenet'):
        features = list(net_in.features.children())
    else:
        raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))

    if skip_layer > 0:
        features = features[:-skip_layer]
    backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)

    att_layer = layers.attention.L2Attention()

    lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)

    reduction_layer = None
    if dim_reduction:
        reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])

    meta = {
        "architecture": architecture,
        "backbone_dim": lit['dim'],
        "outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
        "corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
    }
    net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
    
    if pretrained is not None:
        assert os.path.isfile(pretrained), pretrained
        ckpt = torch.load(pretrained, map_location='cpu')
        missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
        assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
        assert all(['fc' in a for a in unexpected]), "Loading did not go well"
    return net