"""Module of the HOW method""" import numpy as np import torch import torch.nn as nn import torchvision class HOWNet(nn.Module): """Network for the HOW method :param list features: A list of torch.nn.Module which act as feature extractor :param torch.nn.Module attention: Attention layer :param torch.nn.Module smoothing: Smoothing layer :param torch.nn.Module dim_reduction: Dimensionality reduction layer :param dict meta: Metadata that are stored with the network :param dict runtime: Runtime options that can be used as default for e.g. inference """ def __init__(self, features, attention, smoothing, dim_reduction, meta, runtime): super().__init__() self.features = features self.attention = attention self.smoothing = smoothing self.dim_reduction = dim_reduction self.meta = meta self.runtime = runtime def copy_excluding_dim_reduction(self): """Return a copy of this network without the dim_reduction layer""" meta = {**self.meta, "outputdim": self.meta['backbone_dim']} return self.__class__(self.features, self.attention, self.smoothing, None, meta, self.runtime) def copy_with_runtime(self, runtime): """Return a copy of this network with a different runtime dict""" return self.__class__(self.features, self.attention, self.smoothing, self.dim_reduction, self.meta, runtime) # Methods of nn.Module @staticmethod def _set_batchnorm_eval(mod): if mod.__class__.__name__.find('BatchNorm') != -1: # freeze running mean and std mod.eval() def train(self, mode=True): res = super().train(mode) if mode: self.apply(HOWNet._set_batchnorm_eval) return res def parameter_groups(self, optimizer_opts): """Return torch parameter groups""" layers = [self.features, self.attention, self.smoothing] parameters = [{'params': x.parameters()} for x in layers if x is not None] if self.dim_reduction: # Do not update dimensionality reduction layer parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0}) return parameters # Forward def features_attentions(self, x, *, scales): """Return a tuple (features, attentions) where each is a list containing requested scales""" feats = [] masks = [] for s in scales: xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False) o = self.features(xs) m = self.attention(o) if self.smoothing: o = self.smoothing(o) if self.dim_reduction: o = self.dim_reduction(o) feats.append(o) masks.append(m) # Normalize max weight to 1 mx = max(x.max() for x in masks) masks = [x/mx for x in masks] return feats, masks def __repr__(self): meta_str = "\n".join(" %s: %s" % x for x in self.meta.items()) return "%s(meta={\n%s\n})" % (self.__class__.__name__, meta_str) def meta_repr(self): """Return meta representation""" return str(self)