NCTCMumbai's picture
Upload 2583 files
97b6013 verified
raw
history blame
No virus
6.48 kB
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A script to run training for sequential latent variable models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from fivo import ghmm_runners
from fivo import runners
# Shared flags.
tf.app.flags.DEFINE_enum("mode", "train",
["train", "eval", "sample"],
"The mode of the binary.")
tf.app.flags.DEFINE_enum("model", "vrnn",
["vrnn", "ghmm", "srnn"],
"Model choice.")
tf.app.flags.DEFINE_integer("latent_size", 64,
"The size of the latent state of the model.")
tf.app.flags.DEFINE_enum("dataset_type", "pianoroll",
["pianoroll", "speech", "pose"],
"The type of dataset.")
tf.app.flags.DEFINE_string("dataset_path", "",
"Path to load the dataset from.")
tf.app.flags.DEFINE_integer("data_dimension", None,
"The dimension of each vector in the data sequence. "
"Defaults to 88 for pianoroll datasets and 200 for speech "
"datasets. Should not need to be changed except for "
"testing.")
tf.app.flags.DEFINE_integer("batch_size", 4,
"Batch size.")
tf.app.flags.DEFINE_integer("num_samples", 4,
"The number of samples (or particles) for multisample "
"algorithms.")
tf.app.flags.DEFINE_string("logdir", "/tmp/smc_vi",
"The directory to keep checkpoints and summaries in.")
tf.app.flags.DEFINE_integer("random_seed", None,
"A random seed for seeding the TensorFlow graph.")
tf.app.flags.DEFINE_integer("parallel_iterations", 30,
"The number of parallel iterations to use for the while "
"loop that computes the bounds.")
# Training flags.
tf.app.flags.DEFINE_enum("bound", "fivo",
["elbo", "iwae", "fivo", "fivo-aux"],
"The bound to optimize.")
tf.app.flags.DEFINE_boolean("normalize_by_seq_len", True,
"If true, normalize the loss by the number of timesteps "
"per sequence.")
tf.app.flags.DEFINE_float("learning_rate", 0.0002,
"The learning rate for ADAM.")
tf.app.flags.DEFINE_integer("max_steps", int(1e9),
"The number of gradient update steps to train for.")
tf.app.flags.DEFINE_integer("summarize_every", 50,
"The number of steps between summaries.")
tf.app.flags.DEFINE_enum("resampling_type", "multinomial",
["multinomial", "relaxed"],
"The resampling strategy to use for training.")
tf.app.flags.DEFINE_float("relaxed_resampling_temperature", 0.5,
"The relaxation temperature for relaxed resampling.")
tf.app.flags.DEFINE_enum("proposal_type", "filtering",
["prior", "filtering", "smoothing",
"true-filtering", "true-smoothing"],
"The type of proposal to use. true-filtering and true-smoothing "
"are only available for the GHMM. The specific implementation "
"of each proposal type is left to model-writers.")
# Distributed training flags.
tf.app.flags.DEFINE_string("master", "",
"The BNS name of the TensorFlow master to use.")
tf.app.flags.DEFINE_integer("task", 0,
"Task id of the replica running the training.")
tf.app.flags.DEFINE_integer("ps_tasks", 0,
"Number of tasks in the ps job. If 0 no ps job is used.")
tf.app.flags.DEFINE_boolean("stagger_workers", True,
"If true, bring one worker online every 1000 steps.")
# Evaluation flags.
tf.app.flags.DEFINE_enum("split", "train",
["train", "test", "valid"],
"Split to evaluate the model on.")
# Sampling flags.
tf.app.flags.DEFINE_integer("sample_length", 50,
"The number of timesteps to sample for.")
tf.app.flags.DEFINE_integer("prefix_length", 25,
"The number of timesteps to condition the model on "
"before sampling.")
tf.app.flags.DEFINE_string("sample_out_dir", None,
"The directory to write the samples to. "
"Defaults to logdir.")
# GHMM flags.
tf.app.flags.DEFINE_float("variance", 0.1,
"The variance of the ghmm.")
tf.app.flags.DEFINE_integer("num_timesteps", 5,
"The number of timesteps to run the gmp for.")
FLAGS = tf.app.flags.FLAGS
PIANOROLL_DEFAULT_DATA_DIMENSION = 88
SPEECH_DEFAULT_DATA_DIMENSION = 200
def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO)
if FLAGS.model in ["vrnn", "srnn"]:
if FLAGS.data_dimension is None:
if FLAGS.dataset_type == "pianoroll":
FLAGS.data_dimension = PIANOROLL_DEFAULT_DATA_DIMENSION
elif FLAGS.dataset_type == "speech":
FLAGS.data_dimension = SPEECH_DEFAULT_DATA_DIMENSION
if FLAGS.mode == "train":
runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
runners.run_eval(FLAGS)
elif FLAGS.mode == "sample":
runners.run_sample(FLAGS)
elif FLAGS.model == "ghmm":
if FLAGS.mode == "train":
ghmm_runners.run_train(FLAGS)
elif FLAGS.mode == "eval":
ghmm_runners.run_eval(FLAGS)
if __name__ == "__main__":
tf.app.run(main)