File size: 9,302 Bytes
cc9dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"Hooks provide extensibility at the model level."
from ..torch_core import *
from ..callback import *
from ..basic_train import *
from ..basic_data import *

__all__ = ['ActivationStats', 'Hook', 'HookCallback', 'Hooks', 'hook_output', 'hook_outputs',
           'model_sizes', 'num_features_model', 'model_summary', 'dummy_eval', 'dummy_batch']

class Hook():
    "Create a hook on `m` with `hook_func`."
    def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
        self.hook_func,self.detach,self.stored = hook_func,detach,None
        f = m.register_forward_hook if is_forward else m.register_backward_hook
        self.hook = f(self.hook_fn)
        self.removed = False

    def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
        "Applies `hook_func` to `module`, `input`, `output`."
        if self.detach:
            input  = (o.detach() for o in input ) if is_listy(input ) else input.detach()
            output = (o.detach() for o in output) if is_listy(output) else output.detach()
        self.stored = self.hook_func(module, input, output)

    def remove(self):
        "Remove the hook from the model."
        if not self.removed:
            self.hook.remove()
            self.removed=True

    def __enter__(self, *args): return self
    def __exit__(self, *args): self.remove()

class Hooks():
    "Create several hooks on the modules in `ms` with `hook_func`."
    def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
        self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]

    def __getitem__(self,i:int)->Hook: return self.hooks[i]
    def __len__(self)->int: return len(self.hooks)
    def __iter__(self): return iter(self.hooks)
    @property
    def stored(self): return [o.stored for o in self]

    def remove(self):
        "Remove the hooks from the model."
        for h in self.hooks: h.remove()

    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()

def _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if is_listy(o) else list(o)

def hook_output (module:nn.Module, detach:bool=True, grad:bool=False)->Hook:
    "Return a `Hook` that stores activations of `module` in `self.stored`"
    return Hook(module, _hook_inner, detach=detach, is_forward=not grad)

def hook_outputs(modules:Collection[nn.Module], detach:bool=True, grad:bool=False)->Hooks:
    "Return `Hooks` that store activations of all `modules` in `self.stored`"
    return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)

class HookCallback(LearnerCallback):
    "Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`."
    def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):
        super().__init__(learn)
        self.modules,self.do_remove = modules,do_remove

    def on_train_begin(self, **kwargs):
        "Register the `Hooks` on `self.modules`."
        if not self.modules:
            self.modules = [m for m in flatten_model(self.learn.model)
                            if hasattr(m, 'weight')]
        self.hooks = Hooks(self.modules, self.hook)

    def on_train_end(self, **kwargs):
        "Remove the `Hooks`."
        if self.do_remove: self.remove()

    def remove(self): 
        if getattr(self, 'hooks', None): self.hooks.remove()
    def __del__(self): self.remove()

class ActivationStats(HookCallback):
    "Callback that record the mean and std of activations."

    def on_train_begin(self, **kwargs):
        "Initialize stats."
        super().on_train_begin(**kwargs)
        self.stats = []

    def hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:
        "Take the mean and std of `o`."
        return o.mean().item(),o.std().item()
    def on_batch_end(self, train, **kwargs):
        "Take the stored results and puts it in `self.stats`"
        if train: self.stats.append(self.hooks.stored)
    def on_train_end(self, **kwargs):
        "Polish the final result."
        super().on_train_end(**kwargs)
        self.stats = tensor(self.stats).permute(2,1,0)

def dummy_batch(m: nn.Module, size:tuple=(64,64))->Tensor:
    "Create a dummy batch to go through `m` with `size`."
    ch_in = in_channels(m)
    return one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)

def dummy_eval(m:nn.Module, size:tuple=(64,64)):
    "Pass a `dummy_batch` in evaluation mode in `m` with `size`."
    m.eval()
    return m(dummy_batch(m, size))
    #return m.eval()(dummy_batch(m, size))

def model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:
    "Pass a dummy input through the model `m` to get the various sizes of activations."
    with hook_outputs(m) as hooks:
        x = dummy_eval(m, size)
        return [o.stored.shape for o in hooks]

def num_features_model(m:nn.Module)->int:
    "Return the number of output features for `model`."
    sz = 64
    while True:
        try: return model_sizes(m, size=(sz,sz))[-1][1]
        except Exception as e:
            sz *= 2
            if sz > 2048: raise

