SuperFeatures / fire_network.py
YannisK's picture
temp state
9651aac
raw history blame
No virus
5.89 kB
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import os
import torch
from torch import nn
import torchvision
from cirtorch.networks import imageretrievalnet
from how import layers
from how.layers import functional as HF
from lit import LocalfeatureIntegrationTransformer
from how.networks.how_net import HOWNet, CORERCF_SIZE
class FIReNet(HOWNet):
def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
super().__init__(features, attention, None, dim_reduction, meta, runtime)
self.lit = lit
self.return_global = False
def copy_excluding_dim_reduction(self):
"""Return a copy of this network without the dim_reduction layer"""
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)
def copy_with_runtime(self, runtime):
"""Return a copy of this network with a different runtime dict"""
return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)
def parameter_groups(self):
"""Return torch parameter groups"""
layers = [self.features, self.attention, self.smoothing, self.lit]
parameters = [{'params': x.parameters()} for x in layers if x is not None]
if self.dim_reduction:
# Do not update dimensionality reduction layer
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
return parameters
def get_superfeatures(self, x, *, scales):
"""
return a list of tuple (features, attentionmpas) where each is a list containing requested scales
features is a tensor BxDxNx1
attentionmaps is a tensor BxNxHxW
"""
feats = []
attns = []
strengths = []
for s in scales:
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
o = self.features(xs)
o, attn = self.lit(o)
strength = self.attention(o)
if self.smoothing:
o = self.smoothing(o)
if self.dim_reduction:
o = self.dim_reduction(o)
feats.append(o)
attns.append(attn)
strengths.append(strength)
return feats, attns, strengths
def forward(self, x):
if self.return_global:
return self.forward_global(x, scales=self.runtime['training_scales'])
return self.get_superfeatures(x, scales=self.runtime['training_scales'])
def forward_global(self, x, *, scales):
"""Return global descriptor"""
feats, _, strengths = self.get_superfeatures(x, scales=scales)
return HF.weighted_spoc(feats, strengths)
def forward_local(self, x, *, features_num, scales):
"""Return selected super features"""
feats, _, strengths = self.get_superfeatures(x, scales=scales)
return HF.how_select_local(feats, strengths, scales=scales, features_num=features_num)
def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
"""Initialize FIRe network
:param str architecture: Network backbone architecture (e.g. resnet18)
:param str pretrained: url of the pretrained model (None for using random initialization)
:param int skip_layer: How many layers of blocks should be skipped (from the end)
:param dict dim_reduction: Options for the dimensionality reduction layer
:param dict lit: Options for the lit layer
:param dict runtime: Runtime options to be stored in the network
:return FIRe: Initialized network
"""
# Take convolutional layers as features, always ends with ReLU to make last activations non-negative
net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead
if architecture.startswith('alexnet') or architecture.startswith('vgg'):
features = list(net_in.features.children())[:-1]
elif architecture.startswith('resnet'):
features = list(net_in.children())[:-2]
elif architecture.startswith('densenet'):
features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
elif architecture.startswith('squeezenet'):
features = list(net_in.features.children())
else:
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
if skip_layer > 0:
features = features[:-skip_layer]
backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)
att_layer = layers.attention.L2Attention()
lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)
reduction_layer = None
if dim_reduction:
reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])
meta = {
"architecture": architecture,
"backbone_dim": lit['dim'],
"outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
"corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
}
net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
if pretrained is not None:
assert os.path.isfile(pretrained), pretrained
ckpt = torch.load(pretrained, map_location='cpu')
missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
assert all(['fc' in a for a in unexpected]), "Loading did not go well"
return net