Spaces:
Build error
Build error
"""Module of the HOW method""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
class HOWNet(nn.Module): | |
"""Network for the HOW method | |
:param list features: A list of torch.nn.Module which act as feature extractor | |
:param torch.nn.Module attention: Attention layer | |
:param torch.nn.Module smoothing: Smoothing layer | |
:param torch.nn.Module dim_reduction: Dimensionality reduction layer | |
:param dict meta: Metadata that are stored with the network | |
:param dict runtime: Runtime options that can be used as default for e.g. inference | |
""" | |
def __init__(self, features, attention, smoothing, dim_reduction, meta, runtime): | |
super().__init__() | |
self.features = features | |
self.attention = attention | |
self.smoothing = smoothing | |
self.dim_reduction = dim_reduction | |
self.meta = meta | |
self.runtime = runtime | |
def copy_excluding_dim_reduction(self): | |
"""Return a copy of this network without the dim_reduction layer""" | |
meta = {**self.meta, "outputdim": self.meta['backbone_dim']} | |
return self.__class__(self.features, self.attention, self.smoothing, None, meta, self.runtime) | |
def copy_with_runtime(self, runtime): | |
"""Return a copy of this network with a different runtime dict""" | |
return self.__class__(self.features, self.attention, self.smoothing, self.dim_reduction, self.meta, runtime) | |
# Methods of nn.Module | |
def _set_batchnorm_eval(mod): | |
if mod.__class__.__name__.find('BatchNorm') != -1: | |
# freeze running mean and std | |
mod.eval() | |
def train(self, mode=True): | |
res = super().train(mode) | |
if mode: | |
self.apply(HOWNet._set_batchnorm_eval) | |
return res | |
def parameter_groups(self, optimizer_opts): | |
"""Return torch parameter groups""" | |
layers = [self.features, self.attention, self.smoothing] | |
parameters = [{'params': x.parameters()} for x in layers if x is not None] | |
if self.dim_reduction: | |
# Do not update dimensionality reduction layer | |
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0}) | |
return parameters | |
# Forward | |
def features_attentions(self, x, *, scales): | |
"""Return a tuple (features, attentions) where each is a list containing requested scales""" | |
feats = [] | |
masks = [] | |
for s in scales: | |
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False) | |
o = self.features(xs) | |
m = self.attention(o) | |
if self.smoothing: | |
o = self.smoothing(o) | |
if self.dim_reduction: | |
o = self.dim_reduction(o) | |
feats.append(o) | |
masks.append(m) | |
# Normalize max weight to 1 | |
mx = max(x.max() for x in masks) | |
masks = [x/mx for x in masks] | |
return feats, masks | |
def __repr__(self): | |
meta_str = "\n".join(" %s: %s" % x for x in self.meta.items()) | |
return "%s(meta={\n%s\n})" % (self.__class__.__name__, meta_str) | |
def meta_repr(self): | |
"""Return meta representation""" | |
return str(self) | |