File size: 5,121 Bytes
9651aac
 
 
 
 
 
 
 
 
 
 
 
4e8ced7
9651aac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e8ced7
9651aac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
9651aac
 
 
 
 
 
 
 
 
 
 
 
 
4e8ced7
9651aac
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

import os
import torch
from torch import nn
import torchvision

from how import layers

from lit import LocalfeatureIntegrationTransformer

from how.networks.how_net import HOWNet

class FIReNet(HOWNet):

    def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
        super().__init__(features, attention, None, dim_reduction, meta, runtime)
        self.lit = lit
        self.return_global = False
        
    def copy_excluding_dim_reduction(self):
        """Return a copy of this network without the dim_reduction layer"""
        meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
        return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)

    def copy_with_runtime(self, runtime):
        """Return a copy of this network with a different runtime dict"""
        return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)

    def parameter_groups(self):
        """Return torch parameter groups"""
        layers = [self.features, self.attention, self.smoothing, self.lit]
        parameters = [{'params': x.parameters()} for x in layers if x is not None]
        if self.dim_reduction:
            # Do not update dimensionality reduction layer
            parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
        return parameters

    def get_superfeatures(self, x, *, scales):
        """

        return a list of tuple (features, attentionmpas) where each is a list containing requested scales

        features is a tensor BxDxNx1

        attentionmaps is a tensor BxNxHxW

        """
        feats = []
        attns = []
        strengths = []
        for s in scales:
            xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
            o = self.features(xs)
            o, attn = self.lit(o)
            strength = self.attention(o)
            if self.smoothing:
                o = self.smoothing(o)
            if self.dim_reduction:
                o = self.dim_reduction(o)
            feats.append(o)
            attns.append(attn)
            strengths.append(strength)
        return feats, attns, strengths
        
    def forward(self, x):
        return self.get_superfeatures(x, scales=self.runtime['training_scales'])
        
    
def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
    """Initialize FIRe network

    :param str architecture: Network backbone architecture (e.g. resnet18)

    :param str pretrained: url of the pretrained model (None for using random initialization)

    :param int skip_layer: How many layers of blocks should be skipped (from the end)

    :param dict dim_reduction: Options for the dimensionality reduction layer

    :param dict lit: Options for the lit layer

    :param dict runtime: Runtime options to be stored in the network

    :return FIRe: Initialized network

    """
    # Take convolutional layers as features, always ends with ReLU to make last activations non-negative
    net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead 
    if architecture.startswith('alexnet') or architecture.startswith('vgg'):
        features = list(net_in.features.children())[:-1]
    elif architecture.startswith('resnet'):
        features = list(net_in.children())[:-2]
    elif architecture.startswith('densenet'):
        features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
    elif architecture.startswith('squeezenet'):
        features = list(net_in.features.children())
    else:
        raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))

    if skip_layer > 0:
        features = features[:-skip_layer]
    backbone_dim = 2048 // (2 ** skip_layer)

    att_layer = layers.attention.L2Attention()

    lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)

    reduction_layer = None
    if dim_reduction:
        reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])

    meta = {
        "architecture": architecture,
        "backbone_dim": lit['dim'],
        "outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
        "corercf_size": 32 // (2 ** skip_layer),
    }
    net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
    
    if pretrained is not None:
        assert os.path.isfile(pretrained), pretrained
        ckpt = torch.load(pretrained, map_location='cpu')
        missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
        assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
        assert all(['fc' in a for a in unexpected]), "Loading did not go well"
    return net