Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
from models.flowplusplus import log_dist as logistic | |
from models.flowplusplus.nn import NN | |
from models.flowplusplus.transformer_nn import TransformerNN | |
class Coupling(nn.Module): | |
"""Mixture-of-Logistics Coupling layer in Flow++ | |
Args: | |
in_channels (int): Number of channels in the input. | |
mid_channels (int): Number of channels in the transformation network. | |
num_blocks (int): Number of residual blocks in the transformation network. | |
num_components (int): Number of components in the mixture. | |
drop_prob (float): Dropout probability. | |
use_attn (bool): Use attention in the NN blocks. | |
aux_channels (int): Number of channels in optional auxiliary input. | |
""" | |
def __init__(self, in_channels, cond_dim, out_channels, mid_channels, num_blocks, num_components, drop_prob, seq_length, output_length, | |
use_attn=True, use_logmix=True, use_transformer_nn=False, use_pos_emb=False, use_rel_pos_emb=False, num_heads=10, aux_channels=None, concat_dims=True): | |
super(Coupling, self).__init__() | |
if use_transformer_nn: | |
if concat_dims: | |
self.nn = TransformerNN(in_channels, out_channels, mid_channels, num_blocks, num_heads, num_components, drop_prob=drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=seq_length, concat_dims=concat_dims, output_length=output_length) | |
else: | |
self.nn = TransformerNN(cond_dim, out_channels, mid_channels, num_blocks, num_heads, num_components, drop_prob=drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=seq_length, concat_dims=concat_dims, output_length=output_length) | |
else: | |
self.nn = NN(in_channels, out_channels, mid_channels, num_blocks, num_components, drop_prob, use_attn, aux_channels) | |
if not concat_dims: | |
self.input_encoder = nn.Linear(in_channels,cond_dim) | |
self.use_logmix = use_logmix | |
self.offset = 2.0 | |
self.sigmoid_offset = 1 - 1 / (1 + math.exp(-self.offset)) | |
self.cond_dim = cond_dim | |
self.concat_dims = concat_dims | |
def forward(self, x, cond, sldj=None, reverse=False, aux=None): | |
x_change, x_id = x | |
if self.concat_dims: | |
x_id_cond = torch.cat((x_id, cond), dim=1) | |
else: | |
# import pdb;pdb.set_trace() | |
x_id_enc = self.input_encoder(x_id.permute(0,2,3,1)).permute(0,3,1,2) | |
#import pdb;pdb.set_trace() | |
x_id_cond = torch.cat((x_id_enc, cond), dim=2) | |
#import pdb;pdb.set_trace() | |
a, b, pi, mu, s = self.nn(x_id_cond, aux) | |
# import pdb;pdb.set_trace() | |
scale = (torch.sigmoid(a+self.offset)+self.sigmoid_offset) | |
if reverse: | |
out = x_change / scale - b | |
if self.use_logmix: | |
out, scale_ldj = logistic.inverse(out, reverse=True) | |
#out = out.clamp(1e-5, 1. - 1e-5) | |
out = logistic.mixture_inv_cdf(out, pi, mu, s) | |
logistic_ldj = logistic.mixture_log_pdf(out, pi, mu, s) | |
sldj = sldj - (torch.log(scale) + scale_ldj + logistic_ldj).flatten(1).sum(-1) | |
else: | |
sldj = sldj - torch.log(scale).flatten(1).sum(-1) | |
else: | |
if self.use_logmix: | |
out = logistic.mixture_log_cdf(x_change, pi, mu, s).exp() | |
out, scale_ldj = logistic.inverse(out) | |
logistic_ldj = logistic.mixture_log_pdf(x_change, pi, mu, s) | |
sldj = sldj + (logistic_ldj + scale_ldj + torch.log(scale)).flatten(1).sum(-1) | |
else: | |
out = x_change | |
sldj = sldj + torch.log(scale).flatten(1).sum(-1) | |
out = (out + b) * scale | |
x = (out, x_id) | |
return x, sldj | |