import torch import torch.nn as nn from models.util import mean_dim class _BaseNorm(nn.Module): """Base class for ActNorm (Glow) and PixNorm (Flow++). The mean and inv_std get initialized using the mean and variance of the first mini-batch. After the init, mean and inv_std are trainable parameters. Adapted from: > https://github.com/openai/glow """ def __init__(self, num_channels, height, width): super(_BaseNorm, self).__init__() # Input gets concatenated along channel axis #num_channels *= 2 self.register_buffer('is_initialized', torch.zeros(1)) self.mean = nn.Parameter(torch.zeros(1, num_channels, height, width)) self.inv_std = nn.Parameter(torch.zeros(1, num_channels, height, width)) self.eps = 1e-6 def initialize_parameters(self, x): if not self.training: return with torch.no_grad(): mean, inv_std = self._get_moments(x) self.mean.data.copy_(mean.data) self.inv_std.data.copy_(inv_std.data) self.is_initialized += 1. def _center(self, x, reverse=False): if reverse: return x + self.mean else: return x - self.mean def _get_moments(self, x): raise NotImplementedError('Subclass of _BaseNorm must implement _get_moments') def _scale(self, x, sldj, reverse=False): raise NotImplementedError('Subclass of _BaseNorm must implement _scale') def forward(self, x, cond, ldj=None, reverse=False): #import pdb;pdb.set_trace() x = torch.cat(x, dim=1) # import pdb;pdb.set_trace() if not self.is_initialized: print("Initializing norm Layer!") self.initialize_parameters(x) if reverse: x, ldj = self._scale(x, ldj, reverse) x = self._center(x, reverse) else: x = self._center(x, reverse) x, ldj = self._scale(x, ldj, reverse) x = x.chunk(2, dim=1) return x, ldj class BatchNorm(nn.Module): def __init__(self, num_channels, momentum=0.1): super(BatchNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(1, num_channels, 1, 1)) self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) self.register_buffer('running_mean', torch.zeros(1, num_channels, 1, 1)) self.register_buffer('running_var', torch.ones(1, num_channels, 1, 1)) self.eps = 1e-5 self.momentum = momentum self.inv_std = None self.register_buffer('is_initialized', torch.zeros(1)) def _get_moments(self, x): mean = mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True).detach() var = mean_dim((x.clone() - mean) ** 2, dim=[0, 2, 3], keepdims=True).detach() # inv_std = 1. / (var.sqrt() + self.eps) if not self.is_initialized: self.running_mean.data.copy_(mean.data) self.running_var.data.copy_(var.data) self.is_initialized += 1. else: if self.momentum < 1.0: self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var else: self.running_mean.data.copy_(mean.data) self.running_var.data.copy_(var.data) def forward(self, x, cond, ldj=None, reverse=False): # import pdb;pdb.set_trace() x = torch.cat(x, dim=1) if self.training: # print("HI") self._get_moments(x) # print(self.running_var[0]) inv_std = 1. / (self.running_var.sqrt() + self.eps) if reverse: x = self._center(x, self.beta, reverse) x, ldj = self._scale(x, ldj, self.gamma, reverse) x, ldj = self._scale(x, ldj, inv_std, reverse) x = self._center(x, self.running_mean, reverse) else: x = self._center(x, self.running_mean, reverse) x, ldj = self._scale(x, ldj, inv_std, reverse) x, ldj = self._scale(x, ldj, self.gamma, reverse) x = self._center(x, self.beta, reverse) x = x.chunk(2, dim=1) return x, ldj def _center(self, x, centerer, reverse=False): if reverse: return x + centerer else: return x - centerer def _scale(self, x, sldj, scaler, reverse=False): if reverse: x = x / scaler sldj = sldj - scaler.log().sum() * x.size(2) * x.size(3) else: x = x * scaler sldj = sldj + scaler.log().sum() * x.size(2) * x.size(3) return x, sldj class ActNorm(_BaseNorm): """Activation Normalization used in Glow The mean and inv_std get initialized using the mean and variance of the first mini-batch. After the init, mean and inv_std are trainable parameters. """ def __init__(self, num_channels): super(ActNorm, self).__init__(num_channels, 1, 1) def _get_moments(self, x): mean = mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True) var = mean_dim((x.clone() - mean) ** 2, dim=[0, 2, 3], keepdims=True) inv_std = 1. / (var.sqrt() + self.eps) return mean, inv_std def _scale(self, x, sldj, reverse=False): if reverse: x = x / self.inv_std sldj = sldj - self.inv_std.log().sum() * x.size(2) * x.size(3) else: x = x * self.inv_std sldj = sldj + self.inv_std.log().sum() * x.size(2) * x.size(3) return x, sldj class PixNorm(_BaseNorm): """Pixel-wise Activation Normalization used in Flow++ Normalizes every activation independently (note this differs from the variant used in in Glow, where they normalize each channel). The mean and stddev get initialized using the mean and stddev of the first mini-batch. After the initialization, `mean` and `inv_std` become trainable parameters. """ def _get_moments(self, x): mean = torch.mean(x.clone(), dim=0, keepdim=True) var = torch.mean((x.clone() - mean) ** 2, dim=0, keepdim=True) inv_std = 1. / (var.sqrt() + self.eps) return mean, inv_std def _scale(self, x, sldj, reverse=False): if reverse: x = x / self.inv_std sldj = sldj - self.inv_std.log().sum() else: x = x * self.inv_std sldj = sldj + self.inv_std.log().sum() return x, sldj