YannisK's picture
temp
1cdf366
"""Module of the HOW method"""
import numpy as np
import torch
import torch.nn as nn
import torchvision
class HOWNet(nn.Module):
"""Network for the HOW method
:param list features: A list of torch.nn.Module which act as feature extractor
:param torch.nn.Module attention: Attention layer
:param torch.nn.Module smoothing: Smoothing layer
:param torch.nn.Module dim_reduction: Dimensionality reduction layer
:param dict meta: Metadata that are stored with the network
:param dict runtime: Runtime options that can be used as default for e.g. inference
"""
def __init__(self, features, attention, smoothing, dim_reduction, meta, runtime):
super().__init__()
self.features = features
self.attention = attention
self.smoothing = smoothing
self.dim_reduction = dim_reduction
self.meta = meta
self.runtime = runtime
def copy_excluding_dim_reduction(self):
"""Return a copy of this network without the dim_reduction layer"""
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
return self.__class__(self.features, self.attention, self.smoothing, None, meta, self.runtime)
def copy_with_runtime(self, runtime):
"""Return a copy of this network with a different runtime dict"""
return self.__class__(self.features, self.attention, self.smoothing, self.dim_reduction, self.meta, runtime)
# Methods of nn.Module
@staticmethod
def _set_batchnorm_eval(mod):
if mod.__class__.__name__.find('BatchNorm') != -1:
# freeze running mean and std
mod.eval()
def train(self, mode=True):
res = super().train(mode)
if mode:
self.apply(HOWNet._set_batchnorm_eval)
return res
def parameter_groups(self, optimizer_opts):
"""Return torch parameter groups"""
layers = [self.features, self.attention, self.smoothing]
parameters = [{'params': x.parameters()} for x in layers if x is not None]
if self.dim_reduction:
# Do not update dimensionality reduction layer
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
return parameters
# Forward
def features_attentions(self, x, *, scales):
"""Return a tuple (features, attentions) where each is a list containing requested scales"""
feats = []
masks = []
for s in scales:
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
o = self.features(xs)
m = self.attention(o)
if self.smoothing:
o = self.smoothing(o)
if self.dim_reduction:
o = self.dim_reduction(o)
feats.append(o)
masks.append(m)
# Normalize max weight to 1
mx = max(x.max() for x in masks)
masks = [x/mx for x in masks]
return feats, masks
def __repr__(self):
meta_str = "\n".join(" %s: %s" % x for x in self.meta.items())
return "%s(meta={\n%s\n})" % (self.__class__.__name__, meta_str)
def meta_repr(self):
"""Return meta representation"""
return str(self)