# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Main function to train various object detection models.""" from __future__ import absolute_import from __future__ import division # from __future__ import google_type_annotations from __future__ import print_function import functools import pprint # pylint: disable=g-bad-import-order import tensorflow as tf from absl import app from absl import flags from absl import logging # pylint: enable=g-bad-import-order from official.modeling.hyperparams import params_dict from official.modeling.training import distributed_executor as executor from official.utils import hyperparams_flags from official.utils.flags import core as flags_core from official.utils.misc import distribution_utils from official.utils.misc import keras_utils from official.vision.detection.configs import factory as config_factory from official.vision.detection.dataloader import input_reader from official.vision.detection.dataloader import mode_keys as ModeKeys from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor from official.vision.detection.modeling import factory as model_factory hyperparams_flags.initialize_common_flags() flags_core.define_log_steps() flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU') flags.DEFINE_string( 'mode', default='train', help='Mode to run: `train` or `eval`.') flags.DEFINE_string( 'model', default='retinanet', help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.') flags.DEFINE_string('training_file_pattern', None, 'Location of the train data.') flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data') flags.DEFINE_string( 'checkpoint_path', None, 'The checkpoint path to eval. Only used in eval_once mode.') FLAGS = flags.FLAGS def run_executor(params, mode, checkpoint_path=None, train_input_fn=None, eval_input_fn=None, callbacks=None, prebuilt_strategy=None): """Runs the object detection model on distribution strategy defined by the user.""" if params.architecture.use_bfloat16: policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( 'mixed_bfloat16') tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) model_builder = model_factory.model_generator(params) if prebuilt_strategy is not None: strategy = prebuilt_strategy else: strategy_config = params.strategy_config distribution_utils.configure_cluster(strategy_config.worker_hosts, strategy_config.task_index) strategy = distribution_utils.get_distribution_strategy( distribution_strategy=params.strategy_type, num_gpus=strategy_config.num_gpus, all_reduce_alg=strategy_config.all_reduce_alg, num_packs=strategy_config.num_packs, tpu_address=strategy_config.tpu) num_workers = int(strategy.num_replicas_in_sync + 7) // 8 is_multi_host = (int(num_workers) >= 2) if mode == 'train': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.TRAIN) logging.info( 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder .make_filter_trainable_variables_fn()) if is_multi_host: train_input_fn = functools.partial( train_input_fn, batch_size=params.train.batch_size // strategy.num_replicas_in_sync) return dist_executor.train( train_input_fn=train_input_fn, model_dir=params.model_dir, iterations_per_loop=params.train.iterations_per_loop, total_steps=params.train.total_steps, init_checkpoint=model_builder.make_restore_checkpoint_fn(), custom_callbacks=callbacks, save_config=True) elif mode == 'eval' or mode == 'eval_once': def _model_fn(params): return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', strategy.num_replicas_in_sync, num_workers, is_multi_host) if is_multi_host: eval_input_fn = functools.partial( eval_input_fn, batch_size=params.eval.batch_size // strategy.num_replicas_in_sync) dist_executor = DetectionDistributedExecutor( strategy=strategy, params=params, model_fn=_model_fn, loss_fn=model_builder.build_loss_fn, is_multi_host=is_multi_host, predict_post_process_fn=model_builder.post_processing, trainable_variables_filter=model_builder .make_filter_trainable_variables_fn()) if mode == 'eval': results = dist_executor.evaluate_from_model_dir( model_dir=params.model_dir, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, eval_timeout=params.eval.eval_timeout, min_eval_interval=params.eval.min_eval_interval, total_steps=params.train.total_steps) else: # Run evaluation once for a single checkpoint. if not checkpoint_path: raise ValueError('checkpoint_path cannot be empty.') if tf.io.gfile.isdir(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) summary_writer = executor.SummaryWriter(params.model_dir, 'eval') results, _ = dist_executor.evaluate_checkpoint( checkpoint_path=checkpoint_path, eval_input_fn=eval_input_fn, eval_metric_fn=model_builder.eval_metrics, summary_writer=summary_writer) for k, v in results.items(): logging.info('Final eval metric %s: %f', k, v) return results else: raise ValueError('Mode not found: %s.' % mode) def run(callbacks=None): keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) params = config_factory.config_generator(FLAGS.model) params = params_dict.override_params_dict( params, FLAGS.config_file, is_strict=True) params = params_dict.override_params_dict( params, FLAGS.params_override, is_strict=True) params.override( { 'strategy_type': FLAGS.strategy_type, 'model_dir': FLAGS.model_dir, 'strategy_config': executor.strategy_flags_dict(), }, is_strict=False) # Make sure use_tpu and strategy_type are in sync. params.use_tpu = (params.strategy_type == 'tpu') if not params.use_tpu: params.override({ 'architecture': { 'use_bfloat16': False, }, 'norm_activation': { 'use_sync_bn': False, }, }, is_strict=True) params.validate() params.lock() pp = pprint.PrettyPrinter() params_str = pp.pformat(params.as_dict()) logging.info('Model Parameters: %s', params_str) train_input_fn = None eval_input_fn = None training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern if not training_file_pattern and not eval_file_pattern: raise ValueError('Must provide at least one of training_file_pattern and ' 'eval_file_pattern.') if training_file_pattern: # Use global batch size for single host. train_input_fn = input_reader.InputFn( file_pattern=training_file_pattern, params=params, mode=input_reader.ModeKeys.TRAIN, batch_size=params.train.batch_size) if eval_file_pattern: eval_input_fn = input_reader.InputFn( file_pattern=eval_file_pattern, params=params, mode=input_reader.ModeKeys.PREDICT_WITH_GT, batch_size=params.eval.batch_size, num_examples=params.eval.eval_samples) if callbacks is None: callbacks = [] if FLAGS.log_steps: callbacks.append( keras_utils.TimeHistory( batch_size=params.train.batch_size, log_steps=FLAGS.log_steps, )) return run_executor( params, FLAGS.mode, checkpoint_path=FLAGS.checkpoint_path, train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, callbacks=callbacks) def main(argv): del argv # Unused. run() if __name__ == '__main__': tf.config.set_soft_device_placement(True) app.run(main)