myTest01 / models /moglow /modules.py
meng2003's picture
Upload 85 files
bc32eea
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.linalg
import scipy.special
from . import thops
def nan_throw(tensor, name="tensor"):
stop = False
if ((tensor!=tensor).any()):
print(name + " has nans")
stop = True
if (torch.isinf(tensor).any()):
print(name + " has infs")
stop = True
if stop:
print(name + ": " + str(tensor))
#raise ValueError(name + ' contains nans of infs')
class _ActNorm(nn.Module):
"""
Activation Normalization
Initialize the bias and scale with a given minibatch,
so that the output per-channel have zero mean and unit variance for that.
After initialization, `bias` and `logs` will be trained as parameters.
"""
def __init__(self, num_features, scale=1.):
super().__init__()
# register mean and scale
size = [1, num_features, 1]
self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
self.num_features = num_features
self.scale = float(scale)
# self.inited = False
self.register_buffer('is_initialized', torch.zeros(1))
def _check_input_dim(self, input):
return NotImplemented
def initialize_parameters(self, input):
# print("HOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOo")
self._check_input_dim(input)
if not self.training:
return
assert input.device == self.bias.device
with torch.no_grad():
bias = thops.mean(input.clone(), dim=[0, 2], keepdim=True) * -1.0
vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2], keepdim=True)
logs = torch.log(self.scale/(torch.sqrt(vars)+1e-6))
self.bias.data.copy_(bias.data)
self.logs.data.copy_(logs.data)
# self.inited = True
self.is_initialized += 1.
def _center(self, input, reverse=False):
if not reverse:
return input + self.bias
else:
return input - self.bias
def _scale(self, input, logdet=None, reverse=False):
logs = self.logs
if not reverse:
input = input * torch.exp(logs)
else:
input = input * torch.exp(-logs)
if logdet is not None:
"""
logs is log_std of `mean of channels`
so we need to multiply timesteps
"""
dlogdet = thops.sum(logs) * thops.timesteps(input)
if reverse:
dlogdet *= -1
logdet = logdet + dlogdet
return input, logdet
def forward(self, input, logdet=None, reverse=False):
if not self.is_initialized:
self.initialize_parameters(input)
self._check_input_dim(input)
# no need to permute dims as old version
if not reverse:
# center and scale
input = self._center(input, reverse)
input, logdet = self._scale(input, logdet, reverse)
else:
# scale and center
input, logdet = self._scale(input, logdet, reverse)
input = self._center(input, reverse)
return input, logdet
class ActNorm2d(_ActNorm):
def __init__(self, num_features, scale=1.):
super().__init__(num_features, scale)
def _check_input_dim(self, input):
assert len(input.size()) == 3
assert input.size(1) == self.num_features, (
"[ActNorm]: input should be in shape as `BCT`,"
" channels should be {} rather than {}".format(
self.num_features, input.size()))
class LinearZeros(nn.Linear):
def __init__(self, in_channels, out_channels, logscale_factor=3):
super().__init__(in_channels, out_channels)
self.logscale_factor = logscale_factor
# set logs parameter
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels)))
# init
self.weight.data.zero_()
self.bias.data.zero_()
def forward(self, input):
output = super().forward(input)
return output * torch.exp(self.logs * self.logscale_factor)
class Conv2d(nn.Conv2d):
pad_dict = {
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
"valid": lambda kernel, stride: [0 for _ in kernel]
}
@staticmethod
def get_padding(padding, kernel_size, stride):
# make paddding
if isinstance(padding, str):
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if isinstance(stride, int):
stride = [stride, stride]
padding = padding.lower()
try:
padding = Conv2d.pad_dict[padding](kernel_size, stride)
except KeyError:
raise ValueError("{} is not supported".format(padding))
return padding
def __init__(self, in_channels, out_channels,
kernel_size=[3, 3], stride=[1, 1],
padding="same", do_actnorm=True, weight_std=0.05):
padding = Conv2d.get_padding(padding, kernel_size, stride)
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, bias=(not do_actnorm))
# init weight with std
self.weight.data.normal_(mean=0.0, std=weight_std)
if not do_actnorm:
self.bias.data.zero_()
else:
self.actnorm = ActNorm2d(out_channels)
self.do_actnorm = do_actnorm
def forward(self, input):
x = super().forward(input)
if self.do_actnorm:
x, _ = self.actnorm(x)
return x
class Conv2dZeros(nn.Conv2d):
def __init__(self, in_channels, out_channels,
kernel_size=[3, 3], stride=[1, 1],
padding="same", logscale_factor=3):
padding = Conv2d.get_padding(padding, kernel_size, stride)
super().__init__(in_channels, out_channels, kernel_size, stride, padding)
# logscale_factor
self.logscale_factor = logscale_factor
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
# init
self.weight.data.zero_()
self.bias.data.zero_()
def forward(self, input):
output = super().forward(input)
return output * torch.exp(self.logs * self.logscale_factor)
class LinearNormInit(nn.Linear):
def __init__(self, in_channels, out_channels, weight_std=0.05):
super().__init__(in_channels, out_channels)
# init
self.weight.data.normal_(mean=0.0, std=weight_std)
self.bias.data.zero_()
class LinearZeroInit(nn.Linear):
def __init__(self, in_channels, out_channels):
super().__init__(in_channels, out_channels)
# init
self.weight.data.zero_()
self.bias.data.zero_()
class Permute2d(nn.Module):
def __init__(self, num_channels, shuffle):
super().__init__()
self.num_channels = num_channels
print(num_channels)
self.indices = np.arange(self.num_channels - 1, -1,-1).astype(np.long)
self.indices_inverse = np.zeros((self.num_channels), dtype=np.long)
print(self.indices_inverse.shape)
for i in range(self.num_channels):
self.indices_inverse[self.indices[i]] = i
if shuffle:
self.reset_indices()
def reset_indices(self):
np.random.shuffle(self.indices)
for i in range(self.num_channels):
self.indices_inverse[self.indices[i]] = i
def forward(self, input, reverse=False):
assert len(input.size()) == 3
if not reverse:
return input[:, self.indices, :]
else:
return input[:, self.indices_inverse, :]
class InvertibleConv1x1(nn.Module):
def __init__(self, num_channels, LU_decomposed=False):
super().__init__()
w_shape = [num_channels, num_channels]
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
if not LU_decomposed:
# Sample a random orthogonal matrix:
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
else:
np_p, np_l, np_u = scipy.linalg.lu(w_init)
np_s = np.diag(np_u)
np_sign_s = np.sign(np_s)
np_log_s = np.log(np.abs(np_s))
np_u = np.triu(np_u, k=1)
l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1)
eye = np.eye(*w_shape, dtype=np.float32)
#self.p = torch.Tensor(np_p.astype(np.float32))
#self.sign_s = torch.Tensor(np_sign_s.astype(np.float32))
self.register_buffer('p', torch.Tensor(np_p.astype(np.float32)))
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32)))
self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32)))
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32)))
self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32)))
self.l_mask = torch.Tensor(l_mask)
self.eye = torch.Tensor(eye)
self.w_shape = w_shape
self.LU = LU_decomposed
self.first_pass = True
self.saved_weight = None
self.saved_dlogdet = None
def get_weight(self, input, reverse):
w_shape = self.w_shape
if not self.LU:
timesteps = thops.timesteps(input)
dlogdet = torch.slogdet(self.weight)[1] * timesteps
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1)
else:
weight = torch.inverse(self.weight.double()).float()\
.view(w_shape[0], w_shape[1], 1)
return weight, dlogdet
else:
self.p = self.p.to(input.device)
self.sign_s = self.sign_s.to(input.device)
self.l_mask = self.l_mask.to(input.device)
self.eye = self.eye.to(input.device)
l = self.l * self.l_mask + self.eye
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
dlogdet = thops.sum(self.log_s) * thops.timesteps(input)
if not reverse:
w = torch.matmul(self.p, torch.matmul(l, u))
else:
l = torch.inverse(l.double()).float()
u = torch.inverse(u.double()).float()
w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
return w.view(w_shape[0], w_shape[1], 1), dlogdet
def forward(self, input, logdet=None, reverse=False):
"""
log-det = log|abs(|W|)| * timesteps
"""
# weight, dlogdet = self.get_weight(input, reverse)
if not reverse:
weight, dlogdet = self.get_weight(input, reverse)
else:
if self.first_pass:
weight, dlogdet = self.get_weight(input, reverse)
self.saved_weight = weight
if logdet is not None:
self.saved_dlogdet = dlogdet
self.first_pass = False
else:
weight = self.saved_weight
if logdet is not None:
dlogdet = self.saved_dlogdet
nan_throw(weight, "weight")
nan_throw(dlogdet, "dlogdet")
if not reverse:
z = F.conv1d(input, weight)
if logdet is not None:
logdet = logdet + dlogdet
return z, logdet
else:
nan_throw(input, "InConv input")
z = F.conv1d(input, weight)
nan_throw(z, "InConv z")
nan_throw(logdet, "InConv logdet")
if logdet is not None:
logdet = logdet - dlogdet
return z, logdet
# Here we define our model as a class
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim=1, num_layers=2, dropout=0.0):
super(LSTM, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# Define the LSTM layer
self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True)
# Define the output layer
self.linear = LinearZeroInit(self.hidden_dim, output_dim)
# do_init
self.do_init = True
def init_hidden(self):
# This is what we'll initialise our hidden state as
self.do_init = True
def forward(self, input):
# Forward pass through LSTM layer
# shape of lstm_out: [batch_size, input_size, hidden_dim]
# shape of self.hidden: (a, b), where a and b both
# have shape (batch_size, num_layers, hidden_dim).
if self.do_init:
lstm_out, self.hidden = self.lstm(input)
self.do_init = False
else:
lstm_out, self.hidden = self.lstm(input, self.hidden)
#self.hidden = hidden[0].to(input.device), hidden[1].to(input.device)
# Final layer
y_pred = self.linear(lstm_out)
return y_pred
# Here we define our model as a class
class GRU(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim=1, num_layers=2, dropout=0.0):
super(GRU, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
# Define the LSTM layer
self.gru = nn.GRU(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True)
# Define the output layer
self.linear = LinearZeroInit(self.hidden_dim, output_dim)
# do_init
self.do_init = True
def init_hidden(self):
# This is what we'll initialise our hidden state as
self.do_init = True
def forward(self, input):
# Forward pass through LSTM layer
# shape of lstm_out: [batch_size, input_size, hidden_dim]
# shape of self.hidden: (a, b), where a and b both
# have shape (batch_size, num_layers, hidden_dim).
if self.do_init:
gru_out, self.hidden = self.gru(input)
self.do_init = False
else:
gru_out, self.hidden = self.gru(input, self.hidden)
#self.hidden = hidden[0].to(input.device), hidden[1].to(input.device)
# Final layer
y_pred = self.linear(gru_out)
return y_pred
class GaussianDiag:
Log2PI = float(np.log(2 * np.pi))
@staticmethod
def likelihood(x):
"""
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
k = 1 (Independent)
Var = logs ** 2
"""
return -0.5 * (((x) ** 2) + GaussianDiag.Log2PI)
@staticmethod
def logp(x):
likelihood = GaussianDiag.likelihood(x)
return thops.sum(likelihood, dim=[1, 2])
@staticmethod
def sample(z_shape, eps_std=None, device=None):
eps_std = eps_std or 1
eps = torch.normal(mean=torch.zeros(z_shape),
std=torch.ones(z_shape) * eps_std)
eps = eps.to(device)
return eps
class StudentT:
def __init__(self, df, d):
self.df=df
self.d=d
self.norm_const = scipy.special.loggamma(0.5*(df+d))-scipy.special.loggamma(0.5*df)-0.5*d*np.log(np.pi*df)
def logp(self,x):
'''
Multivariate t-student density:
output:
the sum density of the given element
'''
#df=100
#d=x.shape[1]
#norm_const = scipy.special.loggamma(0.5*(df+d))-scipy.special.loggamma(0.5*df)-0.5*d*np.log(np.pi*df)
#import pdb; pdb.set_trace()
x_norms = thops.sum(((x) ** 2), dim=[1])
likelihood = self.norm_const-0.5*(self.df+self.d)*torch.log(1+(1/self.df)*x_norms)
return thops.sum(likelihood, dim=[1])
def sample(self,z_shape, eps_std=None, device=None):
'''generate random variables of multivariate t distribution
Parameters
----------
m : array_like
mean of random variable, length determines dimension of random variable
S : array_like
square array of covariance matrix
df : int or float
degrees of freedom
n : int
number of observations, return random array will be (n, len(m))
Returns
-------
rvs : ndarray, (n, len(m))
each row is an independent draw of a multivariate t distributed
random variable
'''
#df=100
# import pdb; pdb.set_trace()
x_shape = torch.Size((z_shape[0], 1, z_shape[2]))
x = np.random.chisquare(self.df, x_shape)/self.df
x = np.tile(x, (1,z_shape[1],1))
x = torch.Tensor(x.astype(np.float32))
z = torch.normal(mean=torch.zeros(z_shape),std=torch.ones(z_shape) * eps_std)
# import pdb; pdb.set_trace()
return (z/torch.sqrt(x)).to(device)
class Split2d(nn.Module):
def __init__(self, num_channels):
super().__init__()
print("Split2d num_channels:" + str(num_channels))
self.num_channels = num_channels
self.conv = Conv2dZeros(num_channels // 2, num_channels)
def split2d_prior(self, z):
h = self.conv(z)
return thops.split_feature(h, "cross")
def forward(self, input, cond, logdet=0., reverse=False, eps_std=None):
if not reverse:
#print("forward Split2d input:" + str(input.shape))
z1, z2 = thops.split_feature(input, "split")
#mean, logs = self.split2d_prior(z1)
logdet = GaussianDiag.logp(z2) + logdet
return z1, cond, logdet
else:
z1 = input
#print("reverse Split2d z1.shape:" + str(z1.shape))
#mean, logs = self.split2d_prior(z1)
z2_shape = list(z1.shape)
z2_shape[1] = self.num_channels-z1.shape[1]
z2 = GaussianDiag.sample(z2_shape, eps_std, device=input.device)
z = thops.cat_feature(z1, z2)
return z, cond, logdet
def squeeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert H % factor == 0 , "{}".format((H, W))
x = input.view(B, C, H // factor, factor, W, 1)
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C * factor, H // factor, W)
return x
def unsqueeze2d(input, factor=2):
assert factor >= 1 and isinstance(factor, int)
#factor2 = factor ** 2
if factor == 1:
return input
size = input.size()
B = size[0]
C = size[1]
H = size[2]
W = size[3]
assert C % (factor) == 0, "{}".format(C)
x = input.view(B, C // factor, factor, 1, H, W)
x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
x = x.view(B, C // (factor), H * factor, W)
return x
class SqueezeLayer(nn.Module):
def __init__(self, factor):
super().__init__()
self.factor = factor
def forward(self, input, cond = None, logdet=None, reverse=False):
if not reverse:
output = squeeze2d(input, self.factor)
cond_out = squeeze2d(cond, self.factor)
return output, cond_out, logdet
else:
output = unsqueeze2d(input, self.factor)
cond_output = unsqueeze2d(cond, self.factor)
return output, cond_output, logdet
def squeeze_cond(self, cond):
cond_out = squeeze2d(cond, self.factor)
return cond_out