Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Model.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import functools | |
import sonnet as snt | |
import tensorflow as tf | |
import numpy as np | |
import math | |
SQUARED_OBSERVATION = "squared" | |
ABS_OBSERVATION = "abs" | |
STANDARD_OBSERVATION = "standard" | |
OBSERVATION_TYPES = [SQUARED_OBSERVATION, ABS_OBSERVATION, STANDARD_OBSERVATION] | |
ROUND_TRANSITION = "round" | |
STANDARD_TRANSITION = "standard" | |
TRANSITION_TYPES = [ROUND_TRANSITION, STANDARD_TRANSITION] | |
class Q(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
dtype=tf.float32, | |
random_seed=None, | |
init_mu0_to_zero=False, | |
graph_collection_name="Q_VARS"): | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.graph_collection_name = graph_collection_name | |
initializers = [] | |
for t in xrange(num_timesteps): | |
if t == 0 and init_mu0_to_zero: | |
initializers.append( | |
{"w": tf.zeros_initializer, "b": tf.zeros_initializer}) | |
else: | |
initializers.append( | |
{"w": tf.random_uniform_initializer(seed=random_seed), | |
"b": tf.zeros_initializer}) | |
def custom_getter(getter, *args, **kwargs): | |
out = getter(*args, **kwargs) | |
ref = tf.get_collection_ref(self.graph_collection_name) | |
if out not in ref: | |
ref.append(out) | |
return out | |
self.mus = [ | |
snt.Linear(output_size=state_size, | |
initializers=initializers[t], | |
name="q_mu_%d" % t, | |
custom_getter=custom_getter | |
) | |
for t in xrange(num_timesteps) | |
] | |
self.sigmas = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="q_sigma_%d" % (t + 1), | |
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], | |
initializer=tf.random_uniform_initializer(seed=random_seed)) | |
for t in xrange(num_timesteps) | |
] | |
def q_zt(self, observation, prev_state, t): | |
batch_size = tf.shape(prev_state)[0] | |
q_mu = self.mus[t](tf.concat([observation, prev_state], axis=1)) | |
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) | |
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) | |
return q_zt | |
def summarize_weights(self): | |
for t, sigma in enumerate(self.sigmas): | |
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) | |
for t, f in enumerate(self.mus): | |
tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) | |
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0]) | |
if t != 0: | |
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[1,0]) | |
class PreviousStateQ(Q): | |
def q_zt(self, unused_observation, prev_state, t): | |
batch_size = tf.shape(prev_state)[0] | |
q_mu = self.mus[t](prev_state) | |
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) | |
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) | |
return q_zt | |
def summarize_weights(self): | |
for t, sigma in enumerate(self.sigmas): | |
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) | |
for t, f in enumerate(self.mus): | |
tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) | |
tf.summary.scalar("q_mu/w_prev_state_%d" % t, f.w[0,0]) | |
class ObservationQ(Q): | |
def q_zt(self, observation, prev_state, t): | |
batch_size = tf.shape(prev_state)[0] | |
q_mu = self.mus[t](observation) | |
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) | |
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) | |
return q_zt | |
def summarize_weights(self): | |
for t, sigma in enumerate(self.sigmas): | |
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) | |
for t, f in enumerate(self.mus): | |
tf.summary.scalar("q_mu/b_%d" % t, f.b[0]) | |
tf.summary.scalar("q_mu/w_obs_%d" % t, f.w[0,0]) | |
class SimpleMeanQ(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
dtype=tf.float32, | |
random_seed=None, | |
init_mu0_to_zero=False, | |
graph_collection_name="Q_VARS"): | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.graph_collection_name = graph_collection_name | |
initializers = [] | |
for t in xrange(num_timesteps): | |
if t == 0 and init_mu0_to_zero: | |
initializers.append(tf.zeros_initializer) | |
else: | |
initializers.append(tf.random_uniform_initializer(seed=random_seed)) | |
self.mus = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="q_mu_%d" % (t + 1), | |
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], | |
initializer=initializers[t]) | |
for t in xrange(num_timesteps) | |
] | |
self.sigmas = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="q_sigma_%d" % (t + 1), | |
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], | |
initializer=tf.random_uniform_initializer(seed=random_seed)) | |
for t in xrange(num_timesteps) | |
] | |
def q_zt(self, unused_observation, prev_state, t): | |
batch_size = tf.shape(prev_state)[0] | |
q_mu = tf.tile(self.mus[t][tf.newaxis, :], [batch_size, 1]) | |
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) | |
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) | |
return q_zt | |
def summarize_weights(self): | |
for t, sigma in enumerate(self.sigmas): | |
tf.summary.scalar("q_sigma/%d" % t, sigma[0]) | |
for t, f in enumerate(self.mus): | |
tf.summary.scalar("q_mu/%d" % t, f[0]) | |
class R(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
dtype=tf.float32, | |
sigma_init=1., | |
random_seed=None, | |
graph_collection_name="R_VARS"): | |
self.dtype = dtype | |
self.sigma_min = sigma_min | |
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), | |
"b": tf.zeros_initializer} | |
self.graph_collection_name=graph_collection_name | |
def custom_getter(getter, *args, **kwargs): | |
out = getter(*args, **kwargs) | |
ref = tf.get_collection_ref(self.graph_collection_name) | |
if out not in ref: | |
ref.append(out) | |
return out | |
self.mus= [ | |
snt.Linear(output_size=state_size, | |
initializers=initializers, | |
name="r_mu_%d" % t, | |
custom_getter=custom_getter) | |
for t in xrange(num_timesteps) | |
] | |
self.sigmas = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="r_sigma_%d" % (t + 1), | |
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], | |
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100)) | |
initializer=tf.constant_initializer(sigma_init)) | |
for t in xrange(num_timesteps) | |
] | |
def r_xn(self, z_t, t): | |
batch_size = tf.shape(z_t)[0] | |
r_mu = self.mus[t](z_t) | |
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1]) | |
return tf.contrib.distributions.Normal( | |
loc=r_mu, scale=tf.sqrt(r_sigma)) | |
def summarize_weights(self): | |
for t in range(len(self.mus) - 1): | |
tf.summary.scalar("r_mu/%d" % t, self.mus[t][0]) | |
tf.summary.scalar("r_sigma/%d" % t, self.sigmas[t][0]) | |
class P(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
variance=1.0, | |
dtype=tf.float32, | |
random_seed=None, | |
trainable=True, | |
init_bs_to_zero=False, | |
graph_collection_name="P_VARS"): | |
self.state_size = state_size | |
self.num_timesteps = num_timesteps | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.variance = variance | |
self.graph_collection_name = graph_collection_name | |
if init_bs_to_zero: | |
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)] | |
else: | |
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)] | |
self.bs = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="p_b_%d" % (t + 1), | |
initializer=initializers[t], | |
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], | |
trainable=trainable) for t in xrange(num_timesteps) | |
] | |
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) | |
def posterior(self, observation, prev_state, t): | |
"""Computes the true posterior p(z_t|z_{t-1}, z_n).""" | |
# bs[0] is really b_1 | |
# Bs[i] is sum from k=i+1^n b_k | |
mu = observation - self.Bs[t] | |
if t > 0: | |
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t) | |
mu /= float(self.num_timesteps - t + 1) | |
sigma = tf.ones_like(mu) * self.variance * ( | |
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1)) | |
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) | |
def lookahead(self, state, t): | |
"""Computes the true lookahead distribution p(z_n|z_t).""" | |
mu = state + self.Bs[t] | |
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t) | |
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) | |
def likelihood(self, observation): | |
batch_size = tf.shape(observation)[0] | |
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) | |
sigma = tf.ones_like(mu) * self.variance * (self.num_timesteps + 1) | |
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) | |
# Average over the batch and take the sum over the state size | |
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1)) | |
def p_zt(self, prev_state, t): | |
"""Computes the model p(z_t| z_{t-1}).""" | |
batch_size = tf.shape(prev_state)[0] | |
if t > 0: | |
z_mu_p = prev_state + self.bs[t - 1] | |
else: # p(z_0) is Normal(0,1) | |
z_mu_p = tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
p_zt = tf.contrib.distributions.Normal( | |
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance)) | |
return p_zt | |
def generative(self, unused_observation, z_nm1): | |
"""Computes the model's generative distribution p(z_n| z_{n-1}).""" | |
generative_p_mu = z_nm1 + self.bs[-1] | |
return tf.contrib.distributions.Normal( | |
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance)) | |
class ShortChainNonlinearP(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
variance=1.0, | |
observation_variance=1.0, | |
transition_type=STANDARD_TRANSITION, | |
transition_dist=tf.contrib.distributions.Normal, | |
dtype=tf.float32, | |
random_seed=None): | |
self.state_size = state_size | |
self.num_timesteps = num_timesteps | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.variance = variance | |
self.observation_variance = observation_variance | |
self.transition_type = transition_type | |
self.transition_dist = transition_dist | |
def p_zt(self, prev_state, t): | |
"""Computes the model p(z_t| z_{t-1}).""" | |
batch_size = tf.shape(prev_state)[0] | |
if t > 0: | |
if self.transition_type == ROUND_TRANSITION: | |
loc = tf.round(prev_state) | |
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance)) | |
elif self.transition_type == STANDARD_TRANSITION: | |
loc = prev_state | |
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance)) | |
else: # p(z_0) is Normal(0,1) | |
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance) | |
p_zt = self.transition_dist( | |
loc=loc, | |
scale=tf.sqrt(tf.ones_like(loc) * self.variance)) | |
return p_zt | |
def generative(self, unused_obs, z_ni): | |
"""Computes the model's generative distribution p(x_i| z_{ni}).""" | |
if self.transition_type == ROUND_TRANSITION: | |
loc = tf.round(z_ni) | |
elif self.transition_type == STANDARD_TRANSITION: | |
loc = z_ni | |
generative_sigma_sq = tf.ones_like(loc) * self.observation_variance | |
return self.transition_dist( | |
loc=loc, scale=tf.sqrt(generative_sigma_sq)) | |
class BimodalPriorP(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
mixing_coeff=0.5, | |
prior_mode_mean=1, | |
sigma_min=1e-5, | |
variance=1.0, | |
dtype=tf.float32, | |
random_seed=None, | |
trainable=True, | |
init_bs_to_zero=False, | |
graph_collection_name="P_VARS"): | |
self.state_size = state_size | |
self.num_timesteps = num_timesteps | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.variance = variance | |
self.mixing_coeff = mixing_coeff | |
self.prior_mode_mean = prior_mode_mean | |
if init_bs_to_zero: | |
initializers = [tf.zeros_initializer for _ in xrange(num_timesteps)] | |
else: | |
initializers = [tf.random_uniform_initializer(seed=random_seed) for _ in xrange(num_timesteps)] | |
self.bs = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="b_%d" % (t + 1), | |
initializer=initializers[t], | |
collections=[tf.GraphKeys.GLOBAL_VARIABLES, graph_collection_name], | |
trainable=trainable) for t in xrange(num_timesteps) | |
] | |
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) | |
def posterior(self, observation, prev_state, t): | |
# NOTE: This is currently wrong, but would require a refactoring of | |
# summarize_q to fix as kl is not defined for a mixture | |
"""Computes the true posterior p(z_t|z_{t-1}, z_n).""" | |
# bs[0] is really b_1 | |
# Bs[i] is sum from k=i+1^n b_k | |
mu = observation - self.Bs[t] | |
if t > 0: | |
mu += (prev_state + self.bs[t - 1]) * float(self.num_timesteps - t) | |
mu /= float(self.num_timesteps - t + 1) | |
sigma = tf.ones_like(mu) * self.variance * ( | |
float(self.num_timesteps - t) / float(self.num_timesteps - t + 1)) | |
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) | |
def lookahead(self, state, t): | |
"""Computes the true lookahead distribution p(z_n|z_t).""" | |
mu = state + self.Bs[t] | |
sigma = tf.ones_like(state) * self.variance * float(self.num_timesteps - t) | |
return tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) | |
def likelihood(self, observation): | |
batch_size = tf.shape(observation)[0] | |
sum_of_bs = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) | |
sigma = tf.ones_like(sum_of_bs) * self.variance * (self.num_timesteps + 1) | |
mu_pos = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean) + sum_of_bs | |
mu_neg = (tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean) + sum_of_bs | |
zn_pos = tf.contrib.distributions.Normal( | |
loc=mu_pos, | |
scale=tf.sqrt(sigma)) | |
zn_neg = tf.contrib.distributions.Normal( | |
loc=mu_neg, | |
scale=tf.sqrt(sigma)) | |
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64) | |
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) | |
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs) | |
zn_dist = tf.contrib.distributions.Mixture( | |
cat=mode_selection_dist, | |
components=[zn_pos, zn_neg], | |
validate_args=True) | |
# Average over the batch and take the sum over the state size | |
return tf.reduce_mean(tf.reduce_sum(zn_dist.log_prob(observation), axis=1)) | |
def p_zt(self, prev_state, t): | |
"""Computes the model p(z_t| z_{t-1}).""" | |
batch_size = tf.shape(prev_state)[0] | |
if t > 0: | |
z_mu_p = prev_state + self.bs[t - 1] | |
p_zt = tf.contrib.distributions.Normal( | |
loc=z_mu_p, scale=tf.sqrt(tf.ones_like(z_mu_p) * self.variance)) | |
return p_zt | |
else: # p(z_0) is mixture of two Normals | |
mu_pos = tf.ones([batch_size, self.state_size], dtype=self.dtype) * self.prior_mode_mean | |
mu_neg = tf.ones([batch_size, self.state_size], dtype=self.dtype) * -self.prior_mode_mean | |
z0_pos = tf.contrib.distributions.Normal( | |
loc=mu_pos, | |
scale=tf.sqrt(tf.ones_like(mu_pos) * self.variance)) | |
z0_neg = tf.contrib.distributions.Normal( | |
loc=mu_neg, | |
scale=tf.sqrt(tf.ones_like(mu_neg) * self.variance)) | |
mode_probs = tf.convert_to_tensor([self.mixing_coeff, 1-self.mixing_coeff], dtype=tf.float64) | |
mode_probs = tf.tile(mode_probs[tf.newaxis, tf.newaxis, :], [batch_size, 1, 1]) | |
mode_selection_dist = tf.contrib.distributions.Categorical(probs=mode_probs) | |
z0_dist = tf.contrib.distributions.Mixture( | |
cat=mode_selection_dist, | |
components=[z0_pos, z0_neg], | |
validate_args=False) | |
return z0_dist | |
def generative(self, unused_observation, z_nm1): | |
"""Computes the model's generative distribution p(z_n| z_{n-1}).""" | |
generative_p_mu = z_nm1 + self.bs[-1] | |
return tf.contrib.distributions.Normal( | |
loc=generative_p_mu, scale=tf.sqrt(tf.ones_like(generative_p_mu) * self.variance)) | |
class Model(object): | |
def __init__(self, | |
p, | |
q, | |
r, | |
state_size, | |
num_timesteps, | |
dtype=tf.float32): | |
self.p = p | |
self.q = q | |
self.r = r | |
self.state_size = state_size | |
self.num_timesteps = num_timesteps | |
self.dtype = dtype | |
def zero_state(self, batch_size): | |
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
def __call__(self, prev_state, observation, t): | |
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}). | |
q_zt = self.q.q_zt(observation, prev_state, t) | |
# Compute the p distribution over z, p(z_t|z_{t-1}). | |
p_zt = self.p.p_zt(prev_state, t) | |
# sample from q | |
zt = q_zt.sample() | |
r_xn = self.r.r_xn(zt, t) | |
# Calculate the logprobs and sum over the state size. | |
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) | |
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) | |
log_r_xn = tf.reduce_sum(r_xn.log_prob(observation), axis=1) | |
# If we're at the last timestep, also calc the logprob of the observation. | |
if t == self.num_timesteps - 1: | |
generative_dist = self.p.generative(observation, zt) | |
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1) | |
else: | |
log_p_x_given_z = tf.zeros_like(log_q_zt) | |
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn) | |
def create(state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
r_sigma_init=1, | |
variance=1.0, | |
mixing_coeff=0.5, | |
prior_mode_mean=1.0, | |
dtype=tf.float32, | |
random_seed=None, | |
train_p=True, | |
p_type="unimodal", | |
q_type="normal", | |
observation_variance=1.0, | |
transition_type=STANDARD_TRANSITION, | |
use_bs=True): | |
if p_type == "unimodal": | |
p = P(state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
variance=variance, | |
dtype=dtype, | |
random_seed=random_seed, | |
trainable=train_p, | |
init_bs_to_zero=not use_bs) | |
elif p_type == "bimodal": | |
p = BimodalPriorP( | |
state_size, | |
num_timesteps, | |
mixing_coeff=mixing_coeff, | |
prior_mode_mean=prior_mode_mean, | |
sigma_min=sigma_min, | |
variance=variance, | |
dtype=dtype, | |
random_seed=random_seed, | |
trainable=train_p, | |
init_bs_to_zero=not use_bs) | |
elif "nonlinear" in p_type: | |
if "cauchy" in p_type: | |
trans_dist = tf.contrib.distributions.Cauchy | |
else: | |
trans_dist = tf.contrib.distributions.Normal | |
p = ShortChainNonlinearP( | |
state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
variance=variance, | |
observation_variance=observation_variance, | |
transition_type=transition_type, | |
transition_dist=trans_dist, | |
dtype=dtype, | |
random_seed=random_seed | |
) | |
if q_type == "normal": | |
q_class = Q | |
elif q_type == "simple_mean": | |
q_class = SimpleMeanQ | |
elif q_type == "prev_state": | |
q_class = PreviousStateQ | |
elif q_type == "observation": | |
q_class = ObservationQ | |
q = q_class(state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
dtype=dtype, | |
random_seed=random_seed, | |
init_mu0_to_zero=not use_bs) | |
r = R(state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
sigma_init=r_sigma_init, | |
dtype=dtype, | |
random_seed=random_seed) | |
model = Model(p, q, r, state_size, num_timesteps, dtype=dtype) | |
return model | |
class BackwardsModel(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
dtype=tf.float32): | |
self.state_size = state_size | |
self.num_timesteps = num_timesteps | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.bs = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="b_%d" % (t + 1), | |
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) | |
] | |
self.Bs = tf.cumsum(self.bs, reverse=True, axis=0) | |
self.q_mus = [ | |
snt.Linear(output_size=state_size) for _ in xrange(num_timesteps) | |
] | |
self.q_sigmas = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="q_sigma_%d" % (t + 1), | |
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) | |
] | |
self.r_mus = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="r_mu_%d" % (t + 1), | |
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) | |
] | |
self.r_sigmas = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="r_sigma_%d" % (t + 1), | |
initializer=tf.zeros_initializer) for t in xrange(num_timesteps) | |
] | |
def zero_state(self, batch_size): | |
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
def posterior(self, unused_observation, prev_state, unused_t): | |
# TODO(dieterichl): Correct this. | |
return tf.contrib.distributions.Normal( | |
loc=tf.zeros_like(prev_state), scale=tf.zeros_like(prev_state)) | |
def lookahead(self, state, unused_t): | |
# TODO(dieterichl): Correct this. | |
return tf.contrib.distributions.Normal( | |
loc=tf.zeros_like(state), scale=tf.zeros_like(state)) | |
def q_zt(self, observation, next_state, t): | |
"""Computes the variational posterior q(z_{t}|z_{t+1}, z_n).""" | |
t_backwards = self.num_timesteps - t - 1 | |
batch_size = tf.shape(next_state)[0] | |
q_mu = self.q_mus[t_backwards](tf.concat([observation, next_state], axis=1)) | |
q_sigma = tf.maximum( | |
tf.nn.softplus(self.q_sigmas[t_backwards]), self.sigma_min) | |
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) | |
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) | |
return q_zt | |
def p_zt(self, prev_state, t): | |
"""Computes the model p(z_{t+1}| z_{t}).""" | |
t_backwards = self.num_timesteps - t - 1 | |
z_mu_p = prev_state + self.bs[t_backwards] | |
p_zt = tf.contrib.distributions.Normal( | |
loc=z_mu_p, scale=tf.ones_like(z_mu_p)) | |
return p_zt | |
def generative(self, unused_observation, z_nm1): | |
"""Computes the model's generative distribution p(z_n| z_{n-1}).""" | |
generative_p_mu = z_nm1 + self.bs[-1] | |
return tf.contrib.distributions.Normal( | |
loc=generative_p_mu, scale=tf.ones_like(generative_p_mu)) | |
def r(self, z_t, t): | |
t_backwards = self.num_timesteps - t - 1 | |
batch_size = tf.shape(z_t)[0] | |
r_mu = tf.tile(self.r_mus[t_backwards][tf.newaxis, :], [batch_size, 1]) | |
r_sigma = tf.maximum( | |
tf.nn.softplus(self.r_sigmas[t_backwards]), self.sigma_min) | |
r_sigma = tf.tile(r_sigma[tf.newaxis, :], [batch_size, 1]) | |
return tf.contrib.distributions.Normal(loc=r_mu, scale=tf.sqrt(r_sigma)) | |
def likelihood(self, observation): | |
batch_size = tf.shape(observation)[0] | |
mu = tf.tile(tf.reduce_sum(self.bs, axis=0)[tf.newaxis, :], [batch_size, 1]) | |
sigma = tf.ones_like(mu) * (self.num_timesteps + 1) | |
dist = tf.contrib.distributions.Normal(loc=mu, scale=tf.sqrt(sigma)) | |
# Average over the batch and take the sum over the state size | |
return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observation), axis=1)) | |
def __call__(self, next_state, observation, t): | |
# next state = z_{t+1} | |
# Compute the q distribution over z, q(z_{t}|z_n, z_{t+1}). | |
q_zt = self.q_zt(observation, next_state, t) | |
# sample from q | |
zt = q_zt.sample() | |
# Compute the p distribution over z, p(z_{t+1}|z_{t}). | |
p_zt = self.p_zt(zt, t) | |
# Compute log p(z_{t+1} | z_t) | |
if t == 0: | |
log_p_zt = p_zt.log_prob(observation) | |
else: | |
log_p_zt = p_zt.log_prob(next_state) | |
# Compute r prior over zt | |
r_zt = self.r(zt, t) | |
log_r_zt = r_zt.log_prob(zt) | |
# Compute proposal density at zt | |
log_q_zt = q_zt.log_prob(zt) | |
# If we're at the last timestep, also calc the logprob of the observation. | |
if t == self.num_timesteps - 1: | |
p_z0_dist = tf.contrib.distributions.Normal( | |
loc=tf.zeros_like(zt), scale=tf.ones_like(zt)) | |
z0_log_prob = p_z0_dist.log_prob(zt) | |
else: | |
z0_log_prob = tf.zeros_like(log_q_zt) | |
return (zt, log_q_zt, log_p_zt, z0_log_prob, log_r_zt) | |
class LongChainP(object): | |
def __init__(self, | |
state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=1e-5, | |
variance=1.0, | |
observation_variance=1.0, | |
observation_type=STANDARD_OBSERVATION, | |
transition_type=STANDARD_TRANSITION, | |
dtype=tf.float32, | |
random_seed=None): | |
self.state_size = state_size | |
self.steps_per_obs = steps_per_obs | |
self.num_obs = num_obs | |
self.num_timesteps = steps_per_obs*num_obs + 1 | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.variance = variance | |
self.observation_variance = observation_variance | |
self.observation_type = observation_type | |
self.transition_type = transition_type | |
def likelihood(self, observations): | |
"""Computes the model's true likelihood of the observations. | |
Args: | |
observations: A [batch_size, m, state_size] Tensor representing each of | |
the m observations. | |
Returns: | |
logprob: The true likelihood of the observations given the model. | |
""" | |
raise ValueError("Likelihood is not defined for long-chain models") | |
# batch_size = tf.shape(observations)[0] | |
# mu = tf.zeros([batch_size, self.state_size, self.num_obs], dtype=self.dtype) | |
# sigma = np.fromfunction( | |
# lambda i, j: 1 + self.steps_per_obs*np.minimum(i+1, j+1), | |
# [self.num_obs, self.num_obs]) | |
# sigma += np.eye(self.num_obs) | |
# sigma = tf.convert_to_tensor(sigma * self.variance, dtype=self.dtype) | |
# sigma = tf.tile(sigma[tf.newaxis, tf.newaxis, ...], | |
# [batch_size, self.state_size, 1, 1]) | |
# dist = tf.contrib.distributions.MultivariateNormalFullCovariance( | |
# loc=mu, | |
# covariance_matrix=sigma) | |
# Average over the batch and take the sum over the state size | |
#return tf.reduce_mean(tf.reduce_sum(dist.log_prob(observations), axis=1)) | |
def p_zt(self, prev_state, t): | |
"""Computes the model p(z_t| z_{t-1}).""" | |
batch_size = tf.shape(prev_state)[0] | |
if t > 0: | |
if self.transition_type == ROUND_TRANSITION: | |
loc = tf.round(prev_state) | |
tf.logging.info("p(z_%d | z_%d) ~ N(round(z_%d), %0.1f)" % (t, t-1, t-1, self.variance)) | |
elif self.transition_type == STANDARD_TRANSITION: | |
loc = prev_state | |
tf.logging.info("p(z_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t-1, t-1, self.variance)) | |
else: # p(z_0) is Normal(0,1) | |
loc = tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
tf.logging.info("p(z_0) ~ N(0,%0.1f)" % self.variance) | |
p_zt = tf.contrib.distributions.Normal( | |
loc=loc, | |
scale=tf.sqrt(tf.ones_like(loc) * self.variance)) | |
return p_zt | |
def generative(self, z_ni, t): | |
"""Computes the model's generative distribution p(x_i| z_{ni}).""" | |
if self.observation_type == SQUARED_OBSERVATION: | |
generative_mu = tf.square(z_ni) | |
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d^2, %0.1f)" % (t, t, t, self.variance)) | |
elif self.observation_type == ABS_OBSERVATION: | |
generative_mu = tf.abs(z_ni) | |
tf.logging.info("p(x_%d | z_%d) ~ N(|z_%d|, %0.1f)" % (t, t, t, self.variance)) | |
elif self.observation_type == STANDARD_OBSERVATION: | |
generative_mu = z_ni | |
tf.logging.info("p(x_%d | z_%d) ~ N(z_%d, %0.1f)" % (t, t, t, self.variance)) | |
generative_sigma_sq = tf.ones_like(generative_mu) * self.observation_variance | |
return tf.contrib.distributions.Normal( | |
loc=generative_mu, scale=tf.sqrt(generative_sigma_sq)) | |
class LongChainQ(object): | |
def __init__(self, | |
state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=1e-5, | |
dtype=tf.float32, | |
random_seed=None): | |
self.state_size = state_size | |
self.sigma_min = sigma_min | |
self.dtype = dtype | |
self.steps_per_obs = steps_per_obs | |
self.num_obs = num_obs | |
self.num_timesteps = num_obs*steps_per_obs +1 | |
initializers = { | |
"w": tf.random_uniform_initializer(seed=random_seed), | |
"b": tf.zeros_initializer | |
} | |
self.mus = [ | |
snt.Linear(output_size=state_size, initializers=initializers) | |
for t in xrange(self.num_timesteps) | |
] | |
self.sigmas = [ | |
tf.get_variable( | |
shape=[state_size], | |
dtype=self.dtype, | |
name="q_sigma_%d" % (t + 1), | |
initializer=tf.random_uniform_initializer(seed=random_seed)) | |
for t in xrange(self.num_timesteps) | |
] | |
def first_relevant_obs_index(self, t): | |
return int(max((t-1)/self.steps_per_obs, 0)) | |
def q_zt(self, observations, prev_state, t): | |
"""Computes a distribution over z_t. | |
Args: | |
observations: a [batch_size, num_observations, state_size] Tensor. | |
prev_state: a [batch_size, state_size] Tensor. | |
t: The current timestep, an int Tensor. | |
""" | |
# filter out unneeded past obs | |
first_relevant_obs_index = int(math.floor(max(t-1, 0) / self.steps_per_obs)) | |
num_relevant_observations = self.num_obs - first_relevant_obs_index | |
observations = observations[:,first_relevant_obs_index:,:] | |
batch_size = tf.shape(prev_state)[0] | |
# concatenate the prev state and observations along the second axis (that is | |
# not the batch or state size axis, and then flatten it to | |
# [batch_size, (num_relevant_observations + 1) * state_size] to feed it into | |
# the linear layer. | |
q_input = tf.concat([observations, prev_state[:,tf.newaxis, :]], axis=1) | |
q_input = tf.reshape(q_input, | |
[batch_size, (num_relevant_observations + 1) * self.state_size]) | |
q_mu = self.mus[t](q_input) | |
q_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
q_sigma = tf.tile(q_sigma[tf.newaxis, :], [batch_size, 1]) | |
q_zt = tf.contrib.distributions.Normal(loc=q_mu, scale=tf.sqrt(q_sigma)) | |
tf.logging.info( | |
"q(z_{t} | z_{tm1}, x_{obsf}:{obst}) ~ N(Linear([z_{tm1},x_{obsf}:{obst}]), sigma_{t})".format( | |
**{"t": t, | |
"tm1": t-1, | |
"obsf": (first_relevant_obs_index+1)*self.steps_per_obs, | |
"obst":self.steps_per_obs*self.num_obs})) | |
return q_zt | |
def summarize_weights(self): | |
pass | |
class LongChainR(object): | |
def __init__(self, | |
state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=1e-5, | |
dtype=tf.float32, | |
random_seed=None): | |
self.state_size = state_size | |
self.dtype = dtype | |
self.sigma_min = sigma_min | |
self.steps_per_obs = steps_per_obs | |
self.num_obs = num_obs | |
self.num_timesteps = num_obs*steps_per_obs + 1 | |
self.sigmas = [ | |
tf.get_variable( | |
shape=[self.num_future_obs(t)], | |
dtype=self.dtype, | |
name="r_sigma_%d" % (t + 1), | |
#initializer=tf.random_uniform_initializer(seed=random_seed, maxval=100)) | |
initializer=tf.constant_initializer(1.0)) | |
for t in range(self.num_timesteps) | |
] | |
def first_future_obs_index(self, t): | |
return int(math.floor(t / self.steps_per_obs)) | |
def num_future_obs(self, t): | |
return int(self.num_obs - self.first_future_obs_index(t)) | |
def r_xn(self, z_t, t): | |
"""Computes a distribution over the future observations given current latent | |
state. | |
The indexing in these messages is 1 indexed and inclusive. This is | |
consistent with the latex documents. | |
Args: | |
z_t: [batch_size, state_size] Tensor | |
t: Current timestep | |
""" | |
tf.logging.info( | |
"r(x_{start}:{end} | z_{t}) ~ N(z_{t}, sigma_{t})".format( | |
**{"t": t, | |
"start": (self.first_future_obs_index(t)+1)*self.steps_per_obs, | |
"end": self.num_timesteps-1})) | |
batch_size = tf.shape(z_t)[0] | |
# the mean for all future observations is the same. | |
# this tiling results in a [batch_size, num_future_obs, state_size] Tensor | |
r_mu = tf.tile(z_t[:,tf.newaxis,:], [1, self.num_future_obs(t), 1]) | |
# compute the variance | |
r_sigma = tf.maximum(tf.nn.softplus(self.sigmas[t]), self.sigma_min) | |
# the variance is the same across all state dimensions, so we only have to | |
# time sigma to be [batch_size, num_future_obs]. | |
r_sigma = tf.tile(r_sigma[tf.newaxis,:, tf.newaxis], [batch_size, 1, self.state_size]) | |
return tf.contrib.distributions.Normal( | |
loc=r_mu, scale=tf.sqrt(r_sigma)) | |
def summarize_weights(self): | |
pass | |
class LongChainModel(object): | |
def __init__(self, | |
p, | |
q, | |
r, | |
state_size, | |
num_obs, | |
steps_per_obs, | |
dtype=tf.float32, | |
disable_r=False): | |
self.p = p | |
self.q = q | |
self.r = r | |
self.disable_r = disable_r | |
self.state_size = state_size | |
self.num_obs = num_obs | |
self.steps_per_obs = steps_per_obs | |
self.num_timesteps = steps_per_obs*num_obs + 1 | |
self.dtype = dtype | |
def zero_state(self, batch_size): | |
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
def next_obs_ind(self, t): | |
return int(math.floor(max(t-1,0)/self.steps_per_obs)) | |
def __call__(self, prev_state, observations, t): | |
"""Computes the importance weight for the model system. | |
Args: | |
prev_state: [batch_size, state_size] Tensor | |
observations: [batch_size, num_observations, state_size] Tensor | |
""" | |
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}). | |
q_zt = self.q.q_zt(observations, prev_state, t) | |
# Compute the p distribution over z, p(z_t|z_{t-1}). | |
p_zt = self.p.p_zt(prev_state, t) | |
# sample from q and evaluate the logprobs, summing over the state size | |
zt = q_zt.sample() | |
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) | |
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) | |
if not self.disable_r and t < self.num_timesteps-1: | |
# score the remaining observations using r | |
r_xn = self.r.r_xn(zt, t) | |
log_r_xn = r_xn.log_prob(observations[:, self.next_obs_ind(t+1):, :]) | |
# sum over state size and observation, leaving the batch index | |
log_r_xn = tf.reduce_sum(log_r_xn, axis=[1,2]) | |
else: | |
log_r_xn = tf.zeros_like(log_p_zt) | |
if t != 0 and t % self.steps_per_obs == 0: | |
generative_dist = self.p.generative(zt, t) | |
log_p_x_given_z = generative_dist.log_prob(observations[:,self.next_obs_ind(t),:]) | |
log_p_x_given_z = tf.reduce_sum(log_p_x_given_z, axis=1) | |
else: | |
log_p_x_given_z = tf.zeros_like(log_q_zt) | |
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, log_r_xn) | |
def create(state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=1e-5, | |
variance=1.0, | |
observation_variance=1.0, | |
observation_type=STANDARD_OBSERVATION, | |
transition_type=STANDARD_TRANSITION, | |
dtype=tf.float32, | |
random_seed=None, | |
disable_r=False): | |
p = LongChainP( | |
state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=sigma_min, | |
variance=variance, | |
observation_variance=observation_variance, | |
observation_type=observation_type, | |
transition_type=transition_type, | |
dtype=dtype, | |
random_seed=random_seed) | |
q = LongChainQ( | |
state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=sigma_min, | |
dtype=dtype, | |
random_seed=random_seed) | |
r = LongChainR( | |
state_size, | |
num_obs, | |
steps_per_obs, | |
sigma_min=sigma_min, | |
dtype=dtype, | |
random_seed=random_seed) | |
model = LongChainModel( | |
p, q, r, state_size, num_obs, steps_per_obs, | |
dtype=dtype, | |
disable_r=disable_r) | |
return model | |
class RTilde(object): | |
def __init__(self, | |
state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
dtype=tf.float32, | |
random_seed=None, | |
graph_collection_name="R_TILDE_VARS"): | |
self.dtype = dtype | |
self.sigma_min = sigma_min | |
initializers = {"w": tf.truncated_normal_initializer(seed=random_seed), | |
"b": tf.zeros_initializer} | |
self.graph_collection_name=graph_collection_name | |
def custom_getter(getter, *args, **kwargs): | |
out = getter(*args, **kwargs) | |
ref = tf.get_collection_ref(self.graph_collection_name) | |
if out not in ref: | |
ref.append(out) | |
return out | |
self.fns = [ | |
snt.Linear(output_size=2*state_size, | |
initializers=initializers, | |
name="r_tilde_%d" % t, | |
custom_getter=custom_getter) | |
for t in xrange(num_timesteps) | |
] | |
def r_zt(self, z_t, observation, t): | |
#out = self.fns[t](tf.stop_gradient(tf.concat([z_t, observation], axis=1))) | |
out = self.fns[t](tf.concat([z_t, observation], axis=1)) | |
mu, raw_sigma_sq = tf.split(out, 2, axis=1) | |
sigma_sq = tf.maximum(tf.nn.softplus(raw_sigma_sq), self.sigma_min) | |
return mu, sigma_sq | |
class TDModel(object): | |
def __init__(self, | |
p, | |
q, | |
r_tilde, | |
state_size, | |
num_timesteps, | |
dtype=tf.float32, | |
disable_r=False): | |
self.p = p | |
self.q = q | |
self.r_tilde = r_tilde | |
self.disable_r = disable_r | |
self.state_size = state_size | |
self.num_timesteps = num_timesteps | |
self.dtype = dtype | |
def zero_state(self, batch_size): | |
return tf.zeros([batch_size, self.state_size], dtype=self.dtype) | |
def __call__(self, prev_state, observation, t): | |
"""Computes the importance weight for the model system. | |
Args: | |
prev_state: [batch_size, state_size] Tensor | |
observations: [batch_size, num_observations, state_size] Tensor | |
""" | |
# Compute the q distribution over z, q(z_t|z_n, z_{t-1}). | |
q_zt = self.q.q_zt(observation, prev_state, t) | |
# Compute the p distribution over z, p(z_t|z_{t-1}). | |
p_zt = self.p.p_zt(prev_state, t) | |
# sample from q and evaluate the logprobs, summing over the state size | |
zt = q_zt.sample() | |
# If it isn't the last timestep, compute the distribution over the next z. | |
if t < self.num_timesteps - 1: | |
p_ztplus1 = self.p.p_zt(zt, t+1) | |
else: | |
p_ztplus1 = None | |
log_q_zt = tf.reduce_sum(q_zt.log_prob(zt), axis=1) | |
log_p_zt = tf.reduce_sum(p_zt.log_prob(zt), axis=1) | |
if not self.disable_r and t < self.num_timesteps-1: | |
# score the remaining observations using r | |
r_tilde_mu, r_tilde_sigma_sq = self.r_tilde.r_zt(zt, observation, t+1) | |
else: | |
r_tilde_mu = None | |
r_tilde_sigma_sq = None | |
if t == self.num_timesteps - 1: | |
generative_dist = self.p.generative(observation, zt) | |
log_p_x_given_z = tf.reduce_sum(generative_dist.log_prob(observation), axis=1) | |
else: | |
log_p_x_given_z = tf.zeros_like(log_q_zt) | |
return (zt, log_q_zt, log_p_zt, log_p_x_given_z, | |
r_tilde_mu, r_tilde_sigma_sq, p_ztplus1) | |
def create(state_size, | |
num_timesteps, | |
sigma_min=1e-5, | |
variance=1.0, | |
dtype=tf.float32, | |
random_seed=None, | |
train_p=True, | |
p_type="unimodal", | |
q_type="normal", | |
mixing_coeff=0.5, | |
prior_mode_mean=1.0, | |
observation_variance=1.0, | |
transition_type=STANDARD_TRANSITION, | |
use_bs=True): | |
if p_type == "unimodal": | |
p = P(state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
variance=variance, | |
dtype=dtype, | |
random_seed=random_seed, | |
trainable=train_p, | |
init_bs_to_zero=not use_bs) | |
elif p_type == "bimodal": | |
p = BimodalPriorP( | |
state_size, | |
num_timesteps, | |
mixing_coeff=mixing_coeff, | |
prior_mode_mean=prior_mode_mean, | |
sigma_min=sigma_min, | |
variance=variance, | |
dtype=dtype, | |
random_seed=random_seed, | |
trainable=train_p, | |
init_bs_to_zero=not use_bs) | |
elif "nonlinear" in p_type: | |
if "cauchy" in p_type: | |
trans_dist = tf.contrib.distributions.Cauchy | |
else: | |
trans_dist = tf.contrib.distributions.Normal | |
p = ShortChainNonlinearP( | |
state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
variance=variance, | |
observation_variance=observation_variance, | |
transition_type=transition_type, | |
transition_dist=trans_dist, | |
dtype=dtype, | |
random_seed=random_seed | |
) | |
if q_type == "normal": | |
q_class = Q | |
elif q_type == "simple_mean": | |
q_class = SimpleMeanQ | |
elif q_type == "prev_state": | |
q_class = PreviousStateQ | |
elif q_type == "observation": | |
q_class = ObservationQ | |
q = q_class(state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
dtype=dtype, | |
random_seed=random_seed, | |
init_mu0_to_zero=not use_bs) | |
r_tilde = RTilde( | |
state_size, | |
num_timesteps, | |
sigma_min=sigma_min, | |
dtype=dtype, | |
random_seed=random_seed) | |
model = TDModel(p, q, r_tilde, state_size, num_timesteps, dtype=dtype) | |
return model | |