Spaces:
Running
Running
# Copyright 2017 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import json | |
import random | |
import sys | |
import os | |
import numpy as np | |
import tensorflow as tf | |
import rebar | |
import datasets | |
import logger as L | |
try: | |
xrange # Python 2 | |
except NameError: | |
xrange = range # Python 3 | |
gfile = tf.gfile | |
tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar", | |
"""Directory where to save data, write logs, etc.""") | |
tf.app.flags.DEFINE_string('hparams', '', | |
'''Comma separated list of name=value pairs.''') | |
tf.app.flags.DEFINE_integer('eval_freq', 20, | |
'''How often to run the evaluation step.''') | |
FLAGS = tf.flags.FLAGS | |
def manual_scalar_summary(name, value): | |
value = tf.Summary.Value(tag=name, simple_value=value) | |
summary_str = tf.Summary(value=[value]) | |
return summary_str | |
def eval(sbn, eval_xs, n_samples=100, batch_size=5): | |
n = eval_xs.shape[0] | |
i = 0 | |
res = [] | |
while i < n: | |
batch_xs = eval_xs[i:min(i+batch_size, n)] | |
res.append(sbn.partial_eval(batch_xs, n_samples)) | |
i += batch_size | |
res = np.mean(res, axis=0) | |
return res | |
def train(sbn, train_xs, valid_xs, test_xs, training_steps, debug=False): | |
hparams = sorted(sbn.hparams.values().items()) | |
hparams = (map(str, x) for x in hparams) | |
hparams = ('_'.join(x) for x in hparams) | |
hparams_str = '.'.join(hparams) | |
logger = L.Logger() | |
# Create the experiment name from the hparams | |
experiment_name = ([str(sbn.hparams.n_hidden) for i in xrange(sbn.hparams.n_layer)] + | |
[str(sbn.hparams.n_input)]) | |
if sbn.hparams.nonlinear: | |
experiment_name = '~'.join(experiment_name) | |
else: | |
experiment_name = '-'.join(experiment_name) | |
experiment_name = 'SBN_%s' % experiment_name | |
rowkey = {'experiment': experiment_name, | |
'model': hparams_str} | |
# Create summary writer | |
summ_dir = os.path.join(FLAGS.working_dir, hparams_str) | |
summary_writer = tf.summary.FileWriter( | |
summ_dir, flush_secs=15, max_queue=100) | |
sv = tf.train.Supervisor(logdir=os.path.join( | |
FLAGS.working_dir, hparams_str), | |
save_summaries_secs=0, | |
save_model_secs=1200, | |
summary_op=None, | |
recovery_wait_secs=30, | |
global_step=sbn.global_step) | |
with sv.managed_session() as sess: | |
# Dump hparams to file | |
with gfile.Open(os.path.join(FLAGS.working_dir, | |
hparams_str, | |
'hparams.json'), | |
'w') as out: | |
json.dump(sbn.hparams.values(), out) | |
sbn.initialize(sess) | |
batch_size = sbn.hparams.batch_size | |
scores = [] | |
n = train_xs.shape[0] | |
index = range(n) | |
while not sv.should_stop(): | |
lHats = [] | |
grad_variances = [] | |
temperatures = [] | |
random.shuffle(index) | |
i = 0 | |
while i < n: | |
batch_index = index[i:min(i+batch_size, n)] | |
batch_xs = train_xs[batch_index, :] | |
if sbn.hparams.dynamic_b: | |
# Dynamically binarize the batch data | |
batch_xs = (np.random.rand(*batch_xs.shape) < batch_xs).astype(float) | |
lHat, grad_variance, step, temperature = sbn.partial_fit(batch_xs, | |
sbn.hparams.n_samples) | |
if debug: | |
print(i, lHat) | |
if i > 100: | |
return | |
lHats.append(lHat) | |
grad_variances.append(grad_variance) | |
temperatures.append(temperature) | |
i += batch_size | |
grad_variances = np.log(np.mean(grad_variances, axis=0)).tolist() | |
summary_strings = [] | |
if isinstance(grad_variances, list): | |
grad_variances = dict(zip([k for (k, v) in sbn.losses], map(float, grad_variances))) | |
rowkey['step'] = step | |
logger.log(rowkey, {'step': step, | |
'train': np.mean(lHats, axis=0)[0], | |
'grad_variances': grad_variances, | |
'temperature': np.mean(temperatures), }) | |
grad_variances = '\n'.join(map(str, sorted(grad_variances.iteritems()))) | |
else: | |
rowkey['step'] = step | |
logger.log(rowkey, {'step': step, | |
'train': np.mean(lHats, axis=0)[0], | |
'grad_variance': grad_variances, | |
'temperature': np.mean(temperatures), }) | |
summary_strings.append(manual_scalar_summary("log grad variance", grad_variances)) | |
print('Step %d: %s\n%s' % (step, str(np.mean(lHats, axis=0)), str(grad_variances))) | |
# Every few epochs compute test and validation scores | |
epoch = int(step / (train_xs.shape[0] / sbn.hparams.batch_size)) | |
if epoch % FLAGS.eval_freq == 0: | |
valid_res = eval(sbn, valid_xs) | |
test_res= eval(sbn, test_xs) | |
print('\nValid %d: %s' % (step, str(valid_res))) | |
print('Test %d: %s\n' % (step, str(test_res))) | |
logger.log(rowkey, {'step': step, | |
'valid': valid_res[0], | |
'test': test_res[0]}) | |
logger.flush() # Flush infrequently | |
# Create summaries | |
summary_strings.extend([ | |
manual_scalar_summary("Train ELBO", np.mean(lHats, axis=0)[0]), | |
manual_scalar_summary("Temperature", np.mean(temperatures)), | |
]) | |
for summ_str in summary_strings: | |
summary_writer.add_summary(summ_str, global_step=step) | |
summary_writer.flush() | |
sys.stdout.flush() | |
scores.append(np.mean(lHats, axis=0)) | |
if step > training_steps: | |
break | |
return scores | |
def main(): | |
# Parse hyperparams | |
hparams = rebar.default_hparams | |
hparams.parse(FLAGS.hparams) | |
print(hparams.values()) | |
train_xs, valid_xs, test_xs = datasets.load_data(hparams) | |
mean_xs = np.mean(train_xs, axis=0) # Compute mean centering on training | |
training_steps = 2000000 | |
model = getattr(rebar, hparams.model) | |
sbn = model(hparams, mean_xs=mean_xs) | |
scores = train(sbn, train_xs, valid_xs, test_xs, | |
training_steps=training_steps, debug=False) | |
if __name__ == '__main__': | |
main() | |