Spaces:
Runtime error
Runtime error
File size: 3,894 Bytes
bc32eea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import math
import torch
import torch.nn as nn
from models.flowplusplus import log_dist as logistic
from models.flowplusplus.nn import NN
from models.flowplusplus.transformer_nn import TransformerNN
class Coupling(nn.Module):
"""Mixture-of-Logistics Coupling layer in Flow++
Args:
in_channels (int): Number of channels in the input.
mid_channels (int): Number of channels in the transformation network.
num_blocks (int): Number of residual blocks in the transformation network.
num_components (int): Number of components in the mixture.
drop_prob (float): Dropout probability.
use_attn (bool): Use attention in the NN blocks.
aux_channels (int): Number of channels in optional auxiliary input.
"""
def __init__(self, in_channels, cond_dim, out_channels, mid_channels, num_blocks, num_components, drop_prob, seq_length, output_length,
use_attn=True, use_logmix=True, use_transformer_nn=False, use_pos_emb=False, use_rel_pos_emb=False, num_heads=10, aux_channels=None, concat_dims=True):
super(Coupling, self).__init__()
if use_transformer_nn:
if concat_dims:
self.nn = TransformerNN(in_channels, out_channels, mid_channels, num_blocks, num_heads, num_components, drop_prob=drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=seq_length, concat_dims=concat_dims, output_length=output_length)
else:
self.nn = TransformerNN(cond_dim, out_channels, mid_channels, num_blocks, num_heads, num_components, drop_prob=drop_prob, use_pos_emb=use_pos_emb, use_rel_pos_emb=use_rel_pos_emb, input_length=seq_length, concat_dims=concat_dims, output_length=output_length)
else:
self.nn = NN(in_channels, out_channels, mid_channels, num_blocks, num_components, drop_prob, use_attn, aux_channels)
if not concat_dims:
self.input_encoder = nn.Linear(in_channels,cond_dim)
self.use_logmix = use_logmix
self.offset = 2.0
self.sigmoid_offset = 1 - 1 / (1 + math.exp(-self.offset))
self.cond_dim = cond_dim
self.concat_dims = concat_dims
def forward(self, x, cond, sldj=None, reverse=False, aux=None):
x_change, x_id = x
if self.concat_dims:
x_id_cond = torch.cat((x_id, cond), dim=1)
else:
# import pdb;pdb.set_trace()
x_id_enc = self.input_encoder(x_id.permute(0,2,3,1)).permute(0,3,1,2)
#import pdb;pdb.set_trace()
x_id_cond = torch.cat((x_id_enc, cond), dim=2)
#import pdb;pdb.set_trace()
a, b, pi, mu, s = self.nn(x_id_cond, aux)
# import pdb;pdb.set_trace()
scale = (torch.sigmoid(a+self.offset)+self.sigmoid_offset)
if reverse:
out = x_change / scale - b
if self.use_logmix:
out, scale_ldj = logistic.inverse(out, reverse=True)
#out = out.clamp(1e-5, 1. - 1e-5)
out = logistic.mixture_inv_cdf(out, pi, mu, s)
logistic_ldj = logistic.mixture_log_pdf(out, pi, mu, s)
sldj = sldj - (torch.log(scale) + scale_ldj + logistic_ldj).flatten(1).sum(-1)
else:
sldj = sldj - torch.log(scale).flatten(1).sum(-1)
else:
if self.use_logmix:
out = logistic.mixture_log_cdf(x_change, pi, mu, s).exp()
out, scale_ldj = logistic.inverse(out)
logistic_ldj = logistic.mixture_log_pdf(x_change, pi, mu, s)
sldj = sldj + (logistic_ldj + scale_ldj + torch.log(scale)).flatten(1).sum(-1)
else:
out = x_change
sldj = sldj + torch.log(scale).flatten(1).sum(-1)
out = (out + b) * scale
x = (out, x_id)
return x, sldj
|