File size: 360 Bytes
b213d84
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates.

from torch import nn


def initialize_module_params(module: nn.Module) -> None:
    for name, param in module.named_parameters():
        if "bias" in name:
            nn.init.constant_(param, 0)
        elif "weight" in name:
            nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu")