import numpy as np import torch.nn as nn import torch.nn.utils as utils def bits_per_dim(x, nll): """Get the bits per dimension implied by using model with `loss` for compressing `x`, assuming each entry can take on `k` discrete values. Args: x (torch.Tensor): Input to the model. Just used for dimensions. nll (torch.Tensor): Scalar negative log-likelihood loss tensor. Returns: bpd (torch.Tensor): Bits per dimension implied if compressing `x`. """ dim = np.prod(x.size()[1:]) bpd = nll / (np.log(2) * dim) return bpd def clip_grad_norm(optimizer, max_norm, norm_type=2): """Clip the norm of the gradients for all parameters under `optimizer`. Args: optimizer (torch.optim.Optimizer): max_norm (float): The maximum allowable norm of gradients. norm_type (int): The type of norm to use in computing gradient norms. """ for group in optimizer.param_groups: utils.clip_grad_norm_(group['params'], max_norm, norm_type) class NLLLoss(nn.Module): """Negative log-likelihood loss assuming isotropic gaussian with unit norm. Args: k (int or float): Number of discrete values in each input dimension. E.g., `k` is 256 for natural images. See Also: Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803 """ def __init__(self, k=256): super(NLLLoss, self).__init__() self.k = k def forward(self, z, sldj): prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) prior_ll = prior_ll.flatten(1).sum(-1) \ - np.log(self.k) * np.prod(z.size()[1:]) ll = prior_ll + sldj nll = -ll.mean() return nll