import torch.nn as nn import torch.nn.functional as F import torch import numpy as np import scipy.linalg class InvertibleConv1x1(nn.Module): def __init__(self, num_channels, LU_decomposed=True): super().__init__() w_shape = [num_channels, num_channels] w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) if not LU_decomposed: # Sample a random orthogonal matrix: self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) else: # import pdb;pdb.set_trace() np_p, np_l, np_u = scipy.linalg.lu(w_init) np_s = np.diag(np_u) np_sign_s = np.sign(np_s) np_log_s = np.log(np.abs(np_s)) np_u = np.triu(np_u, k=1) l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1) eye = np.eye(*w_shape, dtype=np.float32) self.register_buffer('p', torch.Tensor(np_p.astype(np.float32))) self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32))) self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32))) self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32))) self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32))) self.l_mask = torch.Tensor(l_mask) self.eye = torch.Tensor(eye) self.w_shape = w_shape self.LU = LU_decomposed self.first_pass = True self.saved_weight = None self.saved_dsldj = None def get_weight(self, input, reverse): w_shape = self.w_shape if not self.LU: dlogdet = torch.slogdet(self.weight)[1] * input.size(2) * input.size(3) if not reverse: weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) else: weight = torch.inverse(self.weight.double()).float()\ .view(w_shape[0], w_shape[1], 1, 1) return weight, dlogdet else: self.p = self.p.to(input.device) self.sign_s = self.sign_s.to(input.device) self.l_mask = self.l_mask.to(input.device) self.eye = self.eye.to(input.device) l = self.l * self.l_mask + self.eye u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) dlogdet = self.log_s.sum() * input.size(2) * input.size(3) if not reverse: w = torch.matmul(self.p, torch.matmul(l, u)) else: l = torch.inverse(l.double()).float() u = torch.inverse(u.double()).float() w = torch.matmul(u, torch.matmul(l, self.p.inverse())) return w.view(w_shape[0], w_shape[1], 1, 1), dlogdet def forward(self, x, cond, sldj=None, reverse=False): """ log-det = log|abs(|W|)| * pixels """ x = torch.cat(x, dim=1) if not reverse: weight, dsldj = self.get_weight(x, reverse) else: if self.first_pass: weight, dsldj = self.get_weight(x, reverse) self.saved_weight = weight if sldj is not None: self.saved_dsldj = dsldj self.first_pass = False else: weight = self.saved_weight if sldj is not None: dsldj = self.saved_dsldj if not reverse: x = F.conv2d(x, weight) if sldj is not None: sldj = sldj + dsldj else: x = F.conv2d(x, weight) if sldj is not None: sldj = sldj - dsldj x = x.chunk(2, dim=1) return x, sldj class InvConv(nn.Module): """Invertible 1x1 Convolution for 2D inputs. Originally described in Glow (https://arxiv.org/abs/1807.03039). Does not support LU-decomposed version. Args: num_channels (int): Number of channels in the input and output. random_init (bool): Initialize with a random orthogonal matrix. Otherwise initialize with noisy identity. """ def __init__(self, num_channels, random_init=False): super(InvConv, self).__init__() self.num_channels = num_channels if random_init: # Initialize with a random orthogonal matrix w_init = np.random.randn(self.num_channels, self.num_channels) w_init = np.linalg.qr(w_init)[0] else: # Initialize as identity permutation with some noise w_init = np.eye(self.num_channels, self.num_channels) \ + 1e-3 * np.random.randn(self.num_channels, self.num_channels) self.weight = nn.Parameter(torch.from_numpy(w_init.astype(np.float32))) def forward(self, x, cond, sldj, reverse=False): x = torch.cat(x, dim=1) ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3) if reverse: weight = torch.inverse(self.weight.double()).float() sldj = sldj - ldj else: weight = self.weight sldj = sldj + ldj weight = weight.view(self.num_channels, self.num_channels, 1, 1) x = F.conv2d(x, weight) x = x.chunk(2, dim=1) return x, sldj