myTest01 / models /moglow /models.py
meng2003's picture
Upload 85 files
bc32eea
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from . import thops
from . import modules
from . import utils
from models.transformer import BasicTransformerModelCausal
def nan_throw(tensor, name="tensor"):
stop = False
if ((tensor!=tensor).any()):
print(name + " has nans")
stop = True
if (torch.isinf(tensor).any()):
print(name + " has infs")
stop = True
if stop:
print(name + ": " + str(tensor))
#raise ValueError(name + ' contains nans of infs')
def f(in_channels, out_channels, hidden_channels, cond_channels, network_model, num_layers):
if network_model=="transformer":
#return BasicTransformerModel(out_channels, in_channels + cond_channels, 10, hidden_channels, num_layers, use_pos_emb=True)
return BasicTransformerModelCausal(out_channels, in_channels + cond_channels, 10, hidden_channels, num_layers, use_pos_emb=True, input_length=70)
if network_model=="LSTM":
return modules.LSTM(in_channels + cond_channels, hidden_channels, out_channels, num_layers)
if network_model=="GRU":
return modules.GRU(in_channels + cond_channels, hidden_channels, out_channels, num_layers)
if network_model=="FF":
return nn.Sequential(
nn.Linear(in_channels+cond_channels, hidden_channels), nn.ReLU(inplace=False),
nn.Linear(hidden_channels, hidden_channels), nn.ReLU(inplace=False),
modules.LinearZeroInit(hidden_channels, out_channels))
class FlowStep(nn.Module):
FlowCoupling = ["additive", "affine"]
NetworkModel = ["transformer","LSTM", "GRU", "FF"]
FlowPermutation = {
"reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
"shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
"invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev)
}
def __init__(self, in_channels, hidden_channels, cond_channels,
actnorm_scale=1.0,
flow_permutation="invconv",
flow_coupling="additive",
network_model="LSTM",
num_layers=2,
LU_decomposed=False):
# check configures
assert flow_coupling in FlowStep.FlowCoupling,\
"flow_coupling should be in `{}`".format(FlowStep.FlowCoupling)
assert network_model in FlowStep.NetworkModel,\
"network_model should be in `{}`".format(FlowStep.NetworkModel)
assert flow_permutation in FlowStep.FlowPermutation,\
"float_permutation should be in `{}`".format(
FlowStep.FlowPermutation.keys())
super().__init__()
self.flow_permutation = flow_permutation
self.flow_coupling = flow_coupling
self.network_model = network_model
# 1. actnorm
self.actnorm = modules.ActNorm2d(in_channels, actnorm_scale)
# 2. permute
if flow_permutation == "invconv":
self.invconv = modules.InvertibleConv1x1(
in_channels, LU_decomposed=LU_decomposed)
elif flow_permutation == "shuffle":
self.shuffle = modules.Permute2d(in_channels, shuffle=True)
else:
self.reverse = modules.Permute2d(in_channels, shuffle=False)
# 3. coupling
if flow_coupling == "additive":
self.f = f(in_channels // 2, in_channels-in_channels // 2, hidden_channels, cond_channels, network_model, num_layers)
elif flow_coupling == "affine":
print("affine: in_channels = " + str(in_channels))
self.f = f(in_channels // 2, 2*(in_channels-in_channels // 2), hidden_channels, cond_channels, network_model, num_layers)
print("Flowstep affine layer: " + str(in_channels))
def init_lstm_hidden(self):
if self.network_model == "LSTM" or self.network_model == "GRU":
self.f.init_hidden()
def forward(self, input, cond, logdet=None, reverse=False):
if not reverse:
return self.normal_flow(input, cond, logdet)
else:
return self.reverse_flow(input, cond, logdet)
def normal_flow(self, input, cond, logdet):
#assert input.size(1) % 2 == 0
# 1. actnorm
#z=input
z, logdet = self.actnorm(input, logdet=logdet, reverse=False)
# 2. permute
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
self, z, logdet, False)
# 3. coupling
z1, z2 = thops.split_feature(z, "split")
z1_cond = torch.cat((z1, cond), dim=1)
if self.flow_coupling == "additive":
z2 = z2 + self.f(z1_cond)
elif self.flow_coupling == "affine":
# import pdb;pdb.set_trace()
if self.network_model=="transformer":
h = self.f(z1_cond.permute(2,0,1)).permute(1,2,0)
else:
h = self.f(z1_cond.permute(0, 2, 1)).permute(0, 2, 1)
shift, scale = thops.split_feature(h, "cross")
scale = torch.sigmoid(scale + 2.)+1e-6
z2 = z2 + shift
z2 = z2 * scale
logdet = thops.sum(torch.log(scale), dim=[1, 2]) + logdet
z = thops.cat_feature(z1, z2)
return z, cond, logdet
def reverse_flow(self, input, cond, logdet):
# 1.coupling
z1, z2 = thops.split_feature(input, "split")
# import pdb;pdb.set_trace()
z1_cond = torch.cat((z1, cond), dim=1)
if self.flow_coupling == "additive":
z2 = z2 - self.f(z1_cond)
elif self.flow_coupling == "affine":
h = self.f(z1_cond.permute(0, 2, 1)).permute(0, 2, 1)
shift, scale = thops.split_feature(h, "cross")
nan_throw(shift, "shift")
nan_throw(scale, "scale")
nan_throw(z2, "z2 unscaled")
scale = torch.sigmoid(scale + 2.)+1e-6
z2 = z2 / scale
z2 = z2 - shift
logdet = -thops.sum(torch.log(scale), dim=[1, 2]) + logdet
z = thops.cat_feature(z1, z2)
# 2. permute
z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
self, z, logdet, True)
nan_throw(z, "z permute_" + str(self.flow_permutation))
# 3. actnorm
z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
return z, cond, logdet
class FlowNet(nn.Module):
def __init__(self, x_channels, hidden_channels, cond_channels, K,
actnorm_scale=1.0,
flow_permutation="invconv",
flow_coupling="additive",
network_model="LSTM",
num_layers=2,
LU_decomposed=False):
super().__init__()
self.layers = nn.ModuleList()
self.output_shapes = []
self.K = K
N = cond_channels
for _ in range(K):
self.layers.append(
FlowStep(in_channels=x_channels,
hidden_channels=hidden_channels,
cond_channels=N,
actnorm_scale=actnorm_scale,
flow_permutation=flow_permutation,
flow_coupling=flow_coupling,
network_model=network_model,
num_layers=num_layers,
LU_decomposed=LU_decomposed))
self.output_shapes.append(
[-1, x_channels, 1])
# import pdb;pdb.set_trace()
def init_lstm_hidden(self):
for layer in self.layers:
if isinstance(layer, FlowStep):
layer.init_lstm_hidden()
def forward(self, z, cond, logdet=0., reverse=False, eps_std=None):
if not reverse:
for layer in self.layers:
z, cond, logdet = layer(z, cond, logdet, reverse=False)
return z, logdet
else:
for i,layer in enumerate(reversed(self.layers)):
z, cond, logdet = layer(z, cond, logdet=0, reverse=True)
return z
class Glow(nn.Module):
def __init__(self, x_channels, cond_channels, opt):
super().__init__()
self.flow = FlowNet(x_channels=x_channels,
hidden_channels=opt.dhid,
cond_channels=cond_channels,
K=opt.glow_K,
actnorm_scale=opt.actnorm_scale,
flow_permutation=opt.flow_permutation,
flow_coupling=opt.flow_coupling,
network_model=opt.network_model,
num_layers=opt.num_layers,
LU_decomposed=opt.LU_decomposed)
self.opt = opt
# register prior hidden
# num_device = len(utils.get_proper_device(hparams.Device.glow, False))
# assert hparams.Train.batch_size % num_device == 0
# self.z_shape = [opt.batch_size // num_device, x_channels, 1]
self.z_shape = [opt.batch_size, x_channels, 1]
if opt.flow_dist == "normal":
self.distribution = modules.GaussianDiag()
elif opt.flow_dist == "studentT":
self.distribution = modules.StudentT(opt.flow_dist_param, x_channels)
def init_lstm_hidden(self):
self.flow.init_lstm_hidden()
def forward(self, x=None, cond=None, z=None,
eps_std=None, reverse=False, output_length=1):
if not reverse:
return self.normal_flow(x, cond)
else:
return self.reverse_flow(z, cond, eps_std, output_length=output_length)
def normal_flow(self, x, cond):
n_timesteps = thops.timesteps(x) #just returns the size of dimension 2?
logdet = torch.zeros_like(x[:, 0, 0])
# encode
z, objective = self.flow(x, cond, logdet=logdet, reverse=False)
# prior
objective += self.distribution.logp(z)
# return
nll = (-objective) / float(np.log(2.) * n_timesteps)
return z, nll
def reverse_flow(self, z, cond, eps_std, output_length=1):
with torch.no_grad():
z_shape = self.z_shape
z_shape[-1] = output_length
if z is None:
z = self.distribution.sample(z_shape, eps_std, device=cond.device)
x = self.flow(z, cond, eps_std=eps_std, reverse=True)
return x
def set_actnorm_init(self, inited=True):
for name, m in self.named_modules():
if (m.__class__.__name__.find("ActNorm") >= 0):
m.inited = inited
@staticmethod
def loss_generative(nll):
# Generative loss
return torch.mean(nll)