import torch.nn as nn def get_param_groups(net, weight_decay, norm_suffix='weight_g', verbose=False): """Get two parameter groups from `net`: One named "normalized" which will override the optimizer with `weight_decay`, and one named "unnormalized" which will inherit all hyperparameters from the optimizer. Args: net (torch.nn.Module): Network to get parameters from weight_decay (float): Weight decay to apply to normalized weights. norm_suffix (str): Suffix to select weights that should be normalized. For WeightNorm, using 'weight_g' normalizes the scale variables. verbose (bool): Print out number of normalized and unnormalized parameters. """ norm_params = [] unnorm_params = [] for n, p in net.named_parameters(): if n.endswith(norm_suffix): norm_params.append(p) else: unnorm_params.append(p) param_groups = [{'name': 'normalized', 'params': norm_params, 'weight_decay': weight_decay}, {'name': 'unnormalized', 'params': unnorm_params}] if verbose: print('{} normalized parameters'.format(len(norm_params))) print('{} unnormalized parameters'.format(len(unnorm_params))) return param_groups class WNConv2d(nn.Module): """Weight-normalized 2d convolution. Args: in_channels (int): Number of channels in the input. out_channels (int): Number of channels in the output. kernel_size (int): Side length of each convolutional kernel. padding (int): Padding to add on edges of input. bias (bool): Use bias in the convolution operation. """ def __init__(self, in_channels, out_channels, kernel_size, padding, bias=True): super(WNConv2d, self).__init__() self.conv = nn.utils.weight_norm( nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)) def forward(self, x): x = self.conv(x) return x