File size: 3,248 Bytes
32408ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Module of the HOW method"""

import numpy as np
import torch
import torch.nn as nn
import torchvision


class HOWNet(nn.Module):
    """Network for the HOW method

    :param list features: A list of torch.nn.Module which act as feature extractor
    :param torch.nn.Module attention: Attention layer
    :param torch.nn.Module smoothing: Smoothing layer
    :param torch.nn.Module dim_reduction: Dimensionality reduction layer
    :param dict meta: Metadata that are stored with the network
    :param dict runtime: Runtime options that can be used as default for e.g. inference
    """

    def __init__(self, features, attention, smoothing, dim_reduction, meta, runtime):
        super().__init__()

        self.features = features
        self.attention = attention
        self.smoothing = smoothing
        self.dim_reduction = dim_reduction

        self.meta = meta
        self.runtime = runtime


    def copy_excluding_dim_reduction(self):
        """Return a copy of this network without the dim_reduction layer"""
        meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
        return self.__class__(self.features, self.attention, self.smoothing, None, meta, self.runtime)

    def copy_with_runtime(self, runtime):
        """Return a copy of this network with a different runtime dict"""
        return self.__class__(self.features, self.attention, self.smoothing, self.dim_reduction, self.meta, runtime)


    # Methods of nn.Module

    @staticmethod
    def _set_batchnorm_eval(mod):
        if mod.__class__.__name__.find('BatchNorm') != -1:
            # freeze running mean and std
            mod.eval()

    def train(self, mode=True):
        res = super().train(mode)
        if mode:
            self.apply(HOWNet._set_batchnorm_eval)
        return res

    def parameter_groups(self, optimizer_opts):
        """Return torch parameter groups"""
        layers = [self.features, self.attention, self.smoothing]
        parameters = [{'params': x.parameters()} for x in layers if x is not None]
        if self.dim_reduction:
            # Do not update dimensionality reduction layer
            parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
        return parameters


    # Forward
    def features_attentions(self, x, *, scales):
        """Return a tuple (features, attentions) where each is a list containing requested scales"""
        feats = []
        masks = []
        for s in scales:
            xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
            o = self.features(xs)
            m = self.attention(o)
            if self.smoothing:
                o = self.smoothing(o)
            if self.dim_reduction:
                o = self.dim_reduction(o)
            feats.append(o)
            masks.append(m)

        # Normalize max weight to 1
        mx = max(x.max() for x in masks)
        masks = [x/mx for x in masks]

        return feats, masks

    def __repr__(self):
        meta_str = "\n".join("    %s: %s" % x for x in self.meta.items())
        return "%s(meta={\n%s\n})" % (self.__class__.__name__, meta_str)

    def meta_repr(self):
        """Return meta representation"""
        return str(self)