File size: 9,302 Bytes
cc9dfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
"Hooks provide extensibility at the model level."
from ..torch_core import *
from ..callback import *
from ..basic_train import *
from ..basic_data import *
__all__ = ['ActivationStats', 'Hook', 'HookCallback', 'Hooks', 'hook_output', 'hook_outputs',
'model_sizes', 'num_features_model', 'model_summary', 'dummy_eval', 'dummy_batch']
class Hook():
"Create a hook on `m` with `hook_func`."
def __init__(self, m:nn.Module, hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
self.hook_func,self.detach,self.stored = hook_func,detach,None
f = m.register_forward_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False
def hook_fn(self, module:nn.Module, input:Tensors, output:Tensors):
"Applies `hook_func` to `module`, `input`, `output`."
if self.detach:
input = (o.detach() for o in input ) if is_listy(input ) else input.detach()
output = (o.detach() for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)
def remove(self):
"Remove the hook from the model."
if not self.removed:
self.hook.remove()
self.removed=True
def __enter__(self, *args): return self
def __exit__(self, *args): self.remove()
class Hooks():
"Create several hooks on the modules in `ms` with `hook_func`."
def __init__(self, ms:Collection[nn.Module], hook_func:HookFunc, is_forward:bool=True, detach:bool=True):
self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]
def __getitem__(self,i:int)->Hook: return self.hooks[i]
def __len__(self)->int: return len(self.hooks)
def __iter__(self): return iter(self.hooks)
@property
def stored(self): return [o.stored for o in self]
def remove(self):
"Remove the hooks from the model."
for h in self.hooks: h.remove()
def __enter__(self, *args): return self
def __exit__ (self, *args): self.remove()
def _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if is_listy(o) else list(o)
def hook_output (module:nn.Module, detach:bool=True, grad:bool=False)->Hook:
"Return a `Hook` that stores activations of `module` in `self.stored`"
return Hook(module, _hook_inner, detach=detach, is_forward=not grad)
def hook_outputs(modules:Collection[nn.Module], detach:bool=True, grad:bool=False)->Hooks:
"Return `Hooks` that store activations of all `modules` in `self.stored`"
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)
class HookCallback(LearnerCallback):
"Callback that can be used to register hooks on `modules`. Implement the corresponding function in `self.hook`."
def __init__(self, learn:Learner, modules:Sequence[nn.Module]=None, do_remove:bool=True):
super().__init__(learn)
self.modules,self.do_remove = modules,do_remove
def on_train_begin(self, **kwargs):
"Register the `Hooks` on `self.modules`."
if not self.modules:
self.modules = [m for m in flatten_model(self.learn.model)
if hasattr(m, 'weight')]
self.hooks = Hooks(self.modules, self.hook)
def on_train_end(self, **kwargs):
"Remove the `Hooks`."
if self.do_remove: self.remove()
def remove(self):
if getattr(self, 'hooks', None): self.hooks.remove()
def __del__(self): self.remove()
class ActivationStats(HookCallback):
"Callback that record the mean and std of activations."
def on_train_begin(self, **kwargs):
"Initialize stats."
super().on_train_begin(**kwargs)
self.stats = []
def hook(self, m:nn.Module, i:Tensors, o:Tensors)->Tuple[Rank0Tensor,Rank0Tensor]:
"Take the mean and std of `o`."
return o.mean().item(),o.std().item()
def on_batch_end(self, train, **kwargs):
"Take the stored results and puts it in `self.stats`"
if train: self.stats.append(self.hooks.stored)
def on_train_end(self, **kwargs):
"Polish the final result."
super().on_train_end(**kwargs)
self.stats = tensor(self.stats).permute(2,1,0)
def dummy_batch(m: nn.Module, size:tuple=(64,64))->Tensor:
"Create a dummy batch to go through `m` with `size`."
ch_in = in_channels(m)
return one_param(m).new(1, ch_in, *size).requires_grad_(False).uniform_(-1.,1.)
def dummy_eval(m:nn.Module, size:tuple=(64,64)):
"Pass a `dummy_batch` in evaluation mode in `m` with `size`."
m.eval()
return m(dummy_batch(m, size))
#return m.eval()(dummy_batch(m, size))
def model_sizes(m:nn.Module, size:tuple=(64,64))->Tuple[Sizes,Tensor,Hooks]:
"Pass a dummy input through the model `m` to get the various sizes of activations."
with hook_outputs(m) as hooks:
x = dummy_eval(m, size)
return [o.stored.shape for o in hooks]
def num_features_model(m:nn.Module)->int:
"Return the number of output features for `model`."
sz = 64
while True:
try: return model_sizes(m, size=(sz,sz))[-1][1]
except Exception as e:
sz *= 2
if sz > 2048: raise
def total_params(m:nn.Module)->int:
params, trainable = 0, False
if hasattr(m, "weight") and hasattr(m.weight, "size"):
params += m.weight.numel()
trainable = m.weight.requires_grad
if hasattr(m, "bias") and hasattr(m.bias, "size"): params += m.bias.numel()
return params, trainable
def hook_params(modules:Collection[nn.Module])->Hooks:
return Hooks(modules, lambda m, i, o: total_params(m))
def params_size(m: Union[nn.Module,Learner], size: tuple = (3, 64, 64))->Tuple[Sizes, Tensor, Hooks]:
"Pass a dummy input through the model to get the various sizes. Returns (res,x,hooks) if `full`"
if isinstance(m, Learner):
if m.data.is_empty:
raise Exception("This is an empty `Learner` and `Learner.summary` requires some data to pass through the model.")
ds_type = DatasetType.Train if m.data.train_dl else (DatasetType.Valid if m.data.valid_dl else DatasetType.Test)
x = m.data.one_batch(ds_type=ds_type, detach=False, denorm=False)[0]
x = [o[:1] for o in x] if is_listy(x) else x[:1]
m = m.model
elif isinstance(m, nn.Module): x = next(m.parameters()).new(1, *size)
else: raise TypeError('You should either pass in a Learner or nn.Module')
with hook_outputs(flatten_model(m)) as hook_o:
with hook_params(flatten_model(m))as hook_p:
x = m.eval()(*x) if is_listy(x) else m.eval()(x)
output_size = [((o.stored.shape[1:]) if o.stored is not None else None) for o in hook_o]
params = [(o.stored if o.stored is not None else (None,None)) for o in hook_p]
params, trainables = map(list,zip(*params))
return output_size, params, trainables
def get_layer_name(layer:nn.Module)->str:
return str(layer.__class__).split(".")[-1].split("'")[0]
def layers_info(m:Collection[nn.Module]) -> Collection[namedtuple]:
func = lambda m:list(map(get_layer_name, flatten_model(m)))
layers_names = func(m.model) if isinstance(m, Learner) else func(m)
layers_sizes, layers_params, layers_trainable = params_size(m)
layer_info = namedtuple('Layer_Information', ['Layer', 'OutputSize', 'Params', 'Trainable'])
return list(map(layer_info, layers_names, layers_sizes, layers_params, layers_trainable))
def model_summary(m:Learner, n:int=70):
"Print a summary of `m` using a output text width of `n` chars"
info = layers_info(m)
header = ["Layer (type)", "Output Shape", "Param #", "Trainable"]
res = m.model.__class__.__name__ + "\n"
res += "=" * n + "\n"
res += f"{header[0]:<20} {header[1]:<20} {header[2]:<10} {header[3]:<10}\n"
res += "=" * n + "\n"
total_params = 0
total_trainable_params = 0
for layer, size, params, trainable in info:
if size is None: continue
total_params += int(params)
total_trainable_params += int(params) * trainable
size, trainable = str(list(size)), str(trainable)
res += f"{layer:<20} {size:<20} {int(params):<10,} {trainable:<10}\n"
res += "_" * n + "\n"
res += f"\nTotal params: {total_params:,}\n"
res += f"Total trainable params: {total_trainable_params:,}\n"
res += f"Total non-trainable params: {total_params - total_trainable_params:,}\n"
res += f"Optimized with {str(m.opt_func)[25:-1].replace('>', '')}\n"
if m.true_wd: res += f"Using true weight decay as discussed in https://www.fast.ai/2018/07/02/adam-weight-decay/ \n"
if "wd" in str(m.opt_func) or "weight_decay" in str(m.opt_func): res += f"\x1b[1;31m Specifying weight decay in the optimizer has no effect, Learner will overwrite \x1b[0m \n"
if "lr" in str(m.opt_func) or "learning_rate" in str(m.opt_func): res += f"\x1b[1;31m Specifying lr in the optimizer has no effect, pass it to fit or the defaults.lr will apply \x1b[0m \n"
res += f"Loss function : {m.loss_func.__class__.__name__}\n"
res += "=" * n + "\n"
res += "Callbacks functions applied \n"
res += "\n".join([f" {cbs.__class__.__name__}" for cbs in m.callbacks])
return PrettyString(res)
Learner.summary = model_summary
|