Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| import numpy as np | |
| import scipy.linalg | |
| class InvertibleConv1x1(nn.Module): | |
| def __init__(self, num_channels, LU_decomposed=True): | |
| super().__init__() | |
| w_shape = [num_channels, num_channels] | |
| w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) | |
| if not LU_decomposed: | |
| # Sample a random orthogonal matrix: | |
| self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) | |
| else: | |
| # import pdb;pdb.set_trace() | |
| np_p, np_l, np_u = scipy.linalg.lu(w_init) | |
| np_s = np.diag(np_u) | |
| np_sign_s = np.sign(np_s) | |
| np_log_s = np.log(np.abs(np_s)) | |
| np_u = np.triu(np_u, k=1) | |
| l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1) | |
| eye = np.eye(*w_shape, dtype=np.float32) | |
| self.register_buffer('p', torch.Tensor(np_p.astype(np.float32))) | |
| self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32))) | |
| self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32))) | |
| self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32))) | |
| self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32))) | |
| self.l_mask = torch.Tensor(l_mask) | |
| self.eye = torch.Tensor(eye) | |
| self.w_shape = w_shape | |
| self.LU = LU_decomposed | |
| self.first_pass = True | |
| self.saved_weight = None | |
| self.saved_dsldj = None | |
| def get_weight(self, input, reverse): | |
| w_shape = self.w_shape | |
| if not self.LU: | |
| dlogdet = torch.slogdet(self.weight)[1] * input.size(2) * input.size(3) | |
| if not reverse: | |
| weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) | |
| else: | |
| weight = torch.inverse(self.weight.double()).float()\ | |
| .view(w_shape[0], w_shape[1], 1, 1) | |
| return weight, dlogdet | |
| else: | |
| self.p = self.p.to(input.device) | |
| self.sign_s = self.sign_s.to(input.device) | |
| self.l_mask = self.l_mask.to(input.device) | |
| self.eye = self.eye.to(input.device) | |
| l = self.l * self.l_mask + self.eye | |
| u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) | |
| dlogdet = self.log_s.sum() * input.size(2) * input.size(3) | |
| if not reverse: | |
| w = torch.matmul(self.p, torch.matmul(l, u)) | |
| else: | |
| l = torch.inverse(l.double()).float() | |
| u = torch.inverse(u.double()).float() | |
| w = torch.matmul(u, torch.matmul(l, self.p.inverse())) | |
| return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet | |
| def forward(self, x, cond, sldj=None, reverse=False): | |
| """ | |
| log-det = log|abs(|W|)| * pixels | |
| """ | |
| x = torch.cat(x, dim=1) | |
| if not reverse: | |
| weight, dsldj = self.get_weight(x, reverse) | |
| else: | |
| if self.first_pass: | |
| weight, dsldj = self.get_weight(x, reverse) | |
| self.saved_weight = weight | |
| if sldj is not None: | |
| self.saved_dsldj = dsldj | |
| self.first_pass = False | |
| else: | |
| weight = self.saved_weight | |
| if sldj is not None: | |
| dsldj = self.saved_dsldj | |
| if not reverse: | |
| x = F.conv2d(x, weight) | |
| if sldj is not None: | |
| sldj = sldj + dsldj | |
| else: | |
| x = F.conv2d(x, weight) | |
| if sldj is not None: | |
| sldj = sldj - dsldj | |
| x = x.chunk(2, dim=1) | |
| return x, sldj | |
| class InvConv(nn.Module): | |
| """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow | |
| (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version. | |
| Args: | |
| num_channels (int): Number of channels in the input and output. | |
| random_init (bool): Initialize with a random orthogonal matrix. | |
| Otherwise initialize with noisy identity. | |
| """ | |
| def __init__(self, num_channels, random_init=False): | |
| super(InvConv, self).__init__() | |
| self.num_channels = num_channels | |
| if random_init: | |
| # Initialize with a random orthogonal matrix | |
| w_init = np.random.randn(self.num_channels, self.num_channels) | |
| w_init = np.linalg.qr(w_init)[0] | |
| else: | |
| # Initialize as identity permutation with some noise | |
| w_init = np.eye(self.num_channels, self.num_channels) \ | |
| + 1e-3 * np.random.randn(self.num_channels, self.num_channels) | |
| self.weight = nn.Parameter(torch.from_numpy(w_init.astype(np.float32))) | |
| def forward(self, x, cond, sldj, reverse=False): | |
| x = torch.cat(x, dim=1) | |
| ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) | |
| if reverse: | |
| weight = torch.inverse(self.weight.double()).float() | |
| sldj = sldj - ldj | |
| else: | |
| weight = self.weight | |
| sldj = sldj + ldj | |
| weight = weight.view(self.num_channels, self.num_channels, 1, 1) | |
| x = F.conv2d(x, weight) | |
| x = x.chunk(2, dim=1) | |
| return x, sldj | |