# Copyright (C) 2021-2022 Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). import os import torch from torch import nn import torchvision from how import layers from lit import LocalfeatureIntegrationTransformer from how.networks.how_net import HOWNet class FIReNet(HOWNet): def __init__(self, features, attention, lit, dim_reduction, meta, runtime): super().__init__(features, attention, None, dim_reduction, meta, runtime) self.lit = lit self.return_global = False def copy_excluding_dim_reduction(self): """Return a copy of this network without the dim_reduction layer""" meta = {**self.meta, "outputdim": self.meta['backbone_dim']} return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime) def copy_with_runtime(self, runtime): """Return a copy of this network with a different runtime dict""" return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime) def parameter_groups(self): """Return torch parameter groups""" layers = [self.features, self.attention, self.smoothing, self.lit] parameters = [{'params': x.parameters()} for x in layers if x is not None] if self.dim_reduction: # Do not update dimensionality reduction layer parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0}) return parameters def get_superfeatures(self, x, *, scales): """ return a list of tuple (features, attentionmpas) where each is a list containing requested scales features is a tensor BxDxNx1 attentionmaps is a tensor BxNxHxW """ feats = [] attns = [] strengths = [] for s in scales: xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False) o = self.features(xs) o, attn = self.lit(o) strength = self.attention(o) if self.smoothing: o = self.smoothing(o) if self.dim_reduction: o = self.dim_reduction(o) feats.append(o) attns.append(attn) strengths.append(strength) return feats, attns, strengths def forward(self, x): return self.get_superfeatures(x, scales=self.runtime['training_scales']) def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime): """Initialize FIRe network :param str architecture: Network backbone architecture (e.g. resnet18) :param str pretrained: url of the pretrained model (None for using random initialization) :param int skip_layer: How many layers of blocks should be skipped (from the end) :param dict dim_reduction: Options for the dimensionality reduction layer :param dict lit: Options for the lit layer :param dict runtime: Runtime options to be stored in the network :return FIRe: Initialized network """ # Take convolutional layers as features, always ends with ReLU to make last activations non-negative net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead if architecture.startswith('alexnet') or architecture.startswith('vgg'): features = list(net_in.features.children())[:-1] elif architecture.startswith('resnet'): features = list(net_in.children())[:-2] elif architecture.startswith('densenet'): features = list(net_in.features.children()) + [nn.ReLU(inplace=True)] elif architecture.startswith('squeezenet'): features = list(net_in.features.children()) else: raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture)) if skip_layer > 0: features = features[:-skip_layer] backbone_dim = 2048 // (2 ** skip_layer) att_layer = layers.attention.L2Attention() lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim) reduction_layer = None if dim_reduction: reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim']) meta = { "architecture": architecture, "backbone_dim": lit['dim'], "outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'], "corercf_size": 32 // (2 ** skip_layer), } net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime) if pretrained is not None: assert os.path.isfile(pretrained), pretrained ckpt = torch.load(pretrained, map_location='cpu') missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False) assert all(['dim_reduction' in a for a in missing]), "Loading did not go well" assert all(['fc' in a for a in unexpected]), "Loading did not go well" return net