|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A script to run training for sequential latent variable models. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow as tf |
|
|
|
from fivo import ghmm_runners |
|
from fivo import runners |
|
|
|
|
|
tf.app.flags.DEFINE_enum("mode", "train", |
|
["train", "eval", "sample"], |
|
"The mode of the binary.") |
|
tf.app.flags.DEFINE_enum("model", "vrnn", |
|
["vrnn", "ghmm", "srnn"], |
|
"Model choice.") |
|
tf.app.flags.DEFINE_integer("latent_size", 64, |
|
"The size of the latent state of the model.") |
|
tf.app.flags.DEFINE_enum("dataset_type", "pianoroll", |
|
["pianoroll", "speech", "pose"], |
|
"The type of dataset.") |
|
tf.app.flags.DEFINE_string("dataset_path", "", |
|
"Path to load the dataset from.") |
|
tf.app.flags.DEFINE_integer("data_dimension", None, |
|
"The dimension of each vector in the data sequence. " |
|
"Defaults to 88 for pianoroll datasets and 200 for speech " |
|
"datasets. Should not need to be changed except for " |
|
"testing.") |
|
tf.app.flags.DEFINE_integer("batch_size", 4, |
|
"Batch size.") |
|
tf.app.flags.DEFINE_integer("num_samples", 4, |
|
"The number of samples (or particles) for multisample " |
|
"algorithms.") |
|
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi", |
|
"The directory to keep checkpoints and summaries in.") |
|
tf.app.flags.DEFINE_integer("random_seed", None, |
|
"A random seed for seeding the TensorFlow graph.") |
|
tf.app.flags.DEFINE_integer("parallel_iterations", 30, |
|
"The number of parallel iterations to use for the while " |
|
"loop that computes the bounds.") |
|
|
|
|
|
tf.app.flags.DEFINE_enum("bound", "fivo", |
|
["elbo", "iwae", "fivo", "fivo-aux"], |
|
"The bound to optimize.") |
|
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True, |
|
"If true, normalize the loss by the number of timesteps " |
|
"per sequence.") |
|
tf.app.flags.DEFINE_float("learning_rate", 0.0002, |
|
"The learning rate for ADAM.") |
|
tf.app.flags.DEFINE_integer("max_steps", int(1e9), |
|
"The number of gradient update steps to train for.") |
|
tf.app.flags.DEFINE_integer("summarize_every", 50, |
|
"The number of steps between summaries.") |
|
tf.app.flags.DEFINE_enum("resampling_type", "multinomial", |
|
["multinomial", "relaxed"], |
|
"The resampling strategy to use for training.") |
|
tf.app.flags.DEFINE_float("relaxed_resampling_temperature", 0.5, |
|
"The relaxation temperature for relaxed resampling.") |
|
tf.app.flags.DEFINE_enum("proposal_type", "filtering", |
|
["prior", "filtering", "smoothing", |
|
"true-filtering", "true-smoothing"], |
|
"The type of proposal to use. true-filtering and true-smoothing " |
|
"are only available for the GHMM. The specific implementation " |
|
"of each proposal type is left to model-writers.") |
|
|
|
|
|
tf.app.flags.DEFINE_string("master", "", |
|
"The BNS name of the TensorFlow master to use.") |
|
tf.app.flags.DEFINE_integer("task", 0, |
|
"Task id of the replica running the training.") |
|
tf.app.flags.DEFINE_integer("ps_tasks", 0, |
|
"Number of tasks in the ps job. If 0 no ps job is used.") |
|
tf.app.flags.DEFINE_boolean("stagger_workers", True, |
|
"If true, bring one worker online every 1000 steps.") |
|
|
|
|
|
tf.app.flags.DEFINE_enum("split", "train", |
|
["train", "test", "valid"], |
|
"Split to evaluate the model on.") |
|
|
|
|
|
tf.app.flags.DEFINE_integer("sample_length", 50, |
|
"The number of timesteps to sample for.") |
|
tf.app.flags.DEFINE_integer("prefix_length", 25, |
|
"The number of timesteps to condition the model on " |
|
"before sampling.") |
|
tf.app.flags.DEFINE_string("sample_out_dir", None, |
|
"The directory to write the samples to. " |
|
"Defaults to logdir.") |
|
|
|
|
|
tf.app.flags.DEFINE_float("variance", 0.1, |
|
"The variance of the ghmm.") |
|
tf.app.flags.DEFINE_integer("num_timesteps", 5, |
|
"The number of timesteps to run the gmp for.") |
|
FLAGS = tf.app.flags.FLAGS |
|
|
|
PIANOROLL_DEFAULT_DATA_DIMENSION = 88 |
|
SPEECH_DEFAULT_DATA_DIMENSION = 200 |
|
|
|
|
|
def main(unused_argv): |
|
tf.logging.set_verbosity(tf.logging.INFO) |
|
if FLAGS.model in ["vrnn", "srnn"]: |
|
if FLAGS.data_dimension is None: |
|
if FLAGS.dataset_type == "pianoroll": |
|
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION |
|
elif FLAGS.dataset_type == "speech": |
|
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION |
|
if FLAGS.mode == "train": |
|
runners.run_train(FLAGS) |
|
elif FLAGS.mode == "eval": |
|
runners.run_eval(FLAGS) |
|
elif FLAGS.mode == "sample": |
|
runners.run_sample(FLAGS) |
|
elif FLAGS.model == "ghmm": |
|
if FLAGS.mode == "train": |
|
ghmm_runners.run_train(FLAGS) |
|
elif FLAGS.mode == "eval": |
|
ghmm_runners.run_eval(FLAGS) |
|
|
|
if __name__ == "__main__": |
|
tf.app.run(main) |
|
|