kevinwang676's picture
Upload 93 files
9016314 verified
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from .networks import dense_nn, cond_dense_nn
class CondVAE(object):
def __init__(self, hps, name="cvae"):
self.hps = hps
self.name = name
def enc(self, x, cond=None):
'''
x: [B, C]
cond: [B, C]
'''
B,C = tf.shape(input=x)[0], tf.shape(input=x)[1]
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
prior_dist = tfd.MultivariateNormalDiag(tf.zeros(self.hps['hid_dimensions']),tf.ones(self.hps['hid_dimensions']))
if cond is None:
x = dense_nn(x, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc")
else:
x = cond_dense_nn(x, cond, self.hps['enc_dense_hids'], 2 * self.hps['hid_dimensions'], False, "enc")
m, s = x[:, :self.hps['hid_dimensions']], tf.nn.softplus(x[:, self.hps['hid_dimensions']:])
posterior_dist = tfd.MultivariateNormalDiag(m,s)
#kl = 0.5 * tf.reduce_sum(s + m ** 2 - 1.0 - tf.log(s), axis=-1)
kl = - tfd.kl_divergence(posterior_dist, prior_dist)
eps = prior_dist.sample(B)
posterior_sample = m + eps * s
return kl, posterior_sample
def dec(self, x, cond=None):
'''
x: [B, C]
'''
B,C = tf.shape(input=x)[0], tf.shape(input=x)[1]
with tf.compat.v1.variable_scope(self.name, reuse=tf.compat.v1.AUTO_REUSE):
if cond is None:
x = dense_nn(x, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec")
else:
x = cond_dense_nn(x, cond, self.hps['dec_dense_hids'], 2 * self.hps['dimension'], False, "dec")
m, s = x[:, :self.hps['dimension']], tf.nn.softplus(x[:, self.hps['dimension']:])
sample_dist = tfd.MultivariateNormalDiag(loc=m, scale_diag=s)
return sample_dist