Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding:UTF-8 -*- | |
import torch | |
import torch.nn as nn | |
import torch.nn.init as init | |
def weight_init(m): | |
''' | |
Usage: | |
model = Model() | |
model.apply(weight_init) | |
''' | |
if isinstance(m, nn.Conv1d): | |
init.normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.Conv2d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.Conv3d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.ConvTranspose1d): | |
init.normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.ConvTranspose2d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.ConvTranspose3d): | |
init.xavier_normal_(m.weight.data) | |
if m.bias is not None: | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.BatchNorm1d): | |
init.normal_(m.weight.data, mean=1, std=0.02) | |
init.constant_(m.bias.data, 0) | |
elif isinstance(m, nn.BatchNorm2d): | |
init.normal_(m.weight.data, mean=1, std=0.02) | |
init.constant_(m.bias.data, 0) | |
elif isinstance(m, nn.BatchNorm3d): | |
init.normal_(m.weight.data, mean=1, std=0.02) | |
init.constant_(m.bias.data, 0) | |
elif isinstance(m, nn.Linear): | |
init.xavier_normal_(m.weight.data) | |
init.normal_(m.bias.data) | |
elif isinstance(m, nn.LSTM): | |
for param in m.parameters(): | |
if len(param.shape) >= 2: | |
init.orthogonal_(param.data) | |
else: | |
init.normal_(param.data) | |
elif isinstance(m, nn.LSTMCell): | |
for param in m.parameters(): | |
if len(param.shape) >= 2: | |
init.orthogonal_(param.data) | |
else: | |
init.normal_(param.data) | |
elif isinstance(m, nn.GRU): | |
for param in m.parameters(): | |
if len(param.shape) >= 2: | |
init.orthogonal_(param.data) | |
else: | |
init.normal_(param.data) | |
elif isinstance(m, nn.GRUCell): | |
for param in m.parameters(): | |
if len(param.shape) >= 2: | |
init.orthogonal_(param.data) | |
else: | |
init.normal_(param.data) | |
if __name__ == '__main__': | |
pass |