Spaces:
Running
Running
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
r"""Run training. | |
Choose training algorithm and task(s) and follow these examples. | |
Run synchronous policy gradient training locally: | |
CONFIG="agent=c(algorithm='pg'),env=c(task='reverse')" | |
OUT_DIR="/tmp/bf_pg_local" | |
rm -rf $OUT_DIR | |
bazel run -c opt single_task:run -- \ | |
--alsologtostderr \ | |
--config="$CONFIG" \ | |
--max_npe=0 \ | |
--logdir="$OUT_DIR" \ | |
--summary_interval=1 \ | |
--model_v=0 | |
learning/brain/tensorboard/tensorboard.sh --port 12345 --logdir "$OUT_DIR" | |
Run genetic algorithm locally: | |
CONFIG="agent=c(algorithm='ga'),env=c(task='reverse')" | |
OUT_DIR="/tmp/bf_ga_local" | |
rm -rf $OUT_DIR | |
bazel run -c opt single_task:run -- \ | |
--alsologtostderr \ | |
--config="$CONFIG" \ | |
--max_npe=0 \ | |
--logdir="$OUT_DIR" | |
Run uniform random search locally: | |
CONFIG="agent=c(algorithm='rand'),env=c(task='reverse')" | |
OUT_DIR="/tmp/bf_rand_local" | |
rm -rf $OUT_DIR | |
bazel run -c opt single_task:run -- \ | |
--alsologtostderr \ | |
--config="$CONFIG" \ | |
--max_npe=0 \ | |
--logdir="$OUT_DIR" | |
""" | |
from absl import app | |
from absl import flags | |
from absl import logging | |
from single_task import defaults # brain coder | |
from single_task import ga_train # brain coder | |
from single_task import pg_train # brain coder | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('config', '', 'Configuration.') | |
flags.DEFINE_string( | |
'logdir', None, 'Absolute path where to write results.') | |
flags.DEFINE_integer('task_id', 0, 'ID for this worker.') | |
flags.DEFINE_integer('num_workers', 1, 'How many workers there are.') | |
flags.DEFINE_integer( | |
'max_npe', 0, | |
'NPE = number of programs executed. Maximum number of programs to execute ' | |
'in each run. Training will complete when this threshold is reached. Set ' | |
'to 0 for unlimited training.') | |
flags.DEFINE_integer( | |
'num_repetitions', 1, | |
'Number of times the same experiment will be run (globally across all ' | |
'workers). Each run is independent.') | |
flags.DEFINE_string( | |
'log_level', 'INFO', | |
'The threshold for what messages will be logged. One of DEBUG, INFO, WARN, ' | |
'ERROR, or FATAL.') | |
# To register an algorithm: | |
# 1) Add dependency in the BUILD file to this build rule. | |
# 2) Import the algorithm's module at the top of this file. | |
# 3) Add a new entry in the following dict. The key is the algorithm name | |
# (used to select the algorithm in the config). The value is the module | |
# defining the expected functions for training and tuning. See the docstring | |
# for `get_namespace` for further details. | |
ALGORITHM_REGISTRATION = { | |
'pg': pg_train, | |
'ga': ga_train, | |
'rand': ga_train, | |
} | |
def get_namespace(config_string): | |
"""Get namespace for the selected algorithm. | |
Users who want to add additional algorithm types should modify this function. | |
The algorithm's namespace should contain the following functions: | |
run_training: Run the main training loop. | |
define_tuner_hparam_space: Return the hparam tuning space for the algo. | |
write_hparams_to_config: Helper for tuning. Write hparams chosen for tuning | |
to the Config object. | |
Look at pg_train.py and ga_train.py for function signatures and | |
implementations. | |
Args: | |
config_string: String representation of a Config object. This will get | |
parsed into a Config in order to determine what algorithm to use. | |
Returns: | |
algorithm_namespace: The module corresponding to the algorithm given in the | |
config. | |
config: The Config object resulting from parsing `config_string`. | |
Raises: | |
ValueError: If config.agent.algorithm is not one of the registered | |
algorithms. | |
""" | |
config = defaults.default_config_with_updates(config_string) | |
if config.agent.algorithm not in ALGORITHM_REGISTRATION: | |
raise ValueError('Unknown algorithm type "%s"' % (config.agent.algorithm,)) | |
else: | |
return ALGORITHM_REGISTRATION[config.agent.algorithm], config | |
def main(argv): | |
del argv # Unused. | |
logging.set_verbosity(FLAGS.log_level) | |
flags.mark_flag_as_required('logdir') | |
if FLAGS.num_workers <= 0: | |
raise ValueError('num_workers flag must be greater than 0.') | |
if FLAGS.task_id < 0: | |
raise ValueError('task_id flag must be greater than or equal to 0.') | |
if FLAGS.task_id >= FLAGS.num_workers: | |
raise ValueError( | |
'task_id flag must be strictly less than num_workers flag.') | |
ns, _ = get_namespace(FLAGS.config) | |
ns.run_training(is_chief=FLAGS.task_id == 0) | |
if __name__ == '__main__': | |
app.run(main) | |