Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import scipy.linalg | |
import scipy.special | |
from . import thops | |
def nan_throw(tensor, name="tensor"): | |
stop = False | |
if ((tensor!=tensor).any()): | |
print(name + " has nans") | |
stop = True | |
if (torch.isinf(tensor).any()): | |
print(name + " has infs") | |
stop = True | |
if stop: | |
print(name + ": " + str(tensor)) | |
#raise ValueError(name + ' contains nans of infs') | |
class _ActNorm(nn.Module): | |
""" | |
Activation Normalization | |
Initialize the bias and scale with a given minibatch, | |
so that the output per-channel have zero mean and unit variance for that. | |
After initialization, `bias` and `logs` will be trained as parameters. | |
""" | |
def __init__(self, num_features, scale=1.): | |
super().__init__() | |
# register mean and scale | |
size = [1, num_features, 1] | |
self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) | |
self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) | |
self.num_features = num_features | |
self.scale = float(scale) | |
# self.inited = False | |
self.register_buffer('is_initialized', torch.zeros(1)) | |
def _check_input_dim(self, input): | |
return NotImplemented | |
def initialize_parameters(self, input): | |
# print("HOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOo") | |
self._check_input_dim(input) | |
if not self.training: | |
return | |
assert input.device == self.bias.device | |
with torch.no_grad(): | |
bias = thops.mean(input.clone(), dim=[0, 2], keepdim=True) * -1.0 | |
vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2], keepdim=True) | |
logs = torch.log(self.scale/(torch.sqrt(vars)+1e-6)) | |
self.bias.data.copy_(bias.data) | |
self.logs.data.copy_(logs.data) | |
# self.inited = True | |
self.is_initialized += 1. | |
def _center(self, input, reverse=False): | |
if not reverse: | |
return input + self.bias | |
else: | |
return input - self.bias | |
def _scale(self, input, logdet=None, reverse=False): | |
logs = self.logs | |
if not reverse: | |
input = input * torch.exp(logs) | |
else: | |
input = input * torch.exp(-logs) | |
if logdet is not None: | |
""" | |
logs is log_std of `mean of channels` | |
so we need to multiply timesteps | |
""" | |
dlogdet = thops.sum(logs) * thops.timesteps(input) | |
if reverse: | |
dlogdet *= -1 | |
logdet = logdet + dlogdet | |
return input, logdet | |
def forward(self, input, logdet=None, reverse=False): | |
if not self.is_initialized: | |
self.initialize_parameters(input) | |
self._check_input_dim(input) | |
# no need to permute dims as old version | |
if not reverse: | |
# center and scale | |
input = self._center(input, reverse) | |
input, logdet = self._scale(input, logdet, reverse) | |
else: | |
# scale and center | |
input, logdet = self._scale(input, logdet, reverse) | |
input = self._center(input, reverse) | |
return input, logdet | |
class ActNorm2d(_ActNorm): | |
def __init__(self, num_features, scale=1.): | |
super().__init__(num_features, scale) | |
def _check_input_dim(self, input): | |
assert len(input.size()) == 3 | |
assert input.size(1) == self.num_features, ( | |
"[ActNorm]: input should be in shape as `BCT`," | |
" channels should be {} rather than {}".format( | |
self.num_features, input.size())) | |
class LinearZeros(nn.Linear): | |
def __init__(self, in_channels, out_channels, logscale_factor=3): | |
super().__init__(in_channels, out_channels) | |
self.logscale_factor = logscale_factor | |
# set logs parameter | |
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels))) | |
# init | |
self.weight.data.zero_() | |
self.bias.data.zero_() | |
def forward(self, input): | |
output = super().forward(input) | |
return output * torch.exp(self.logs * self.logscale_factor) | |
class Conv2d(nn.Conv2d): | |
pad_dict = { | |
"same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)], | |
"valid": lambda kernel, stride: [0 for _ in kernel] | |
} | |
def get_padding(padding, kernel_size, stride): | |
# make paddding | |
if isinstance(padding, str): | |
if isinstance(kernel_size, int): | |
kernel_size = [kernel_size, kernel_size] | |
if isinstance(stride, int): | |
stride = [stride, stride] | |
padding = padding.lower() | |
try: | |
padding = Conv2d.pad_dict[padding](kernel_size, stride) | |
except KeyError: | |
raise ValueError("{} is not supported".format(padding)) | |
return padding | |
def __init__(self, in_channels, out_channels, | |
kernel_size=[3, 3], stride=[1, 1], | |
padding="same", do_actnorm=True, weight_std=0.05): | |
padding = Conv2d.get_padding(padding, kernel_size, stride) | |
super().__init__(in_channels, out_channels, kernel_size, stride, | |
padding, bias=(not do_actnorm)) | |
# init weight with std | |
self.weight.data.normal_(mean=0.0, std=weight_std) | |
if not do_actnorm: | |
self.bias.data.zero_() | |
else: | |
self.actnorm = ActNorm2d(out_channels) | |
self.do_actnorm = do_actnorm | |
def forward(self, input): | |
x = super().forward(input) | |
if self.do_actnorm: | |
x, _ = self.actnorm(x) | |
return x | |
class Conv2dZeros(nn.Conv2d): | |
def __init__(self, in_channels, out_channels, | |
kernel_size=[3, 3], stride=[1, 1], | |
padding="same", logscale_factor=3): | |
padding = Conv2d.get_padding(padding, kernel_size, stride) | |
super().__init__(in_channels, out_channels, kernel_size, stride, padding) | |
# logscale_factor | |
self.logscale_factor = logscale_factor | |
self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) | |
# init | |
self.weight.data.zero_() | |
self.bias.data.zero_() | |
def forward(self, input): | |
output = super().forward(input) | |
return output * torch.exp(self.logs * self.logscale_factor) | |
class LinearNormInit(nn.Linear): | |
def __init__(self, in_channels, out_channels, weight_std=0.05): | |
super().__init__(in_channels, out_channels) | |
# init | |
self.weight.data.normal_(mean=0.0, std=weight_std) | |
self.bias.data.zero_() | |
class LinearZeroInit(nn.Linear): | |
def __init__(self, in_channels, out_channels): | |
super().__init__(in_channels, out_channels) | |
# init | |
self.weight.data.zero_() | |
self.bias.data.zero_() | |
class Permute2d(nn.Module): | |
def __init__(self, num_channels, shuffle): | |
super().__init__() | |
self.num_channels = num_channels | |
print(num_channels) | |
self.indices = np.arange(self.num_channels - 1, -1,-1).astype(np.long) | |
self.indices_inverse = np.zeros((self.num_channels), dtype=np.long) | |
print(self.indices_inverse.shape) | |
for i in range(self.num_channels): | |
self.indices_inverse[self.indices[i]] = i | |
if shuffle: | |
self.reset_indices() | |
def reset_indices(self): | |
np.random.shuffle(self.indices) | |
for i in range(self.num_channels): | |
self.indices_inverse[self.indices[i]] = i | |
def forward(self, input, reverse=False): | |
assert len(input.size()) == 3 | |
if not reverse: | |
return input[:, self.indices, :] | |
else: | |
return input[:, self.indices_inverse, :] | |
class InvertibleConv1x1(nn.Module): | |
def __init__(self, num_channels, LU_decomposed=False): | |
super().__init__() | |
w_shape = [num_channels, num_channels] | |
w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) | |
if not LU_decomposed: | |
# Sample a random orthogonal matrix: | |
self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) | |
else: | |
np_p, np_l, np_u = scipy.linalg.lu(w_init) | |
np_s = np.diag(np_u) | |
np_sign_s = np.sign(np_s) | |
np_log_s = np.log(np.abs(np_s)) | |
np_u = np.triu(np_u, k=1) | |
l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1) | |
eye = np.eye(*w_shape, dtype=np.float32) | |
#self.p = torch.Tensor(np_p.astype(np.float32)) | |
#self.sign_s = torch.Tensor(np_sign_s.astype(np.float32)) | |
self.register_buffer('p', torch.Tensor(np_p.astype(np.float32))) | |
self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32))) | |
self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32))) | |
self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32))) | |
self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32))) | |
self.l_mask = torch.Tensor(l_mask) | |
self.eye = torch.Tensor(eye) | |
self.w_shape = w_shape | |
self.LU = LU_decomposed | |
self.first_pass = True | |
self.saved_weight = None | |
self.saved_dlogdet = None | |
def get_weight(self, input, reverse): | |
w_shape = self.w_shape | |
if not self.LU: | |
timesteps = thops.timesteps(input) | |
dlogdet = torch.slogdet(self.weight)[1] * timesteps | |
if not reverse: | |
weight = self.weight.view(w_shape[0], w_shape[1], 1) | |
else: | |
weight = torch.inverse(self.weight.double()).float()\ | |
.view(w_shape[0], w_shape[1], 1) | |
return weight, dlogdet | |
else: | |
self.p = self.p.to(input.device) | |
self.sign_s = self.sign_s.to(input.device) | |
self.l_mask = self.l_mask.to(input.device) | |
self.eye = self.eye.to(input.device) | |
l = self.l * self.l_mask + self.eye | |
u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) | |
dlogdet = thops.sum(self.log_s) * thops.timesteps(input) | |
if not reverse: | |
w = torch.matmul(self.p, torch.matmul(l, u)) | |
else: | |
l = torch.inverse(l.double()).float() | |
u = torch.inverse(u.double()).float() | |
w = torch.matmul(u, torch.matmul(l, self.p.inverse())) | |
return w.view(w_shape[0], w_shape[1], 1), dlogdet | |
def forward(self, input, logdet=None, reverse=False): | |
""" | |
log-det = log|abs(|W|)| * timesteps | |
""" | |
# weight, dlogdet = self.get_weight(input, reverse) | |
if not reverse: | |
weight, dlogdet = self.get_weight(input, reverse) | |
else: | |
if self.first_pass: | |
weight, dlogdet = self.get_weight(input, reverse) | |
self.saved_weight = weight | |
if logdet is not None: | |
self.saved_dlogdet = dlogdet | |
self.first_pass = False | |
else: | |
weight = self.saved_weight | |
if logdet is not None: | |
dlogdet = self.saved_dlogdet | |
nan_throw(weight, "weight") | |
nan_throw(dlogdet, "dlogdet") | |
if not reverse: | |
z = F.conv1d(input, weight) | |
if logdet is not None: | |
logdet = logdet + dlogdet | |
return z, logdet | |
else: | |
nan_throw(input, "InConv input") | |
z = F.conv1d(input, weight) | |
nan_throw(z, "InConv z") | |
nan_throw(logdet, "InConv logdet") | |
if logdet is not None: | |
logdet = logdet - dlogdet | |
return z, logdet | |
# Here we define our model as a class | |
class LSTM(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim=1, num_layers=2, dropout=0.0): | |
super(LSTM, self).__init__() | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.num_layers = num_layers | |
# Define the LSTM layer | |
self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True) | |
# Define the output layer | |
self.linear = LinearZeroInit(self.hidden_dim, output_dim) | |
# do_init | |
self.do_init = True | |
def init_hidden(self): | |
# This is what we'll initialise our hidden state as | |
self.do_init = True | |
def forward(self, input): | |
# Forward pass through LSTM layer | |
# shape of lstm_out: [batch_size, input_size, hidden_dim] | |
# shape of self.hidden: (a, b), where a and b both | |
# have shape (batch_size, num_layers, hidden_dim). | |
if self.do_init: | |
lstm_out, self.hidden = self.lstm(input) | |
self.do_init = False | |
else: | |
lstm_out, self.hidden = self.lstm(input, self.hidden) | |
#self.hidden = hidden[0].to(input.device), hidden[1].to(input.device) | |
# Final layer | |
y_pred = self.linear(lstm_out) | |
return y_pred | |
# Here we define our model as a class | |
class GRU(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim=1, num_layers=2, dropout=0.0): | |
super(GRU, self).__init__() | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.num_layers = num_layers | |
# Define the LSTM layer | |
self.gru = nn.GRU(self.input_dim, self.hidden_dim, self.num_layers, batch_first=True) | |
# Define the output layer | |
self.linear = LinearZeroInit(self.hidden_dim, output_dim) | |
# do_init | |
self.do_init = True | |
def init_hidden(self): | |
# This is what we'll initialise our hidden state as | |
self.do_init = True | |
def forward(self, input): | |
# Forward pass through LSTM layer | |
# shape of lstm_out: [batch_size, input_size, hidden_dim] | |
# shape of self.hidden: (a, b), where a and b both | |
# have shape (batch_size, num_layers, hidden_dim). | |
if self.do_init: | |
gru_out, self.hidden = self.gru(input) | |
self.do_init = False | |
else: | |
gru_out, self.hidden = self.gru(input, self.hidden) | |
#self.hidden = hidden[0].to(input.device), hidden[1].to(input.device) | |
# Final layer | |
y_pred = self.linear(gru_out) | |
return y_pred | |
class GaussianDiag: | |
Log2PI = float(np.log(2 * np.pi)) | |
def likelihood(x): | |
""" | |
lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } | |
k = 1 (Independent) | |
Var = logs ** 2 | |
""" | |
return -0.5 * (((x) ** 2) + GaussianDiag.Log2PI) | |
def logp(x): | |
likelihood = GaussianDiag.likelihood(x) | |
return thops.sum(likelihood, dim=[1, 2]) | |
def sample(z_shape, eps_std=None, device=None): | |
eps_std = eps_std or 1 | |
eps = torch.normal(mean=torch.zeros(z_shape), | |
std=torch.ones(z_shape) * eps_std) | |
eps = eps.to(device) | |
return eps | |
class StudentT: | |
def __init__(self, df, d): | |
self.df=df | |
self.d=d | |
self.norm_const = scipy.special.loggamma(0.5*(df+d))-scipy.special.loggamma(0.5*df)-0.5*d*np.log(np.pi*df) | |
def logp(self,x): | |
''' | |
Multivariate t-student density: | |
output: | |
the sum density of the given element | |
''' | |
#df=100 | |
#d=x.shape[1] | |
#norm_const = scipy.special.loggamma(0.5*(df+d))-scipy.special.loggamma(0.5*df)-0.5*d*np.log(np.pi*df) | |
#import pdb; pdb.set_trace() | |
x_norms = thops.sum(((x) ** 2), dim=[1]) | |
likelihood = self.norm_const-0.5*(self.df+self.d)*torch.log(1+(1/self.df)*x_norms) | |
return thops.sum(likelihood, dim=[1]) | |
def sample(self,z_shape, eps_std=None, device=None): | |
'''generate random variables of multivariate t distribution | |
Parameters | |
---------- | |
m : array_like | |
mean of random variable, length determines dimension of random variable | |
S : array_like | |
square array of covariance matrix | |
df : int or float | |
degrees of freedom | |
n : int | |
number of observations, return random array will be (n, len(m)) | |
Returns | |
------- | |
rvs : ndarray, (n, len(m)) | |
each row is an independent draw of a multivariate t distributed | |
random variable | |
''' | |
#df=100 | |
# import pdb; pdb.set_trace() | |
x_shape = torch.Size((z_shape[0], 1, z_shape[2])) | |
x = np.random.chisquare(self.df, x_shape)/self.df | |
x = np.tile(x, (1,z_shape[1],1)) | |
x = torch.Tensor(x.astype(np.float32)) | |
z = torch.normal(mean=torch.zeros(z_shape),std=torch.ones(z_shape) * eps_std) | |
# import pdb; pdb.set_trace() | |
return (z/torch.sqrt(x)).to(device) | |
class Split2d(nn.Module): | |
def __init__(self, num_channels): | |
super().__init__() | |
print("Split2d num_channels:" + str(num_channels)) | |
self.num_channels = num_channels | |
self.conv = Conv2dZeros(num_channels // 2, num_channels) | |
def split2d_prior(self, z): | |
h = self.conv(z) | |
return thops.split_feature(h, "cross") | |
def forward(self, input, cond, logdet=0., reverse=False, eps_std=None): | |
if not reverse: | |
#print("forward Split2d input:" + str(input.shape)) | |
z1, z2 = thops.split_feature(input, "split") | |
#mean, logs = self.split2d_prior(z1) | |
logdet = GaussianDiag.logp(z2) + logdet | |
return z1, cond, logdet | |
else: | |
z1 = input | |
#print("reverse Split2d z1.shape:" + str(z1.shape)) | |
#mean, logs = self.split2d_prior(z1) | |
z2_shape = list(z1.shape) | |
z2_shape[1] = self.num_channels-z1.shape[1] | |
z2 = GaussianDiag.sample(z2_shape, eps_std, device=input.device) | |
z = thops.cat_feature(z1, z2) | |
return z, cond, logdet | |
def squeeze2d(input, factor=2): | |
assert factor >= 1 and isinstance(factor, int) | |
if factor == 1: | |
return input | |
size = input.size() | |
B = size[0] | |
C = size[1] | |
H = size[2] | |
W = size[3] | |
assert H % factor == 0 , "{}".format((H, W)) | |
x = input.view(B, C, H // factor, factor, W, 1) | |
x = x.permute(0, 1, 3, 5, 2, 4).contiguous() | |
x = x.view(B, C * factor, H // factor, W) | |
return x | |
def unsqueeze2d(input, factor=2): | |
assert factor >= 1 and isinstance(factor, int) | |
#factor2 = factor ** 2 | |
if factor == 1: | |
return input | |
size = input.size() | |
B = size[0] | |
C = size[1] | |
H = size[2] | |
W = size[3] | |
assert C % (factor) == 0, "{}".format(C) | |
x = input.view(B, C // factor, factor, 1, H, W) | |
x = x.permute(0, 1, 4, 2, 5, 3).contiguous() | |
x = x.view(B, C // (factor), H * factor, W) | |
return x | |
class SqueezeLayer(nn.Module): | |
def __init__(self, factor): | |
super().__init__() | |
self.factor = factor | |
def forward(self, input, cond = None, logdet=None, reverse=False): | |
if not reverse: | |
output = squeeze2d(input, self.factor) | |
cond_out = squeeze2d(cond, self.factor) | |
return output, cond_out, logdet | |
else: | |
output = unsqueeze2d(input, self.factor) | |
cond_output = unsqueeze2d(cond, self.factor) | |
return output, cond_output, logdet | |
def squeeze_cond(self, cond): | |
cond_out = squeeze2d(cond, self.factor) | |
return cond_out | |