NCTCMumbai's picture
Upload 2583 files
97b6013 verified
raw
history blame
30.3 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
r"""Train RL agent on coding tasks."""
import contextlib
import cPickle
import cProfile
import marshal
import os
import time
from absl import flags
from absl import logging
import tensorflow as tf
# internal session lib import
from single_task import data # brain coder
from single_task import defaults # brain coder
from single_task import pg_agent as agent_lib # brain coder
from single_task import results_lib # brain coder
FLAGS = flags.FLAGS
flags.DEFINE_string(
'master', '',
'URL of the TensorFlow master to use.')
flags.DEFINE_integer(
'ps_tasks', 0,
'Number of parameter server tasks. Only set to 0 for '
'single worker training.')
flags.DEFINE_integer(
'summary_interval', 10,
'How often to write summaries.')
flags.DEFINE_integer(
'summary_tasks', 16,
'If greater than 0 only tasks 0 through summary_tasks - 1 '
'will write summaries. If 0, all tasks will write '
'summaries.')
flags.DEFINE_bool(
'stop_on_success', True,
'If True, training will stop as soon as a solution is found. '
'If False, training will continue indefinitely until another '
'stopping condition is reached.')
flags.DEFINE_bool(
'do_profiling', False,
'If True, cProfile profiler will run and results will be '
'written to logdir. WARNING: Results will not be written if '
'the code crashes. Make sure it exists successfully.')
flags.DEFINE_integer('model_v', 0, 'Model verbosity level.')
flags.DEFINE_bool(
'delayed_graph_cleanup', True,
'If true, container for n-th run will not be reset until the (n+1)-th run '
'is complete. This greatly reduces the chance that a worker is still '
'using the n-th container when it is cleared.')
def define_tuner_hparam_space(hparam_space_type):
"""Define tunable hparams for grid search."""
if hparam_space_type not in ('pg', 'pg-topk', 'topk', 'is'):
raise ValueError('Hparam space is not valid: "%s"' % hparam_space_type)
# Discrete hparam space is stored as a dict from hparam name to discrete
# values.
hparam_space = {}
if hparam_space_type in ('pg', 'pg-topk', 'is'):
# Add a floating point parameter named learning rate.
hparam_space['lr'] = [1e-5, 1e-4, 1e-3]
hparam_space['entropy_beta'] = [0.005, 0.01, 0.05, 0.10]
else: # 'topk'
# Add a floating point parameter named learning rate.
hparam_space['lr'] = [1e-5, 1e-4, 1e-3]
hparam_space['entropy_beta'] = [0.0, 0.005, 0.01, 0.05, 0.10]
if hparam_space_type in ('topk', 'pg-topk'):
# topk tuning will be enabled.
hparam_space['topk'] = [10]
hparam_space['topk_loss_hparam'] = [1.0, 10.0, 50.0, 200.0]
elif hparam_space_type == 'is':
# importance sampling tuning will be enabled.
hparam_space['replay_temperature'] = [0.25, 0.5, 1.0, 2.0]
hparam_space['alpha'] = [0.5, 0.75, 63/64.]
return hparam_space
def write_hparams_to_config(config, hparams, hparam_space_type):
"""Write hparams given by the tuner into the Config object."""
if hparam_space_type not in ('pg', 'pg-topk', 'topk', 'is'):
raise ValueError('Hparam space is not valid: "%s"' % hparam_space_type)
config.agent.lr = hparams.lr
config.agent.entropy_beta = hparams.entropy_beta
if hparam_space_type in ('topk', 'pg-topk'):
# topk tuning will be enabled.
config.agent.topk = hparams.topk
config.agent.topk_loss_hparam = hparams.topk_loss_hparam
elif hparam_space_type == 'is':
# importance sampling tuning will be enabled.
config.agent.replay_temperature = hparams.replay_temperature
config.agent.alpha = hparams.alpha
def make_initialized_variable(value, name, shape=None, dtype=tf.float32):
"""Create a tf.Variable with a constant initializer.
Args:
value: Constant value to initialize the variable with. This is the value
that the variable starts with.
name: Name of the variable in the TF graph.
shape: Shape of the variable. If None, variable will be a scalar.
dtype: Data type of the variable. Should be a TF dtype. Defaults to
tf.float32.
Returns:
tf.Variable instance.
"""
if shape is None:
shape = []
return tf.get_variable(
name=name, shape=shape, initializer=tf.constant_initializer(value),
dtype=dtype, trainable=False)
class AsyncTrainer(object):
"""Manages graph creation and training.
This async trainer creates a global model on the parameter server, and a local
model (for this worker). Gradient updates are sent to the global model, and
the updated weights are synced to the local copy.
"""
def __init__(self, config, task_id, ps_tasks, num_workers, is_chief=True,
summary_writer=None,
dtype=tf.float32,
summary_interval=1,
run_number=0,
logging_dir='/tmp', model_v=0):
self.config = config
self.data_manager = data.DataManager(
config, run_number=run_number,
do_code_simplification=not FLAGS.stop_on_success)
self.task_id = task_id
self.ps_tasks = ps_tasks
self.is_chief = is_chief
if ps_tasks == 0:
assert task_id == 0, 'No parameter servers specified. Expecting 1 task.'
assert num_workers == 1, (
'No parameter servers specified. Expecting 1 task.')
worker_device = '/job:localhost/replica:%d/task:0/cpu:0' % task_id
# worker_device = '/cpu:0'
# ps_device = '/cpu:0'
else:
assert num_workers > 0, 'There must be at least 1 training worker.'
worker_device = '/job:worker/replica:%d/task:0/cpu:0' % task_id
# ps_device = '/job:ps/replica:0/task:0/cpu:0'
logging.info('worker_device: %s', worker_device)
logging_file = os.path.join(
logging_dir, 'solutions_%d.txt' % task_id)
experience_replay_file = os.path.join(
logging_dir, 'replay_buffer_%d.pickle' % task_id)
self.topk_file = os.path.join(
logging_dir, 'topk_buffer_%d.pickle' % task_id)
tf.get_variable_scope().set_use_resource(True)
# global model
with tf.device(tf.train.replica_device_setter(ps_tasks,
ps_device='/job:ps/replica:0',
worker_device=worker_device)):
with tf.variable_scope('global'):
global_model = agent_lib.LMAgent(config, dtype=dtype, is_local=False)
global_params_dict = {p.name: p
for p in global_model.sync_variables}
self.global_model = global_model
self.global_step = make_initialized_variable(
0, 'global_step', dtype=tf.int64)
self.global_best_reward = make_initialized_variable(
-10.0, 'global_best_reward', dtype=tf.float64)
self.is_best_model = make_initialized_variable(
False, 'is_best_model', dtype=tf.bool)
self.reset_is_best_model = self.is_best_model.assign(False)
self.global_best_reward_placeholder = tf.placeholder(
tf.float64, [], name='global_best_reward_placeholder')
self.assign_global_best_reward_op = tf.group(
self.global_best_reward.assign(
self.global_best_reward_placeholder),
self.is_best_model.assign(True))
def assign_global_best_reward_fn(session, reward):
reward = round(reward, 10)
best_reward = round(session.run(self.global_best_reward), 10)
is_best = reward > best_reward
if is_best:
session.run(self.assign_global_best_reward_op,
{self.global_best_reward_placeholder: reward})
return is_best
self.assign_global_best_reward_fn = assign_global_best_reward_fn
# Any worker will set to true when it finds a solution.
self.found_solution_flag = make_initialized_variable(
False, 'found_solution_flag', dtype=tf.bool)
self.found_solution_op = self.found_solution_flag.assign(True)
self.run_number = make_initialized_variable(
run_number, 'run_number', dtype=tf.int32)
# Store a solution when found.
self.code_solution_variable = tf.get_variable(
'code_solution', [], tf.string,
initializer=tf.constant_initializer(''))
self.code_solution_ph = tf.placeholder(
tf.string, [], name='code_solution_ph')
self.code_solution_assign_op = self.code_solution_variable.assign(
self.code_solution_ph)
def assign_code_solution_fn(session, code_solution_string):
session.run(self.code_solution_assign_op,
{self.code_solution_ph: code_solution_string})
self.assign_code_solution_fn = assign_code_solution_fn
# Count all programs sampled from policy. This does not include
# programs sampled from replay buffer.
# This equals NPE (number of programs executed). Only programs sampled
# from the policy need to be executed.
self.program_count = make_initialized_variable(
0, 'program_count', dtype=tf.int64)
# local model
with tf.device(worker_device):
with tf.variable_scope('local'):
self.model = model = agent_lib.LMAgent(
config,
task_id=task_id,
logging_file=logging_file,
experience_replay_file=experience_replay_file,
dtype=dtype,
global_best_reward_fn=self.assign_global_best_reward_fn,
found_solution_op=self.found_solution_op,
assign_code_solution_fn=self.assign_code_solution_fn,
program_count=self.program_count,
stop_on_success=FLAGS.stop_on_success,
verbose_level=model_v)
local_params = model.trainable_variables
local_params_dict = {p.name: p for p in local_params}
# Pull global params to local model.
def _global_to_local_scope(name):
assert name.startswith('global/')
return 'local' + name[6:]
sync_dict = {
local_params_dict[_global_to_local_scope(p_name)]: p
for p_name, p in global_params_dict.items()}
self.sync_op = tf.group(*[v_local.assign(v_global)
for v_local, v_global
in sync_dict.items()])
# Pair local gradients with global params.
grad_var_dict = {
gradient: sync_dict[local_var]
for local_var, gradient in model.gradients_dict.items()}
# local model
model.make_summary_ops() # Don't put summaries under 'local' scope.
with tf.variable_scope('local'):
self.train_op = model.optimizer.apply_gradients(
grad_var_dict.items(), global_step=self.global_step)
self.local_init_op = tf.variables_initializer(
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
tf.get_variable_scope().name))
self.local_step = 0
self.last_summary_time = time.time()
self.summary_interval = summary_interval
self.summary_writer = summary_writer
self.cached_global_step = -1
self.cached_global_npe = -1
logging.info('summary_interval: %d', self.summary_interval)
# Load top-k buffer.
if self.model.top_episodes is not None and tf.gfile.Exists(self.topk_file):
try:
with tf.gfile.FastGFile(self.topk_file, 'r') as f:
self.model.top_episodes = cPickle.loads(f.read())
logging.info(
'Loaded top-k buffer from disk with %d items. Location: "%s"',
len(self.model.top_episodes), self.topk_file)
except (cPickle.UnpicklingError, EOFError) as e:
logging.warn(
'Failed to load existing top-k buffer from disk. Removing bad file.'
'\nLocation: "%s"\nException: %s', self.topk_file, str(e))
tf.gfile.Remove(self.topk_file)
def initialize(self, session):
"""Run initialization ops."""
session.run(self.local_init_op)
session.run(self.sync_op)
self.cached_global_step, self.cached_global_npe = session.run(
[self.global_step, self.program_count])
def update_global_model(self, session):
"""Run an update step.
1) Asynchronously copy global weights to local model.
2) Call into local model's update_step method, which does the following:
a) Sample batch of programs from policy.
b) Compute rewards.
c) Compute gradients and update the global model asynchronously.
3) Write tensorboard summaries to disk.
Args:
session: tf.Session instance.
"""
session.run(self.sync_op) # Copy weights from global to local.
with session.as_default():
result = self.model.update_step(
session, self.data_manager.sample_rl_batch(), self.train_op,
self.global_step)
global_step = result.global_step
global_npe = result.global_npe
summaries = result.summaries_list
self.cached_global_step = global_step
self.cached_global_npe = global_npe
self.local_step += 1
if self.summary_writer and self.local_step % self.summary_interval == 0:
if not isinstance(summaries, (tuple, list)):
summaries = [summaries]
summaries.append(self._local_step_summary())
if self.is_chief:
(global_best_reward,
found_solution_flag,
program_count) = session.run(
[self.global_best_reward,
self.found_solution_flag,
self.program_count])
summaries.append(
tf.Summary(
value=[tf.Summary.Value(
tag='model/best_reward',
simple_value=global_best_reward)]))
summaries.append(
tf.Summary(
value=[tf.Summary.Value(
tag='model/solution_found',
simple_value=int(found_solution_flag))]))
summaries.append(
tf.Summary(
value=[tf.Summary.Value(
tag='model/program_count',
simple_value=program_count)]))
for s in summaries:
self.summary_writer.add_summary(s, global_step)
self.last_summary_time = time.time()
def _local_step_summary(self):
"""Compute number of local steps per time increment."""
dt = time.time() - self.last_summary_time
steps_per_time = self.summary_interval / float(dt)
return tf.Summary(value=[
tf.Summary.Value(
tag='local_step/per_sec',
simple_value=steps_per_time),
tf.Summary.Value(
tag='local_step/step',
simple_value=self.local_step)])
def maybe_save_best_model(self, session, saver, checkpoint_file):
"""Check if this model got the highest reward and save to disk if so."""
if self.is_chief and session.run(self.is_best_model):
logging.info('Saving best model to "%s"', checkpoint_file)
saver.save(session, checkpoint_file)
session.run(self.reset_is_best_model)
def save_replay_buffer(self):
"""Save replay buffer to disk.
Call this periodically so that training can recover if jobs go down.
"""
if self.model.experience_replay is not None:
logging.info('Saving experience replay buffer to "%s".',
self.model.experience_replay.save_file)
self.model.experience_replay.incremental_save(True)
def delete_replay_buffer(self):
"""Delete replay buffer from disk.
Call this at the end of training to clean up. Replay buffer can get very
large.
"""
if self.model.experience_replay is not None:
logging.info('Deleting experience replay buffer at "%s".',
self.model.experience_replay.save_file)
tf.gfile.Remove(self.model.experience_replay.save_file)
def save_topk_buffer(self):
"""Save top-k buffer to disk.
Call this periodically so that training can recover if jobs go down.
"""
if self.model.top_episodes is not None:
logging.info('Saving top-k buffer to "%s".', self.topk_file)
# Overwrite previous data each time.
with tf.gfile.FastGFile(self.topk_file, 'w') as f:
f.write(cPickle.dumps(self.model.top_episodes))
@contextlib.contextmanager
def managed_session(sv, master='', config=None,
start_standard_services=True,
close_summary_writer=True,
max_wait_secs=7200):
# Same as Supervisor.managed_session, but with configurable timeout.
try:
sess = sv.prepare_or_wait_for_session(
master=master, config=config,
start_standard_services=start_standard_services,
max_wait_secs=max_wait_secs)
yield sess
except tf.errors.DeadlineExceededError:
raise
except Exception as e: # pylint: disable=broad-except
sv.request_stop(e)
finally:
try:
# Request all the threads to stop and wait for them to do so. Any
# exception raised by the threads is raised again from stop().
# Passing stop_grace_period_secs is for blocked enqueue/dequeue
# threads which are not checking for `should_stop()`. They
# will be stopped when we close the session further down.
sv.stop(close_summary_writer=close_summary_writer)
finally:
# Close the session to finish up all pending calls. We do not care
# about exceptions raised when closing. This takes care of
# blocked enqueue/dequeue calls.
try:
sess.close()
except Exception: # pylint: disable=broad-except
# Silently ignore exceptions raised by close().
pass
def train(config, is_chief, tuner=None, run_dir=None, run_number=0,
results_writer=None):
"""Run training loop.
Args:
config: config_lib.Config instance containing global config (agent and env).
is_chief: True if this worker is chief. Chief worker manages writing some
data to disk and initialization of the global model.
tuner: A tuner instance. If not tuning, leave as None.
run_dir: Directory where all data for this run will be written. If None,
run_dir = FLAGS.logdir. Set this argument when doing multiple runs.
run_number: Which run is this.
results_writer: Managest writing training results to disk. Results are a
dict of metric names and values.
Returns:
The trainer object used to run training updates.
"""
logging.info('Will run asynchronous training.')
if run_dir is None:
run_dir = FLAGS.logdir
train_dir = os.path.join(run_dir, 'train')
best_model_checkpoint = os.path.join(train_dir, 'best.ckpt')
events_dir = '%s/events_%d' % (run_dir, FLAGS.task_id)
logging.info('Events directory: %s', events_dir)
logging_dir = os.path.join(run_dir, 'logs')
if not tf.gfile.Exists(logging_dir):
tf.gfile.MakeDirs(logging_dir)
status_file = os.path.join(logging_dir, 'status.txt')
if FLAGS.summary_tasks and FLAGS.task_id < FLAGS.summary_tasks:
summary_writer = tf.summary.FileWriter(events_dir)
else:
summary_writer = None
# Only profile task 0.
if FLAGS.do_profiling:
logging.info('Profiling enabled')
profiler = cProfile.Profile()
profiler.enable()
else:
profiler = None
trainer = AsyncTrainer(
config, FLAGS.task_id, FLAGS.ps_tasks, FLAGS.num_workers,
is_chief=is_chief,
summary_interval=FLAGS.summary_interval,
summary_writer=summary_writer,
logging_dir=logging_dir,
run_number=run_number,
model_v=FLAGS.model_v)
variables_to_save = [v for v in tf.global_variables()
if v.name.startswith('global')]
global_init_op = tf.variables_initializer(variables_to_save)
saver = tf.train.Saver(variables_to_save)
var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
logging.info('Trainable vars:')
for v in var_list:
logging.info(' %s, %s, %s', v.name, v.device, v.get_shape())
logging.info('All vars:')
for v in tf.global_variables():
logging.info(' %s, %s, %s', v.name, v.device, v.get_shape())
def init_fn(unused_sess):
logging.info('No checkpoint found. Initialized global params.')
sv = tf.train.Supervisor(is_chief=is_chief,
logdir=train_dir,
saver=saver,
summary_op=None,
init_op=global_init_op,
init_fn=init_fn,
summary_writer=summary_writer,
ready_op=tf.report_uninitialized_variables(
variables_to_save),
ready_for_local_init_op=None,
global_step=trainer.global_step,
save_model_secs=30,
save_summaries_secs=30)
# Add a thread that periodically checks if this Trial should stop
# based on an early stopping policy.
if tuner:
sv.Loop(60, tuner.check_for_stop, (sv.coord,))
last_replay_save_time = time.time()
global_step = -1
logging.info(
'Starting session. '
'If this hangs, we\'re mostly likely waiting to connect '
'to the parameter server. One common cause is that the parameter '
'server DNS name isn\'t resolving yet, or is misspecified.')
should_retry = True
supervisor_deadline_exceeded = False
while should_retry:
try:
with managed_session(
sv, FLAGS.master, max_wait_secs=60) as session, session.as_default():
should_retry = False
do_training = True
try:
trainer.initialize(session)
if session.run(trainer.run_number) != run_number:
# If we loaded existing model from disk, and the saved run number is
# different, throw an exception.
raise RuntimeError(
'Expecting to be on run %d, but is actually on run %d. '
'run_dir: "%s"'
% (run_number, session.run(trainer.run_number), run_dir))
global_step = trainer.cached_global_step
logging.info('Starting training at step=%d', global_step)
while do_training:
trainer.update_global_model(session)
if is_chief:
trainer.maybe_save_best_model(
session, saver, best_model_checkpoint)
global_step = trainer.cached_global_step
global_npe = trainer.cached_global_npe
if time.time() - last_replay_save_time >= 30:
trainer.save_replay_buffer()
trainer.save_topk_buffer()
last_replay_save_time = time.time()
# Stopping conditions.
if tuner and tuner.should_trial_stop():
logging.info('Tuner requested early stopping. Finishing.')
do_training = False
if is_chief and FLAGS.stop_on_success:
found_solution = session.run(trainer.found_solution_flag)
if found_solution:
do_training = False
logging.info('Solution found. Finishing.')
if FLAGS.max_npe and global_npe >= FLAGS.max_npe:
# Max NPE (number of programs executed) reached.
logging.info('Max NPE reached. Finishing.')
do_training = False
if sv.should_stop():
logging.info('Supervisor issued stop. Finishing.')
do_training = False
except tf.errors.NotFoundError:
# Catch "Error while reading resource variable".
# The chief worker likely destroyed the container, so do not retry.
logging.info('Caught NotFoundError. Quitting.')
do_training = False
should_retry = False
break
except tf.errors.InternalError as e:
# Catch "Invalid variable reference."
if str(e).startswith('Invalid variable reference.'):
# The chief worker likely destroyed the container, so do not
# retry.
logging.info(
'Caught "InternalError: Invalid variable reference.". '
'Quitting.')
do_training = False
should_retry = False
break
else:
# Pass exception through.
raise
# Exited training loop. Write results to disk.
if is_chief and results_writer:
assert not should_retry
with tf.gfile.FastGFile(status_file, 'w') as f:
f.write('done')
(program_count,
found_solution,
code_solution,
best_reward,
global_step) = session.run(
[trainer.program_count,
trainer.found_solution_flag,
trainer.code_solution_variable,
trainer.global_best_reward,
trainer.global_step])
results_dict = {
'max_npe': FLAGS.max_npe,
'batch_size': config.batch_size,
'max_batches': FLAGS.max_npe // config.batch_size,
'npe': program_count,
'max_global_repetitions': FLAGS.num_repetitions,
'max_local_repetitions': FLAGS.num_repetitions,
'code_solution': code_solution,
'best_reward': best_reward,
'num_batches': global_step,
'found_solution': found_solution,
'task': trainer.data_manager.task_name,
'global_rep': run_number}
logging.info('results_dict: %s', results_dict)
results_writer.append(results_dict)
except tf.errors.AbortedError:
# Catch "Graph handle is not found" error due to preempted jobs.
logging.info('Caught AbortedError. Retying.')
should_retry = True
except tf.errors.DeadlineExceededError:
supervisor_deadline_exceeded = True
should_retry = False
if is_chief:
logging.info('This is chief worker. Stopping all workers.')
sv.stop()
if supervisor_deadline_exceeded:
logging.info('Supervisor timed out. Quitting.')
else:
logging.info('Reached %s steps. Worker stopped.', global_step)
# Dump profiling.
"""
How to use profiling data.
Download the profiler dump to your local machine, say to PROF_FILE_PATH.
In a separate script, run something like the following:
import pstats
p = pstats.Stats(PROF_FILE_PATH)
p.strip_dirs().sort_stats('cumtime').print_stats()
This will sort by 'cumtime', which "is the cumulative time spent in this and
all subfunctions (from invocation till exit)."
https://docs.python.org/2/library/profile.html#instant-user-s-manual
""" # pylint: disable=pointless-string-statement
if profiler:
prof_file = os.path.join(run_dir, 'task_%d.prof' % FLAGS.task_id)
logging.info('Done profiling.\nDumping to "%s".', prof_file)
profiler.create_stats()
with tf.gfile.Open(prof_file, 'w') as f:
f.write(marshal.dumps(profiler.stats))
return trainer
def run_training(config=None, tuner=None, logdir=None, trial_name=None,
is_chief=True):
"""Do all training runs.
This is the top level training function for policy gradient based models.
Run this from the main function.
Args:
config: config_lib.Config instance containing global config (agent and
environment hparams). If None, config will be parsed from FLAGS.config.
tuner: A tuner instance. Leave as None if not tuning.
logdir: Parent directory where all data from all runs will be written. If
None, FLAGS.logdir will be used.
trial_name: If tuning, set this to a unique string that identifies this
trial. If `tuner` is not None, this also must be set.
is_chief: True if this worker is the chief.
Returns:
List of results dicts which were written to disk. Each training run gets a
results dict. Results dict contains metrics, i.e. (name, value) pairs which
give information about the training run.
Raises:
ValueError: If results dicts read from disk contain invalid data.
"""
if not config:
# If custom config is not given, get it from flags.
config = defaults.default_config_with_updates(FLAGS.config)
if not logdir:
logdir = FLAGS.logdir
if not tf.gfile.Exists(logdir):
tf.gfile.MakeDirs(logdir)
assert FLAGS.num_repetitions > 0
results = results_lib.Results(logdir)
results_list, _ = results.read_all()
logging.info('Starting experiment. Directory: "%s"', logdir)
if results_list:
if results_list[0]['max_npe'] != FLAGS.max_npe:
raise ValueError(
'Cannot resume training. Max-NPE changed. Was %s, now %s',
results_list[0]['max_npe'], FLAGS.max_npe)
if results_list[0]['max_global_repetitions'] != FLAGS.num_repetitions:
raise ValueError(
'Cannot resume training. Number of repetitions changed. Was %s, '
'now %s',
results_list[0]['max_global_repetitions'],
FLAGS.num_repetitions)
while len(results_list) < FLAGS.num_repetitions:
run_number = len(results_list)
rep_container_name = trial_name if trial_name else 'container'
if FLAGS.num_repetitions > 1:
rep_dir = os.path.join(logdir, 'run_%d' % run_number)
rep_container_name = rep_container_name + '_run_' + str(run_number)
else:
rep_dir = logdir
logging.info(
'Starting repetition %d (%d out of %d)', run_number, run_number + 1,
FLAGS.num_repetitions)
# Train will write result to disk.
with tf.container(rep_container_name):
trainer = train(config, is_chief, tuner, rep_dir, run_number, results)
logging.info('Done training.')
if is_chief:
# Destroy current container immediately (clears current graph).
logging.info('Clearing shared variables.')
tf.Session.reset(FLAGS.master, containers=[rep_container_name])
logging.info('Shared variables cleared.')
# Delete replay buffer on disk.
assert trainer
trainer.delete_replay_buffer()
else:
# Give chief worker time to clean up.
sleep_sec = 30.0
logging.info('Sleeping for %s sec.', sleep_sec)
time.sleep(sleep_sec)
tf.reset_default_graph()
logging.info('Default graph reset.')
# Expecting that train wrote new result to disk before returning.
results_list, _ = results.read_all()
return results_list