import torch def init_module_weights(module): """Initialize the weights""" from src.model.modules import QuantizationModule # gumbel softmax requires special init if isinstance(module, QuantizationModule): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() torch.nn.init.uniform_(module.codebooks) elif isinstance(module, torch.nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.5) elif isinstance(module, (torch.nn.LayerNorm, torch.nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, torch.nn.Conv1d): torch.nn.init.kaiming_normal_(module.weight.data) if ( isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)) and module.bias is not None ): module.bias.data.zero_()