Spaces:
Runtime error
Runtime error
File size: 1,994 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch.nn as nn
def get_param_groups(net, weight_decay, norm_suffix='weight_g', verbose=False):
"""Get two parameter groups from `net`: One named "normalized" which will
override the optimizer with `weight_decay`, and one named "unnormalized"
which will inherit all hyperparameters from the optimizer.
Args:
net (torch.nn.Module): Network to get parameters from
weight_decay (float): Weight decay to apply to normalized weights.
norm_suffix (str): Suffix to select weights that should be normalized.
For WeightNorm, using 'weight_g' normalizes the scale variables.
verbose (bool): Print out number of normalized and unnormalized parameters.
"""
norm_params = []
unnorm_params = []
for n, p in net.named_parameters():
if n.endswith(norm_suffix):
norm_params.append(p)
else:
unnorm_params.append(p)
param_groups = [{'name': 'normalized', 'params': norm_params, 'weight_decay': weight_decay},
{'name': 'unnormalized', 'params': unnorm_params}]
if verbose:
print('{} normalized parameters'.format(len(norm_params)))
print('{} unnormalized parameters'.format(len(unnorm_params)))
return param_groups
class WNConv2d(nn.Module):
"""Weight-normalized 2d convolution.
Args:
in_channels (int): Number of channels in the input.
out_channels (int): Number of channels in the output.
kernel_size (int): Side length of each convolutional kernel.
padding (int): Padding to add on edges of input.
bias (bool): Use bias in the convolution operation.
"""
def __init__(self, in_channels, out_channels, kernel_size, padding, bias=True):
super(WNConv2d, self).__init__()
self.conv = nn.utils.weight_norm(
nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias))
def forward(self, x):
x = self.conv(x)
return x
|