Spaces:
Running
Running
# Lint as: python2, python3 | |
# Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Functions to export object detection inference graph.""" | |
import os | |
import tensorflow.compat.v2 as tf | |
from object_detection.builders import model_builder | |
from object_detection.core import standard_fields as fields | |
from object_detection.data_decoders import tf_example_decoder | |
from object_detection.utils import config_util | |
def _decode_image(encoded_image_string_tensor): | |
image_tensor = tf.image.decode_image(encoded_image_string_tensor, | |
channels=3) | |
image_tensor.set_shape((None, None, 3)) | |
return image_tensor | |
def _decode_tf_example(tf_example_string_tensor): | |
tensor_dict = tf_example_decoder.TfExampleDecoder().decode( | |
tf_example_string_tensor) | |
image_tensor = tensor_dict[fields.InputDataFields.image] | |
return image_tensor | |
class DetectionInferenceModule(tf.Module): | |
"""Detection Inference Module.""" | |
def __init__(self, detection_model): | |
"""Initializes a module for detection. | |
Args: | |
detection_model: The detection model to use for inference. | |
""" | |
self._model = detection_model | |
def _run_inference_on_images(self, image): | |
"""Cast image to float and run inference. | |
Args: | |
image: uint8 Tensor of shape [1, None, None, 3] | |
Returns: | |
Tensor dictionary holding detections. | |
""" | |
label_id_offset = 1 | |
image = tf.cast(image, tf.float32) | |
image, shapes = self._model.preprocess(image) | |
prediction_dict = self._model.predict(image, shapes) | |
detections = self._model.postprocess(prediction_dict, shapes) | |
classes_field = fields.DetectionResultFields.detection_classes | |
detections[classes_field] = ( | |
tf.cast(detections[classes_field], tf.float32) + label_id_offset) | |
for key, val in detections.items(): | |
detections[key] = tf.cast(val, tf.float32) | |
return detections | |
class DetectionFromImageModule(DetectionInferenceModule): | |
"""Detection Inference Module for image inputs.""" | |
def __call__(self, input_tensor): | |
return self._run_inference_on_images(input_tensor) | |
class DetectionFromFloatImageModule(DetectionInferenceModule): | |
"""Detection Inference Module for float image inputs.""" | |
def __call__(self, input_tensor): | |
return self._run_inference_on_images(input_tensor) | |
class DetectionFromEncodedImageModule(DetectionInferenceModule): | |
"""Detection Inference Module for encoded image string inputs.""" | |
def __call__(self, input_tensor): | |
with tf.device('cpu:0'): | |
image = tf.map_fn( | |
_decode_image, | |
elems=input_tensor, | |
dtype=tf.uint8, | |
parallel_iterations=32, | |
back_prop=False) | |
return self._run_inference_on_images(image) | |
class DetectionFromTFExampleModule(DetectionInferenceModule): | |
"""Detection Inference Module for TF.Example inputs.""" | |
def __call__(self, input_tensor): | |
with tf.device('cpu:0'): | |
image = tf.map_fn( | |
_decode_tf_example, | |
elems=input_tensor, | |
dtype=tf.uint8, | |
parallel_iterations=32, | |
back_prop=False) | |
return self._run_inference_on_images(image) | |
DETECTION_MODULE_MAP = { | |
'image_tensor': DetectionFromImageModule, | |
'encoded_image_string_tensor': | |
DetectionFromEncodedImageModule, | |
'tf_example': DetectionFromTFExampleModule, | |
'float_image_tensor': DetectionFromFloatImageModule | |
} | |
def export_inference_graph(input_type, | |
pipeline_config, | |
trained_checkpoint_dir, | |
output_directory): | |
"""Exports inference graph for the model specified in the pipeline config. | |
This function creates `output_directory` if it does not already exist, | |
which will hold a copy of the pipeline config with filename `pipeline.config`, | |
and two subdirectories named `checkpoint` and `saved_model` | |
(containing the exported checkpoint and SavedModel respectively). | |
Args: | |
input_type: Type of input for the graph. Can be one of ['image_tensor', | |
'encoded_image_string_tensor', 'tf_example']. | |
pipeline_config: pipeline_pb2.TrainAndEvalPipelineConfig proto. | |
trained_checkpoint_dir: Path to the trained checkpoint file. | |
output_directory: Path to write outputs. | |
Raises: | |
ValueError: if input_type is invalid. | |
""" | |
output_checkpoint_directory = os.path.join(output_directory, 'checkpoint') | |
output_saved_model_directory = os.path.join(output_directory, 'saved_model') | |
detection_model = model_builder.build(pipeline_config.model, | |
is_training=False) | |
ckpt = tf.train.Checkpoint( | |
model=detection_model) | |
manager = tf.train.CheckpointManager( | |
ckpt, trained_checkpoint_dir, max_to_keep=1) | |
status = ckpt.restore(manager.latest_checkpoint).expect_partial() | |
if input_type not in DETECTION_MODULE_MAP: | |
raise ValueError('Unrecognized `input_type`') | |
detection_module = DETECTION_MODULE_MAP[input_type](detection_model) | |
# Getting the concrete function traces the graph and forces variables to | |
# be constructed --- only after this can we save the checkpoint and | |
# saved model. | |
concrete_function = detection_module.__call__.get_concrete_function() | |
status.assert_existing_objects_matched() | |
exported_checkpoint_manager = tf.train.CheckpointManager( | |
ckpt, output_checkpoint_directory, max_to_keep=1) | |
exported_checkpoint_manager.save(checkpoint_number=0) | |
tf.saved_model.save(detection_module, | |
output_saved_model_directory, | |
signatures=concrete_function) | |
config_util.save_pipeline_config(pipeline_config, output_directory) | |