Spaces:
Running
Running
# Lint as: python2, python3 | |
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Exports trained model to TensorFlow frozen graph.""" | |
import os | |
import tensorflow as tf | |
from tensorflow.contrib import quantize as contrib_quantize | |
from tensorflow.python.tools import freeze_graph | |
from deeplab import common | |
from deeplab import input_preprocess | |
from deeplab import model | |
slim = tf.contrib.slim | |
flags = tf.app.flags | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path') | |
flags.DEFINE_string('export_path', None, | |
'Path to output Tensorflow frozen graph.') | |
flags.DEFINE_integer('num_classes', 21, 'Number of classes.') | |
flags.DEFINE_multi_integer('crop_size', [513, 513], | |
'Crop size [height, width].') | |
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or | |
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note | |
# one could use different atrous_rates/output_stride during training/evaluation. | |
flags.DEFINE_multi_integer('atrous_rates', None, | |
'Atrous rates for atrous spatial pyramid pooling.') | |
flags.DEFINE_integer('output_stride', 8, | |
'The ratio of input to output spatial resolution.') | |
# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale inference. | |
flags.DEFINE_multi_float('inference_scales', [1.0], | |
'The scales to resize images for inference.') | |
flags.DEFINE_bool('add_flipped_images', False, | |
'Add flipped images during inference or not.') | |
flags.DEFINE_integer( | |
'quantize_delay_step', -1, | |
'Steps to start quantized training. If < 0, will not quantize model.') | |
flags.DEFINE_bool('save_inference_graph', False, | |
'Save inference graph in text proto.') | |
# Input name of the exported model. | |
_INPUT_NAME = 'ImageTensor' | |
# Output name of the exported predictions. | |
_OUTPUT_NAME = 'SemanticPredictions' | |
_RAW_OUTPUT_NAME = 'RawSemanticPredictions' | |
# Output name of the exported probabilities. | |
_OUTPUT_PROB_NAME = 'SemanticProbabilities' | |
_RAW_OUTPUT_PROB_NAME = 'RawSemanticProbabilities' | |
def _create_input_tensors(): | |
"""Creates and prepares input tensors for DeepLab model. | |
This method creates a 4-D uint8 image tensor 'ImageTensor' with shape | |
[1, None, None, 3]. The actual input tensor name to use during inference is | |
'ImageTensor:0'. | |
Returns: | |
image: Preprocessed 4-D float32 tensor with shape [1, crop_height, | |
crop_width, 3]. | |
original_image_size: Original image shape tensor [height, width]. | |
resized_image_size: Resized image shape tensor [height, width]. | |
""" | |
# input_preprocess takes 4-D image tensor as input. | |
input_image = tf.placeholder(tf.uint8, [1, None, None, 3], name=_INPUT_NAME) | |
original_image_size = tf.shape(input_image)[1:3] | |
# Squeeze the dimension in axis=0 since `preprocess_image_and_label` assumes | |
# image to be 3-D. | |
image = tf.squeeze(input_image, axis=0) | |
resized_image, image, _ = input_preprocess.preprocess_image_and_label( | |
image, | |
label=None, | |
crop_height=FLAGS.crop_size[0], | |
crop_width=FLAGS.crop_size[1], | |
min_resize_value=FLAGS.min_resize_value, | |
max_resize_value=FLAGS.max_resize_value, | |
resize_factor=FLAGS.resize_factor, | |
is_training=False, | |
model_variant=FLAGS.model_variant) | |
resized_image_size = tf.shape(resized_image)[:2] | |
# Expand the dimension in axis=0, since the following operations assume the | |
# image to be 4-D. | |
image = tf.expand_dims(image, 0) | |
return image, original_image_size, resized_image_size | |
def main(unused_argv): | |
tf.logging.set_verbosity(tf.logging.INFO) | |
tf.logging.info('Prepare to export model to: %s', FLAGS.export_path) | |
with tf.Graph().as_default(): | |
image, image_size, resized_image_size = _create_input_tensors() | |
model_options = common.ModelOptions( | |
outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes}, | |
crop_size=FLAGS.crop_size, | |
atrous_rates=FLAGS.atrous_rates, | |
output_stride=FLAGS.output_stride) | |
if tuple(FLAGS.inference_scales) == (1.0,): | |
tf.logging.info('Exported model performs single-scale inference.') | |
predictions = model.predict_labels( | |
image, | |
model_options=model_options, | |
image_pyramid=FLAGS.image_pyramid) | |
else: | |
tf.logging.info('Exported model performs multi-scale inference.') | |
if FLAGS.quantize_delay_step >= 0: | |
raise ValueError( | |
'Quantize mode is not supported with multi-scale test.') | |
predictions = model.predict_labels_multi_scale( | |
image, | |
model_options=model_options, | |
eval_scales=FLAGS.inference_scales, | |
add_flipped_images=FLAGS.add_flipped_images) | |
raw_predictions = tf.identity( | |
tf.cast(predictions[common.OUTPUT_TYPE], tf.float32), | |
_RAW_OUTPUT_NAME) | |
raw_probabilities = tf.identity( | |
predictions[common.OUTPUT_TYPE + model.PROB_SUFFIX], | |
_RAW_OUTPUT_PROB_NAME) | |
# Crop the valid regions from the predictions. | |
semantic_predictions = raw_predictions[ | |
:, :resized_image_size[0], :resized_image_size[1]] | |
semantic_probabilities = raw_probabilities[ | |
:, :resized_image_size[0], :resized_image_size[1]] | |
# Resize back the prediction to the original image size. | |
def _resize_label(label, label_size): | |
# Expand dimension of label to [1, height, width, 1] for resize operation. | |
label = tf.expand_dims(label, 3) | |
resized_label = tf.image.resize_images( | |
label, | |
label_size, | |
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, | |
align_corners=True) | |
return tf.cast(tf.squeeze(resized_label, 3), tf.int32) | |
semantic_predictions = _resize_label(semantic_predictions, image_size) | |
semantic_predictions = tf.identity(semantic_predictions, name=_OUTPUT_NAME) | |
semantic_probabilities = tf.image.resize_bilinear( | |
semantic_probabilities, image_size, align_corners=True, | |
name=_OUTPUT_PROB_NAME) | |
if FLAGS.quantize_delay_step >= 0: | |
contrib_quantize.create_eval_graph() | |
saver = tf.train.Saver(tf.all_variables()) | |
dirname = os.path.dirname(FLAGS.export_path) | |
tf.gfile.MakeDirs(dirname) | |
graph_def = tf.get_default_graph().as_graph_def(add_shapes=True) | |
freeze_graph.freeze_graph_with_def_protos( | |
graph_def, | |
saver.as_saver_def(), | |
FLAGS.checkpoint_path, | |
_OUTPUT_NAME + ',' + _OUTPUT_PROB_NAME, | |
restore_op_name=None, | |
filename_tensor_name=None, | |
output_graph=FLAGS.export_path, | |
clear_devices=True, | |
initializer_nodes=None) | |
if FLAGS.save_inference_graph: | |
tf.train.write_graph(graph_def, dirname, 'inference_graph.pbtxt') | |
if __name__ == '__main__': | |
flags.mark_flag_as_required('checkpoint_path') | |
flags.mark_flag_as_required('export_path') | |
tf.app.run() | |