|
import torch |
|
|
|
|
|
def init_module_weights(module): |
|
"""Initialize the weights""" |
|
|
|
from src.model.modules import QuantizationModule |
|
|
|
|
|
if isinstance(module, QuantizationModule): |
|
module.weight_proj.weight.data.normal_(mean=0.0, std=1) |
|
module.weight_proj.bias.data.zero_() |
|
torch.nn.init.uniform_(module.codebooks) |
|
elif isinstance(module, torch.nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=0.5) |
|
elif isinstance(module, (torch.nn.LayerNorm, torch.nn.GroupNorm)): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, torch.nn.Conv1d): |
|
torch.nn.init.kaiming_normal_(module.weight.data) |
|
|
|
if ( |
|
isinstance(module, (torch.nn.Linear, torch.nn.Conv1d)) |
|
and module.bias is not None |
|
): |
|
module.bias.data.zero_() |
|
|