Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.utils as utils | |
| def bits_per_dim(x, nll): | |
| """Get the bits per dimension implied by using model with `loss` | |
| for compressing `x`, assuming each entry can take on `k` discrete values. | |
| Args: | |
| x (torch.Tensor): Input to the model. Just used for dimensions. | |
| nll (torch.Tensor): Scalar negative log-likelihood loss tensor. | |
| Returns: | |
| bpd (torch.Tensor): Bits per dimension implied if compressing `x`. | |
| """ | |
| dim = np.prod(x.size()[1:]) | |
| bpd = nll / (np.log(2) * dim) | |
| return bpd | |
| def clip_grad_norm(optimizer, max_norm, norm_type=2): | |
| """Clip the norm of the gradients for all parameters under `optimizer`. | |
| Args: | |
| optimizer (torch.optim.Optimizer): | |
| max_norm (float): The maximum allowable norm of gradients. | |
| norm_type (int): The type of norm to use in computing gradient norms. | |
| """ | |
| for group in optimizer.param_groups: | |
| utils.clip_grad_norm_(group['params'], max_norm, norm_type) | |
| class NLLLoss(nn.Module): | |
| """Negative log-likelihood loss assuming isotropic gaussian with unit norm. | |
| Args: | |
| k (int or float): Number of discrete values in each input dimension. | |
| E.g., `k` is 256 for natural images. | |
| See Also: | |
| Equation (3) in the RealNVP paper: https://arxiv.org/abs/1605.08803 | |
| """ | |
| def __init__(self, k=256): | |
| super(NLLLoss, self).__init__() | |
| self.k = k | |
| def forward(self, z, sldj): | |
| prior_ll = -0.5 * (z ** 2 + np.log(2 * np.pi)) | |
| prior_ll = prior_ll.flatten(1).sum(-1) \ | |
| - np.log(self.k) * np.prod(z.size()[1:]) | |
| ll = prior_ll + sldj | |
| nll = -ll.mean() | |
| return nll | |