Spaces:
Running
Running
# Lint as: python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
r"""Creates and runs TF2 object detection models. | |
################################## | |
NOTE: This module has not been fully tested; please bear with us while we iron | |
out the kinks. | |
################################## | |
When a TPU device is available, this binary uses TPUStrategy. Otherwise, it uses | |
GPUS with MirroredStrategy/MultiWorkerMirroredStrategy. | |
For local training/evaluation run: | |
PIPELINE_CONFIG_PATH=path/to/pipeline.config | |
MODEL_DIR=/tmp/model_outputs | |
NUM_TRAIN_STEPS=10000 | |
SAMPLE_1_OF_N_EVAL_EXAMPLES=1 | |
python model_main_tf2.py -- \ | |
--model_dir=$MODEL_DIR --num_train_steps=$NUM_TRAIN_STEPS \ | |
--sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \ | |
--pipeline_config_path=$PIPELINE_CONFIG_PATH \ | |
--alsologtostderr | |
""" | |
from absl import flags | |
import tensorflow.compat.v2 as tf | |
from object_detection import model_hparams | |
from object_detection import model_lib_v2 | |
flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config ' | |
'file.') | |
flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.') | |
flags.DEFINE_bool('eval_on_train_data', False, 'Enable evaluating on train ' | |
'data (only supported in distributed training).') | |
flags.DEFINE_integer('sample_1_of_n_eval_examples', None, 'Will sample one of ' | |
'every n eval input examples, where n is provided.') | |
flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample ' | |
'one of every n train input examples for evaluation, ' | |
'where n is provided. This is only used if ' | |
'`eval_training_data` is True.') | |
flags.DEFINE_string( | |
'hparams_overrides', None, 'Hyperparameter overrides, ' | |
'represented as a string containing comma-separated ' | |
'hparam_name=value pairs.') | |
flags.DEFINE_string( | |
'model_dir', None, 'Path to output model directory ' | |
'where event and checkpoint files will be written.') | |
flags.DEFINE_string( | |
'checkpoint_dir', None, 'Path to directory holding a checkpoint. If ' | |
'`checkpoint_dir` is provided, this binary operates in eval-only mode, ' | |
'writing resulting metrics to `model_dir`.') | |
flags.DEFINE_integer('eval_timeout', 3600, 'Number of seconds to wait for an' | |
'evaluation checkpoint before exiting.') | |
flags.DEFINE_integer( | |
'num_workers', 1, 'When num_workers > 1, training uses ' | |
'MultiWorkerMirroredStrategy. When num_workers = 1 it uses ' | |
'MirroredStrategy.') | |
FLAGS = flags.FLAGS | |
def main(unused_argv): | |
flags.mark_flag_as_required('model_dir') | |
flags.mark_flag_as_required('pipeline_config_path') | |
tf.config.set_soft_device_placement(True) | |
if FLAGS.checkpoint_dir: | |
model_lib_v2.eval_continuously( | |
hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), | |
pipeline_config_path=FLAGS.pipeline_config_path, | |
model_dir=FLAGS.model_dir, | |
train_steps=FLAGS.num_train_steps, | |
sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples, | |
sample_1_of_n_eval_on_train_examples=( | |
FLAGS.sample_1_of_n_eval_on_train_examples), | |
checkpoint_dir=FLAGS.checkpoint_dir, | |
wait_interval=300, timeout=FLAGS.eval_timeout) | |
else: | |
if tf.config.get_visible_devices('TPU'): | |
resolver = tf.distribute.cluster_resolver.TPUClusterResolver() | |
tf.config.experimental_connect_to_cluster(resolver) | |
tf.tpu.experimental.initialize_tpu_system(resolver) | |
strategy = tf.distribute.experimental.TPUStrategy(resolver) | |
elif FLAGS.num_workers > 1: | |
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() | |
else: | |
strategy = tf.compat.v2.distribute.MirroredStrategy() | |
with strategy.scope(): | |
model_lib_v2.train_loop( | |
hparams=model_hparams.create_hparams(FLAGS.hparams_overrides), | |
pipeline_config_path=FLAGS.pipeline_config_path, | |
model_dir=FLAGS.model_dir, | |
train_steps=FLAGS.num_train_steps, | |
use_tpu=FLAGS.use_tpu) | |
if __name__ == '__main__': | |
tf.app.run() | |