Spaces:
Running
Running
File size: 927 Bytes
aea73e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
# -*- coding: utf-8 -*-
# Helper functions for models
#
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
from torch import nn
def reset_weights(model: nn.Module) -> None:
"""Reset the parameters of the model to avaid weight leakage
Args:
model (nn.Module): PyTorch Model
"""
for layer in model.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
def initialize_weights(module: nn.Module) -> None:
"""Initialize Module weights according to xavier
Args:
module (nn.Module): Model
"""
for m in module.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
|