File size: 1,540 Bytes
3eb682b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Utility function for weight initialization"""
import torch.nn as nn
from fvcore.nn.weight_init import c2_msra_fill
def init_weights(model, fc_init_std=0.01, zero_init_final_bn=True):
"""
Performs ResNet style weight initialization.
Args:
fc_init_std (float): the expected standard deviation for fc layer.
zero_init_final_bn (bool): if True, zero initialize the final bn for
every bottleneck.
"""
for m in model.modules():
if isinstance(m, nn.Conv3d):
"""
Follow the initialization method proposed in:
{He, Kaiming, et al.
"Delving deep into rectifiers: Surpassing human-level
performance on imagenet classification."
arXiv preprint arXiv:1502.01852 (2015)}
"""
c2_msra_fill(m)
elif isinstance(m, nn.BatchNorm3d):
if (
hasattr(m, "transform_final_bn")
and m.transform_final_bn
and zero_init_final_bn
):
batchnorm_weight = 0.0
else:
batchnorm_weight = 1.0
if m.weight is not None:
m.weight.data.fill_(batchnorm_weight)
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.Linear):
m.weight.data.normal_(mean=0.0, std=fc_init_std)
if m.bias is not None:
m.bias.data.zero_()
|