# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for object_detection.trainer.""" import tensorflow as tf from google.protobuf import text_format from object_detection.core import losses from object_detection.core import model from object_detection.core import standard_fields as fields from object_detection.legacy import trainer from object_detection.protos import train_pb2 NUMBER_OF_CLASSES = 2 def get_input_function(): """A function to get test inputs. Returns an image with one box.""" image = tf.random_uniform([32, 32, 3], dtype=tf.float32) key = tf.constant('image_000000') class_label = tf.random_uniform( [1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32) box_label = tf.random_uniform( [1, 4], minval=0.4, maxval=0.6, dtype=tf.float32) multiclass_scores = tf.random_uniform( [1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32) return { fields.InputDataFields.image: image, fields.InputDataFields.key: key, fields.InputDataFields.groundtruth_classes: class_label, fields.InputDataFields.groundtruth_boxes: box_label, fields.InputDataFields.multiclass_scores: multiclass_scores } class FakeDetectionModel(model.DetectionModel): """A simple (and poor) DetectionModel for use in test.""" def __init__(self): super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES) self._classification_loss = losses.WeightedSigmoidClassificationLoss() self._localization_loss = losses.WeightedSmoothL1LocalizationLoss() def preprocess(self, inputs): """Input preprocessing, resizes images to 28x28. Args: inputs: a [batch, height_in, width_in, channels] float32 tensor representing a batch of images with values between 0 and 255.0. Returns: preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor. true_image_shapes: int32 tensor of shape [batch, 3] where each row is of the form [height, width, channels] indicating the shapes of true images in the resized images, as resized images can be padded with zeros. """ true_image_shapes = [inputs.shape[:-1].as_list() for _ in range(inputs.shape[-1])] return tf.image.resize_images(inputs, [28, 28]), true_image_shapes def predict(self, preprocessed_inputs, true_image_shapes): """Prediction tensors from inputs tensor. Args: preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor. true_image_shapes: int32 tensor of shape [batch, 3] where each row is of the form [height, width, channels] indicating the shapes of true images in the resized images, as resized images can be padded with zeros. Returns: prediction_dict: a dictionary holding prediction tensors to be passed to the Loss or Postprocess functions. """ flattened_inputs = tf.contrib.layers.flatten(preprocessed_inputs) class_prediction = tf.contrib.layers.fully_connected( flattened_inputs, self._num_classes) box_prediction = tf.contrib.layers.fully_connected(flattened_inputs, 4) return { 'class_predictions_with_background': tf.reshape( class_prediction, [-1, 1, self._num_classes]), 'box_encodings': tf.reshape(box_prediction, [-1, 1, 4]) } def postprocess(self, prediction_dict, true_image_shapes, **params): """Convert predicted output tensors to final detections. Unused. Args: prediction_dict: a dictionary holding prediction tensors. true_image_shapes: int32 tensor of shape [batch, 3] where each row is of the form [height, width, channels] indicating the shapes of true images in the resized images, as resized images can be padded with zeros. **params: Additional keyword arguments for specific implementations of DetectionModel. Returns: detections: a dictionary with empty fields. """ return { 'detection_boxes': None, 'detection_scores': None, 'detection_classes': None, 'num_detections': None } def loss(self, prediction_dict, true_image_shapes): """Compute scalar loss tensors with respect to provided groundtruth. Calling this function requires that groundtruth tensors have been provided via the provide_groundtruth function. Args: prediction_dict: a dictionary holding predicted tensors true_image_shapes: int32 tensor of shape [batch, 3] where each row is of the form [height, width, channels] indicating the shapes of true images in the resized images, as resized images can be padded with zeros. Returns: a dictionary mapping strings (loss names) to scalar tensors representing loss values. """ batch_reg_targets = tf.stack( self.groundtruth_lists(fields.BoxListFields.boxes)) batch_cls_targets = tf.stack( self.groundtruth_lists(fields.BoxListFields.classes)) weights = tf.constant( 1.0, dtype=tf.float32, shape=[len(self.groundtruth_lists(fields.BoxListFields.boxes)), 1]) location_losses = self._localization_loss( prediction_dict['box_encodings'], batch_reg_targets, weights=weights) cls_losses = self._classification_loss( prediction_dict['class_predictions_with_background'], batch_cls_targets, weights=weights) loss_dict = { 'localization_loss': tf.reduce_sum(location_losses), 'classification_loss': tf.reduce_sum(cls_losses), } return loss_dict def regularization_losses(self): """Returns a list of regularization losses for this model. Returns a list of regularization losses for this model that the estimator needs to use during training/optimization. Returns: A list of regularization loss tensors. """ pass def restore_map(self, fine_tune_checkpoint_type='detection'): """Returns a map of variables to load from a foreign checkpoint. Args: fine_tune_checkpoint_type: whether to restore from a full detection checkpoint (with compatible variable names) or to restore from a classification checkpoint for initialization prior to training. Valid values: `detection`, `classification`. Default 'detection'. Returns: A dict mapping variable names to variables. """ return {var.op.name: var for var in tf.global_variables()} def updates(self): """Returns a list of update operators for this model. Returns a list of update operators for this model that must be executed at each training step. The estimator's train op needs to have a control dependency on these updates. Returns: A list of update operators. """ pass class TrainerTest(tf.test.TestCase): def test_configure_trainer_and_train_two_steps(self): train_config_text_proto = """ optimizer { adam_optimizer { learning_rate { constant_learning_rate { learning_rate: 0.01 } } } } data_augmentation_options { random_adjust_brightness { max_delta: 0.2 } } data_augmentation_options { random_adjust_contrast { min_delta: 0.7 max_delta: 1.1 } } num_steps: 2 """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) train_dir = self.get_temp_dir() trainer.train( create_tensor_dict_fn=get_input_function, create_model_fn=FakeDetectionModel, train_config=train_config, master='', task=0, num_clones=1, worker_replicas=1, clone_on_cpu=True, ps_tasks=0, worker_job_name='worker', is_chief=True, train_dir=train_dir) def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self): train_config_text_proto = """ optimizer { adam_optimizer { learning_rate { constant_learning_rate { learning_rate: 0.01 } } } } data_augmentation_options { random_adjust_brightness { max_delta: 0.2 } } data_augmentation_options { random_adjust_contrast { min_delta: 0.7 max_delta: 1.1 } } num_steps: 2 use_multiclass_scores: true """ train_config = train_pb2.TrainConfig() text_format.Merge(train_config_text_proto, train_config) train_dir = self.get_temp_dir() trainer.train(create_tensor_dict_fn=get_input_function, create_model_fn=FakeDetectionModel, train_config=train_config, master='', task=0, num_clones=1, worker_replicas=1, clone_on_cpu=True, ps_tasks=0, worker_job_name='worker', is_chief=True, train_dir=train_dir) if __name__ == '__main__': tf.test.main()