Spaces:
Running
Running
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Main function to train various object detection models.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
# from __future__ import google_type_annotations | |
from __future__ import print_function | |
import functools | |
import pprint | |
# pylint: disable=g-bad-import-order | |
import tensorflow as tf | |
from absl import app | |
from absl import flags | |
from absl import logging | |
# pylint: enable=g-bad-import-order | |
from official.modeling.hyperparams import params_dict | |
from official.modeling.training import distributed_executor as executor | |
from official.utils import hyperparams_flags | |
from official.utils.flags import core as flags_core | |
from official.utils.misc import distribution_utils | |
from official.utils.misc import keras_utils | |
from official.vision.detection.configs import factory as config_factory | |
from official.vision.detection.dataloader import input_reader | |
from official.vision.detection.dataloader import mode_keys as ModeKeys | |
from official.vision.detection.executor.detection_executor import DetectionDistributedExecutor | |
from official.vision.detection.modeling import factory as model_factory | |
hyperparams_flags.initialize_common_flags() | |
flags_core.define_log_steps() | |
flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU') | |
flags.DEFINE_string( | |
'mode', default='train', help='Mode to run: `train` or `eval`.') | |
flags.DEFINE_string( | |
'model', default='retinanet', | |
help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.') | |
flags.DEFINE_string('training_file_pattern', None, | |
'Location of the train data.') | |
flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data') | |
flags.DEFINE_string( | |
'checkpoint_path', None, | |
'The checkpoint path to eval. Only used in eval_once mode.') | |
FLAGS = flags.FLAGS | |
def run_executor(params, | |
mode, | |
checkpoint_path=None, | |
train_input_fn=None, | |
eval_input_fn=None, | |
callbacks=None, | |
prebuilt_strategy=None): | |
"""Runs the object detection model on distribution strategy defined by the user.""" | |
if params.architecture.use_bfloat16: | |
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( | |
'mixed_bfloat16') | |
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) | |
model_builder = model_factory.model_generator(params) | |
if prebuilt_strategy is not None: | |
strategy = prebuilt_strategy | |
else: | |
strategy_config = params.strategy_config | |
distribution_utils.configure_cluster(strategy_config.worker_hosts, | |
strategy_config.task_index) | |
strategy = distribution_utils.get_distribution_strategy( | |
distribution_strategy=params.strategy_type, | |
num_gpus=strategy_config.num_gpus, | |
all_reduce_alg=strategy_config.all_reduce_alg, | |
num_packs=strategy_config.num_packs, | |
tpu_address=strategy_config.tpu) | |
num_workers = int(strategy.num_replicas_in_sync + 7) // 8 | |
is_multi_host = (int(num_workers) >= 2) | |
if mode == 'train': | |
def _model_fn(params): | |
return model_builder.build_model(params, mode=ModeKeys.TRAIN) | |
logging.info( | |
'Train num_replicas_in_sync %d num_workers %d is_multi_host %s', | |
strategy.num_replicas_in_sync, num_workers, is_multi_host) | |
dist_executor = DetectionDistributedExecutor( | |
strategy=strategy, | |
params=params, | |
model_fn=_model_fn, | |
loss_fn=model_builder.build_loss_fn, | |
is_multi_host=is_multi_host, | |
predict_post_process_fn=model_builder.post_processing, | |
trainable_variables_filter=model_builder | |
.make_filter_trainable_variables_fn()) | |
if is_multi_host: | |
train_input_fn = functools.partial( | |
train_input_fn, | |
batch_size=params.train.batch_size // strategy.num_replicas_in_sync) | |
return dist_executor.train( | |
train_input_fn=train_input_fn, | |
model_dir=params.model_dir, | |
iterations_per_loop=params.train.iterations_per_loop, | |
total_steps=params.train.total_steps, | |
init_checkpoint=model_builder.make_restore_checkpoint_fn(), | |
custom_callbacks=callbacks, | |
save_config=True) | |
elif mode == 'eval' or mode == 'eval_once': | |
def _model_fn(params): | |
return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT) | |
logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s', | |
strategy.num_replicas_in_sync, num_workers, is_multi_host) | |
if is_multi_host: | |
eval_input_fn = functools.partial( | |
eval_input_fn, | |
batch_size=params.eval.batch_size // strategy.num_replicas_in_sync) | |
dist_executor = DetectionDistributedExecutor( | |
strategy=strategy, | |
params=params, | |
model_fn=_model_fn, | |
loss_fn=model_builder.build_loss_fn, | |
is_multi_host=is_multi_host, | |
predict_post_process_fn=model_builder.post_processing, | |
trainable_variables_filter=model_builder | |
.make_filter_trainable_variables_fn()) | |
if mode == 'eval': | |
results = dist_executor.evaluate_from_model_dir( | |
model_dir=params.model_dir, | |
eval_input_fn=eval_input_fn, | |
eval_metric_fn=model_builder.eval_metrics, | |
eval_timeout=params.eval.eval_timeout, | |
min_eval_interval=params.eval.min_eval_interval, | |
total_steps=params.train.total_steps) | |
else: | |
# Run evaluation once for a single checkpoint. | |
if not checkpoint_path: | |
raise ValueError('checkpoint_path cannot be empty.') | |
if tf.io.gfile.isdir(checkpoint_path): | |
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) | |
summary_writer = executor.SummaryWriter(params.model_dir, 'eval') | |
results, _ = dist_executor.evaluate_checkpoint( | |
checkpoint_path=checkpoint_path, | |
eval_input_fn=eval_input_fn, | |
eval_metric_fn=model_builder.eval_metrics, | |
summary_writer=summary_writer) | |
for k, v in results.items(): | |
logging.info('Final eval metric %s: %f', k, v) | |
return results | |
else: | |
raise ValueError('Mode not found: %s.' % mode) | |
def run(callbacks=None): | |
keras_utils.set_session_config(enable_xla=FLAGS.enable_xla) | |
params = config_factory.config_generator(FLAGS.model) | |
params = params_dict.override_params_dict( | |
params, FLAGS.config_file, is_strict=True) | |
params = params_dict.override_params_dict( | |
params, FLAGS.params_override, is_strict=True) | |
params.override( | |
{ | |
'strategy_type': FLAGS.strategy_type, | |
'model_dir': FLAGS.model_dir, | |
'strategy_config': executor.strategy_flags_dict(), | |
}, | |
is_strict=False) | |
# Make sure use_tpu and strategy_type are in sync. | |
params.use_tpu = (params.strategy_type == 'tpu') | |
if not params.use_tpu: | |
params.override({ | |
'architecture': { | |
'use_bfloat16': False, | |
}, | |
'norm_activation': { | |
'use_sync_bn': False, | |
}, | |
}, is_strict=True) | |
params.validate() | |
params.lock() | |
pp = pprint.PrettyPrinter() | |
params_str = pp.pformat(params.as_dict()) | |
logging.info('Model Parameters: %s', params_str) | |
train_input_fn = None | |
eval_input_fn = None | |
training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern | |
eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern | |
if not training_file_pattern and not eval_file_pattern: | |
raise ValueError('Must provide at least one of training_file_pattern and ' | |
'eval_file_pattern.') | |
if training_file_pattern: | |
# Use global batch size for single host. | |
train_input_fn = input_reader.InputFn( | |
file_pattern=training_file_pattern, | |
params=params, | |
mode=input_reader.ModeKeys.TRAIN, | |
batch_size=params.train.batch_size) | |
if eval_file_pattern: | |
eval_input_fn = input_reader.InputFn( | |
file_pattern=eval_file_pattern, | |
params=params, | |
mode=input_reader.ModeKeys.PREDICT_WITH_GT, | |
batch_size=params.eval.batch_size, | |
num_examples=params.eval.eval_samples) | |
if callbacks is None: | |
callbacks = [] | |
if FLAGS.log_steps: | |
callbacks.append( | |
keras_utils.TimeHistory( | |
batch_size=params.train.batch_size, | |
log_steps=FLAGS.log_steps, | |
)) | |
return run_executor( | |
params, | |
FLAGS.mode, | |
checkpoint_path=FLAGS.checkpoint_path, | |
train_input_fn=train_input_fn, | |
eval_input_fn=eval_input_fn, | |
callbacks=callbacks) | |
def main(argv): | |
del argv # Unused. | |
run() | |
if __name__ == '__main__': | |
tf.config.set_soft_device_placement(True) | |
app.run(main) | |