Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
def get_param_groups(net, weight_decay, norm_suffix='weight_g', verbose=False): | |
"""Get two parameter groups from `net`: One named "normalized" which will | |
override the optimizer with `weight_decay`, and one named "unnormalized" | |
which will inherit all hyperparameters from the optimizer. | |
Args: | |
net (torch.nn.Module): Network to get parameters from | |
weight_decay (float): Weight decay to apply to normalized weights. | |
norm_suffix (str): Suffix to select weights that should be normalized. | |
For WeightNorm, using 'weight_g' normalizes the scale variables. | |
verbose (bool): Print out number of normalized and unnormalized parameters. | |
""" | |
norm_params = [] | |
unnorm_params = [] | |
for n, p in net.named_parameters(): | |
if n.endswith(norm_suffix): | |
norm_params.append(p) | |
else: | |
unnorm_params.append(p) | |
param_groups = [{'name': 'normalized', 'params': norm_params, 'weight_decay': weight_decay}, | |
{'name': 'unnormalized', 'params': unnorm_params}] | |
if verbose: | |
print('{} normalized parameters'.format(len(norm_params))) | |
print('{} unnormalized parameters'.format(len(unnorm_params))) | |
return param_groups | |
class WNConv2d(nn.Module): | |
"""Weight-normalized 2d convolution. | |
Args: | |
in_channels (int): Number of channels in the input. | |
out_channels (int): Number of channels in the output. | |
kernel_size (int): Side length of each convolutional kernel. | |
padding (int): Padding to add on edges of input. | |
bias (bool): Use bias in the convolution operation. | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size, padding, bias=True): | |
super(WNConv2d, self).__init__() | |
self.conv = nn.utils.weight_norm( | |
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)) | |
def forward(self, x): | |
x = self.conv(x) | |
return x | |