def total_params(m:nn.Module)->int:
    params, trainable = 0, False
    if hasattr(m, "weight") and hasattr(m.weight, "size"):
         params += m.weight.numel()
         trainable = m.weight.requires_grad
    if hasattr(m, "bias") and hasattr(m.bias, "size"): params += m.bias.numel()
    return params, trainable

def hook_params(modules:Collection[nn.Module])->Hooks:
    return Hooks(modules, lambda m, i, o: total_params(m))

def params_size(m: Union[nn.Module,Learner], size: tuple = (3, 64, 64))->Tuple[Sizes, Tensor, Hooks]:
    "Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if `full`"
    if isinstance(m, Learner):
        if m.data.is_empty:
            raise Exception("This is an empty `Learner` and `Learner.summary` requires some data to pass through the model.")
        ds_type = DatasetType.Train if m.data.train_dl else (DatasetType.Valid if m.data.valid_dl else DatasetType.Test)
        x = m.data.one_batch(ds_type=ds_type, detach=False, denorm=False)[0]
        x = [o[:1] for o in x]  if is_listy(x) else x[:1]
        m = m.model
    elif isinstance(m, nn.Module): x = next(m.parameters()).new(1, *size)
    else: raise TypeError('You should either pass in a Learner or nn.Module')
    with hook_outputs(flatten_model(m)) as hook_o:
        with hook_params(flatten_model(m))as hook_p:
            x = m.eval()(*x) if is_listy(x) else m.eval()(x)
            output_size = [((o.stored.shape[1:]) if o.stored is not None else None) for o in hook_o]
            params = [(o.stored if o.stored is not None else (None,None)) for o in hook_p]
    params, trainables = map(list,zip(*params))
    return output_size, params, trainables

def get_layer_name(layer:nn.Module)->str:
    return str(layer.__class__).split(".")[-1].split("'")[0]

def layers_info(m:Collection[nn.Module]) -> Collection[namedtuple]:
    func = lambda m:list(map(get_layer_name, flatten_model(m)))
    layers_names = func(m.model) if isinstance(m, Learner) else func(m)
    layers_sizes, layers_params, layers_trainable = params_size(m)
    layer_info = namedtuple('Layer_Information', ['Layer', 'OutputSize', 'Params', 'Trainable'])
    return list(map(layer_info, layers_names, layers_sizes, layers_params, layers_trainable))

def model_summary(m:Learner, n:int=70):
    "Print a summary of `m` using a output text width of `n` chars"
    info = layers_info(m)
    header = ["Layer (type)", "Output Shape", "Param #", "Trainable"]
    res = m.model.__class__.__name__ + "\n"
    res += "=" * n + "\n"
    res += f"{header[0]:<20} {header[1]:<20} {header[2]:<10} {header[3]:<10}\n"
    res += "=" * n + "\n"
    total_params = 0
    total_trainable_params = 0
    for layer, size, params, trainable in info:
        if size is None: continue
        total_params += int(params)
        total_trainable_params += int(params) * trainable
        size, trainable = str(list(size)), str(trainable)
        res += f"{layer:<20} {size:<20} {int(params):<10,} {trainable:<10}\n"
        res += "_" * n + "\n"
    res += f"\nTotal params: {total_params:,}\n"
    res += f"Total trainable params: {total_trainable_params:,}\n"
    res += f"Total non-trainable params: {total_params - total_trainable_params:,}\n"
           
    res += f"Optimized with {str(m.opt_func)[25:-1].replace('>', '')}\n"
    if m.true_wd: res += f"Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ \n"
    if "wd" in str(m.opt_func) or "weight_decay" in str(m.opt_func): res += f"\x1b[1;31m Specifying weight decay in the optimizer has no effect, Learner will overwrite \x1b[0m \n"
    if "lr" in str(m.opt_func) or "learning_rate" in str(m.opt_func): res += f"\x1b[1;31m Specifying lr in the optimizer has no effect, pass it to fit or the defaults.lr will apply \x1b[0m \n" 
    res += f"Loss function : {m.loss_func.__class__.__name__}\n"
    res += "=" * n + "\n"
    res += "Callbacks functions applied \n"
    res += "\n".join([f"    {cbs.__class__.__name__}" for cbs in m.callbacks])

    return PrettyString(res)

Learner.summary = model_summary