SuperFeatures / fire_network.py
YannisK's picture
temp
4e8ced7
raw
history blame
5.12 kB
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import os
import torch
from torch import nn
import torchvision
from how import layers
from lit import LocalfeatureIntegrationTransformer
from how.networks.how_net import HOWNet
class FIReNet(HOWNet):
def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
super().__init__(features, attention, None, dim_reduction, meta, runtime)
self.lit = lit
self.return_global = False
def copy_excluding_dim_reduction(self):
"""Return a copy of this network without the dim_reduction layer"""
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)
def copy_with_runtime(self, runtime):
"""Return a copy of this network with a different runtime dict"""
return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)
def parameter_groups(self):
"""Return torch parameter groups"""
layers = [self.features, self.attention, self.smoothing, self.lit]
parameters = [{'params': x.parameters()} for x in layers if x is not None]
if self.dim_reduction:
# Do not update dimensionality reduction layer
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
return parameters
def get_superfeatures(self, x, *, scales):
"""
return a list of tuple (features, attentionmpas) where each is a list containing requested scales
features is a tensor BxDxNx1
attentionmaps is a tensor BxNxHxW
"""
feats = []
attns = []
strengths = []
for s in scales:
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
o = self.features(xs)
o, attn = self.lit(o)
strength = self.attention(o)
if self.smoothing:
o = self.smoothing(o)
if self.dim_reduction:
o = self.dim_reduction(o)
feats.append(o)
attns.append(attn)
strengths.append(strength)
return feats, attns, strengths
def forward(self, x):
return self.get_superfeatures(x, scales=self.runtime['training_scales'])
def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
"""Initialize FIRe network
:param str architecture: Network backbone architecture (e.g. resnet18)
:param str pretrained: url of the pretrained model (None for using random initialization)
:param int skip_layer: How many layers of blocks should be skipped (from the end)
:param dict dim_reduction: Options for the dimensionality reduction layer
:param dict lit: Options for the lit layer
:param dict runtime: Runtime options to be stored in the network
:return FIRe: Initialized network
"""
# Take convolutional layers as features, always ends with ReLU to make last activations non-negative
net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead
if architecture.startswith('alexnet') or architecture.startswith('vgg'):
features = list(net_in.features.children())[:-1]
elif architecture.startswith('resnet'):
features = list(net_in.children())[:-2]
elif architecture.startswith('densenet'):
features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
elif architecture.startswith('squeezenet'):
features = list(net_in.features.children())
else:
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
if skip_layer > 0:
features = features[:-skip_layer]
backbone_dim = 2048 // (2 ** skip_layer)
att_layer = layers.attention.L2Attention()
lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)
reduction_layer = None
if dim_reduction:
reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])
meta = {
"architecture": architecture,
"backbone_dim": lit['dim'],
"outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
"corercf_size": 32 // (2 ** skip_layer),
}
net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
if pretrained is not None:
assert os.path.isfile(pretrained), pretrained
ckpt = torch.load(pretrained, map_location='cpu')
missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
assert all(['fc' in a for a in unexpected]), "Loading did not go well"
return net