NCTC / models /research /rebar /rebar_train.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
6.85 kB
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import random
import sys
import os
import numpy as np
import tensorflow as tf
import rebar
import datasets
import logger as L
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
gfile = tf.gfile
tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar",
"""Directory where to save data, write logs, etc.""")
tf.app.flags.DEFINE_string('hparams', '',
'''Comma separated list of name=value pairs.''')
tf.app.flags.DEFINE_integer('eval_freq', 20,
'''How often to run the evaluation step.''')
FLAGS = tf.flags.FLAGS
def manual_scalar_summary(name, value):
value = tf.Summary.Value(tag=name, simple_value=value)
summary_str = tf.Summary(value=[value])
return summary_str
def eval(sbn, eval_xs, n_samples=100, batch_size=5):
n = eval_xs.shape[0]
i = 0
res = []
while i < n:
batch_xs = eval_xs[i:min(i+batch_size, n)]
res.append(sbn.partial_eval(batch_xs, n_samples))
i += batch_size
res = np.mean(res, axis=0)
return res
def train(sbn, train_xs, valid_xs, test_xs, training_steps, debug=False):
hparams = sorted(sbn.hparams.values().items())
hparams = (map(str, x) for x in hparams)
hparams = ('_'.join(x) for x in hparams)
hparams_str = '.'.join(hparams)
logger = L.Logger()
# Create the experiment name from the hparams
experiment_name = ([str(sbn.hparams.n_hidden) for i in xrange(sbn.hparams.n_layer)] +
[str(sbn.hparams.n_input)])
if sbn.hparams.nonlinear:
experiment_name = '~'.join(experiment_name)
else:
experiment_name = '-'.join(experiment_name)
experiment_name = 'SBN_%s' % experiment_name
rowkey = {'experiment': experiment_name,
'model': hparams_str}
# Create summary writer
summ_dir = os.path.join(FLAGS.working_dir, hparams_str)
summary_writer = tf.summary.FileWriter(
summ_dir, flush_secs=15, max_queue=100)
sv = tf.train.Supervisor(logdir=os.path.join(
FLAGS.working_dir, hparams_str),
save_summaries_secs=0,
save_model_secs=1200,
summary_op=None,
recovery_wait_secs=30,
global_step=sbn.global_step)
with sv.managed_session() as sess:
# Dump hparams to file
with gfile.Open(os.path.join(FLAGS.working_dir,
hparams_str,
'hparams.json'),
'w') as out:
json.dump(sbn.hparams.values(), out)
sbn.initialize(sess)
batch_size = sbn.hparams.batch_size
scores = []
n = train_xs.shape[0]
index = range(n)
while not sv.should_stop():
lHats = []
grad_variances = []
temperatures = []
random.shuffle(index)
i = 0
while i < n:
batch_index = index[i:min(i+batch_size, n)]
batch_xs = train_xs[batch_index, :]
if sbn.hparams.dynamic_b:
# Dynamically binarize the batch data
batch_xs = (np.random.rand(*batch_xs.shape) < batch_xs).astype(float)
lHat, grad_variance, step, temperature = sbn.partial_fit(batch_xs,
sbn.hparams.n_samples)
if debug:
print(i, lHat)
if i > 100:
return
lHats.append(lHat)
grad_variances.append(grad_variance)
temperatures.append(temperature)
i += batch_size
grad_variances = np.log(np.mean(grad_variances, axis=0)).tolist()
summary_strings = []
if isinstance(grad_variances, list):
grad_variances = dict(zip([k for (k, v) in sbn.losses], map(float, grad_variances)))
rowkey['step'] = step
logger.log(rowkey, {'step': step,
'train': np.mean(lHats, axis=0)[0],
'grad_variances': grad_variances,
'temperature': np.mean(temperatures), })
grad_variances = '\n'.join(map(str, sorted(grad_variances.iteritems())))
else:
rowkey['step'] = step
logger.log(rowkey, {'step': step,
'train': np.mean(lHats, axis=0)[0],
'grad_variance': grad_variances,
'temperature': np.mean(temperatures), })
summary_strings.append(manual_scalar_summary("log grad variance", grad_variances))
print('Step %d: %s\n%s' % (step, str(np.mean(lHats, axis=0)), str(grad_variances)))
# Every few epochs compute test and validation scores
epoch = int(step / (train_xs.shape[0] / sbn.hparams.batch_size))
if epoch % FLAGS.eval_freq == 0:
valid_res = eval(sbn, valid_xs)
test_res= eval(sbn, test_xs)
print('\nValid %d: %s' % (step, str(valid_res)))
print('Test %d: %s\n' % (step, str(test_res)))
logger.log(rowkey, {'step': step,
'valid': valid_res[0],
'test': test_res[0]})
logger.flush() # Flush infrequently
# Create summaries
summary_strings.extend([
manual_scalar_summary("Train ELBO", np.mean(lHats, axis=0)[0]),
manual_scalar_summary("Temperature", np.mean(temperatures)),
])
for summ_str in summary_strings:
summary_writer.add_summary(summ_str, global_step=step)
summary_writer.flush()
sys.stdout.flush()
scores.append(np.mean(lHats, axis=0))
if step > training_steps:
break
return scores
def main():
# Parse hyperparams
hparams = rebar.default_hparams
hparams.parse(FLAGS.hparams)
print(hparams.values())
train_xs, valid_xs, test_xs = datasets.load_data(hparams)
mean_xs = np.mean(train_xs, axis=0) # Compute mean centering on training
training_steps = 2000000
model = getattr(rebar, hparams.model)
sbn = model(hparams, mean_xs=mean_xs)
scores = train(sbn, train_xs, valid_xs, test_xs,
training_steps=training_steps, debug=False)
if __name__ == '__main__':
main()