File size: 2,511 Bytes
d60982d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#!/usr/bin/env python
# -*- coding:UTF-8 -*-
import torch
import torch.nn as nn
import torch.nn.init as init
def weight_init(m):
'''
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Conv1d):
init.normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.Conv3d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.ConvTranspose1d):
init.normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.ConvTranspose2d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.ConvTranspose3d):
init.xavier_normal_(m.weight.data)
if m.bias is not None:
init.normal_(m.bias.data)
elif isinstance(m, nn.BatchNorm1d):
init.normal_(m.weight.data, mean=1, std=0.02)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, mean=1, std=0.02)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.BatchNorm3d):
init.normal_(m.weight.data, mean=1, std=0.02)
init.constant_(m.bias.data, 0)
elif isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.LSTM):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.LSTMCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.GRU):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
if __name__ == '__main__':
pass |