pat229988's picture
Upload 653 files
9a393e2
raw history blame
No virus
29.4 kB
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model input function for tf-learn object detection model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow as tf
from object_detection.builders import dataset_builder
from object_detection.builders import image_resizer_builder
from object_detection.builders import model_builder
from object_detection.builders import preprocessor_builder
from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder
from object_detection.protos import eval_pb2
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import train_pb2
from object_detection.utils import config_util
from object_detection.utils import ops as util_ops
from object_detection.utils import shape_utils
HASH_KEY = 'hash'
HASH_BINS = 1 << 31
SERVING_FED_EXAMPLE_KEY = 'serialized_example'
# A map of names to methods that help build the input pipeline.
INPUT_BUILDER_UTIL_MAP = {
'dataset_build': dataset_builder.build,
}
def transform_input_data(tensor_dict,
model_preprocess_fn,
image_resizer_fn,
num_classes,
data_augmentation_fn=None,
merge_multiple_boxes=False,
retain_original_image=False,
use_multiclass_scores=False,
use_bfloat16=False):
"""A single function that is responsible for all input data transformations.
Data transformation functions are applied in the following order.
1. If key fields.InputDataFields.image_additional_channels is present in
tensor_dict, the additional channels will be merged into
fields.InputDataFields.image.
2. data_augmentation_fn (optional): applied on tensor_dict.
3. model_preprocess_fn: applied only on image tensor in tensor_dict.
4. image_resizer_fn: applied on original image and instance mask tensor in
tensor_dict.
5. one_hot_encoding: applied to classes tensor in tensor_dict.
6. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
same they can be merged into a single box with an associated k-hot class
label.
Args:
tensor_dict: dictionary containing input tensors keyed by
fields.InputDataFields.
model_preprocess_fn: model's preprocess function to apply on image tensor.
This function must take in a 4-D float tensor and return a 4-D preprocess
float tensor and a tensor containing the true image shape.
image_resizer_fn: image resizer function to apply on groundtruth instance
`masks. This function must take a 3-D float tensor of an image and a 3-D
tensor of instance masks and return a resized version of these along with
the true shapes.
num_classes: number of max classes to one-hot (or k-hot) encode the class
labels.
data_augmentation_fn: (optional) data augmentation function to apply on
input `tensor_dict`.
merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes
and classes for a given image if the boxes are exactly the same.
retain_original_image: (optional) whether to retain original image in the
output dictionary.
use_multiclass_scores: whether to use multiclass scores as
class targets instead of one-hot encoding of `groundtruth_classes`.
use_bfloat16: (optional) a bool, whether to use bfloat16 in training.
Returns:
A dictionary keyed by fields.InputDataFields containing the tensors obtained
after applying all the transformations.
"""
# Reshape flattened multiclass scores tensor into a 2D tensor of shape
# [num_boxes, num_classes].
if fields.InputDataFields.multiclass_scores in tensor_dict:
tensor_dict[fields.InputDataFields.multiclass_scores] = tf.reshape(
tensor_dict[fields.InputDataFields.multiclass_scores], [
tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0],
num_classes
])
if fields.InputDataFields.groundtruth_boxes in tensor_dict:
tensor_dict = util_ops.filter_groundtruth_with_nan_box_coordinates(
tensor_dict)
tensor_dict = util_ops.filter_unrecognized_classes(tensor_dict)
if retain_original_image:
tensor_dict[fields.InputDataFields.original_image] = tf.cast(
image_resizer_fn(tensor_dict[fields.InputDataFields.image], None)[0],
tf.uint8)
if fields.InputDataFields.image_additional_channels in tensor_dict:
channels = tensor_dict[fields.InputDataFields.image_additional_channels]
tensor_dict[fields.InputDataFields.image] = tf.concat(
[tensor_dict[fields.InputDataFields.image], channels], axis=2)
# Apply data augmentation ops.
if data_augmentation_fn is not None:
tensor_dict = data_augmentation_fn(tensor_dict)
# Apply model preprocessing ops and resize instance masks.
image = tensor_dict[fields.InputDataFields.image]
preprocessed_resized_image, true_image_shape = model_preprocess_fn(
tf.expand_dims(tf.to_float(image), axis=0))
if use_bfloat16:
preprocessed_resized_image = tf.cast(
preprocessed_resized_image, tf.bfloat16)
tensor_dict[fields.InputDataFields.image] = tf.squeeze(
preprocessed_resized_image, axis=0)
tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
true_image_shape, axis=0)
if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
masks = tensor_dict[fields.InputDataFields.groundtruth_instance_masks]
_, resized_masks, _ = image_resizer_fn(image, masks)
if use_bfloat16:
resized_masks = tf.cast(resized_masks, tf.bfloat16)
tensor_dict[fields.InputDataFields.
groundtruth_instance_masks] = resized_masks
# Transform groundtruth classes to one hot encodings.
label_offset = 1
zero_indexed_groundtruth_classes = tensor_dict[
fields.InputDataFields.groundtruth_classes] - label_offset
tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
zero_indexed_groundtruth_classes, num_classes)
if use_multiclass_scores:
tensor_dict[fields.InputDataFields.groundtruth_classes] = tensor_dict[
fields.InputDataFields.multiclass_scores]
tensor_dict.pop(fields.InputDataFields.multiclass_scores, None)
if fields.InputDataFields.groundtruth_confidences in tensor_dict:
groundtruth_confidences = tensor_dict[
fields.InputDataFields.groundtruth_confidences]
# Map the confidences to the one-hot encoding of classes
tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
tf.reshape(groundtruth_confidences, [-1, 1]) *
tensor_dict[fields.InputDataFields.groundtruth_classes])
else:
groundtruth_confidences = tf.ones_like(
zero_indexed_groundtruth_classes, dtype=tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
tensor_dict[fields.InputDataFields.groundtruth_classes])
if merge_multiple_boxes:
merged_boxes, merged_classes, merged_confidences, _ = (
util_ops.merge_boxes_with_multiple_labels(
tensor_dict[fields.InputDataFields.groundtruth_boxes],
zero_indexed_groundtruth_classes,
groundtruth_confidences,
num_classes))
merged_classes = tf.cast(merged_classes, tf.float32)
tensor_dict[fields.InputDataFields.groundtruth_boxes] = merged_boxes
tensor_dict[fields.InputDataFields.groundtruth_classes] = merged_classes
tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
merged_confidences)
if fields.InputDataFields.groundtruth_boxes in tensor_dict:
tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]
return tensor_dict
def pad_input_data_to_static_shapes(tensor_dict, max_num_boxes, num_classes,
spatial_image_shape=None):
"""Pads input tensors to static shapes.
In case num_additional_channels > 0, we assume that the additional channels
have already been concatenated to the base image.
Args:
tensor_dict: Tensor dictionary of input data
max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
padding.
num_classes: Number of classes in the dataset needed to compute shapes for
padding.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the image.
Returns:
A dictionary keyed by fields.InputDataFields containing padding shapes for
tensors in the dataset.
Raises:
ValueError: If groundtruth classes is neither rank 1 nor rank 2, or if we
detect that additional channels have not been concatenated yet.
"""
if not spatial_image_shape or spatial_image_shape == [-1, -1]:
height, width = None, None
else:
height, width = spatial_image_shape # pylint: disable=unpacking-non-sequence
num_additional_channels = 0
if fields.InputDataFields.image_additional_channels in tensor_dict:
num_additional_channels = tensor_dict[
fields.InputDataFields.image_additional_channels].shape[2].value
# We assume that if num_additional_channels > 0, then it has already been
# concatenated to the base image (but not the ground truth).
num_channels = 3
if fields.InputDataFields.image in tensor_dict:
num_channels = tensor_dict[fields.InputDataFields.image].shape[2].value
if num_additional_channels:
if num_additional_channels >= num_channels:
raise ValueError(
'Image must be already concatenated with additional channels.')
if (fields.InputDataFields.original_image in tensor_dict and
tensor_dict[fields.InputDataFields.original_image].shape[2].value ==
num_channels):
raise ValueError(
'Image must be already concatenated with additional channels.')
padding_shapes = {
fields.InputDataFields.image: [
height, width, num_channels
],
fields.InputDataFields.original_image_spatial_shape: [2],
fields.InputDataFields.image_additional_channels: [
height, width, num_additional_channels
],
fields.InputDataFields.source_id: [],
fields.InputDataFields.filename: [],
fields.InputDataFields.key: [],
fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
fields.InputDataFields.groundtruth_classes: [max_num_boxes, num_classes],
fields.InputDataFields.groundtruth_instance_masks: [
max_num_boxes, height, width
],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_area: [max_num_boxes],
fields.InputDataFields.groundtruth_weights: [max_num_boxes],
fields.InputDataFields.groundtruth_confidences: [
max_num_boxes, num_classes
],
fields.InputDataFields.num_groundtruth_boxes: [],
fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
fields.InputDataFields.groundtruth_label_weights: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3],
fields.InputDataFields.groundtruth_image_classes: [num_classes],
fields.InputDataFields.groundtruth_image_confidences: [num_classes],
}
if fields.InputDataFields.original_image in tensor_dict:
padding_shapes[fields.InputDataFields.original_image] = [
height, width, tensor_dict[fields.InputDataFields.
original_image].shape[2].value
]
if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
tensor_shape = (
tensor_dict[fields.InputDataFields.groundtruth_keypoints].shape)
padding_shape = [max_num_boxes, tensor_shape[1].value,
tensor_shape[2].value]
padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
if fields.InputDataFields.groundtruth_keypoint_visibilities in tensor_dict:
tensor_shape = tensor_dict[fields.InputDataFields.
groundtruth_keypoint_visibilities].shape
padding_shape = [max_num_boxes, tensor_shape[1].value]
padding_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities] = padding_shape
padded_tensor_dict = {}
for tensor_name in tensor_dict:
padded_tensor_dict[tensor_name] = shape_utils.pad_or_clip_nd(
tensor_dict[tensor_name], padding_shapes[tensor_name])
# Make sure that the number of groundtruth boxes now reflects the
# padded/clipped tensors.
if fields.InputDataFields.num_groundtruth_boxes in padded_tensor_dict:
padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = (
tf.minimum(
padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
max_num_boxes))
return padded_tensor_dict
def augment_input_data(tensor_dict, data_augmentation_options):
"""Applies data augmentation ops to input tensors.
Args:
tensor_dict: A dictionary of input tensors keyed by fields.InputDataFields.
data_augmentation_options: A list of tuples, where each tuple contains a
function and a dictionary that contains arguments and their values.
Usually, this is the output of core/preprocessor.build.
Returns:
A dictionary of tensors obtained by applying data augmentation ops to the
input tensor dictionary.
"""
tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
tf.to_float(tensor_dict[fields.InputDataFields.image]), 0)
include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
in tensor_dict)
include_keypoints = (fields.InputDataFields.groundtruth_keypoints
in tensor_dict)
include_label_weights = (fields.InputDataFields.groundtruth_weights
in tensor_dict)
include_label_confidences = (fields.InputDataFields.groundtruth_confidences
in tensor_dict)
include_multiclass_scores = (fields.InputDataFields.multiclass_scores in
tensor_dict)
tensor_dict = preprocessor.preprocess(
tensor_dict, data_augmentation_options,
func_arg_map=preprocessor.get_default_func_arg_map(
include_label_weights=include_label_weights,
include_label_confidences=include_label_confidences,
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints))
tensor_dict[fields.InputDataFields.image] = tf.squeeze(
tensor_dict[fields.InputDataFields.image], axis=0)
return tensor_dict
def _get_labels_dict(input_dict):
"""Extracts labels dict from input dict."""
required_label_keys = [
fields.InputDataFields.num_groundtruth_boxes,
fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_weights,
]
labels_dict = {}
for key in required_label_keys:
labels_dict[key] = input_dict[key]
optional_label_keys = [
fields.InputDataFields.groundtruth_confidences,
fields.InputDataFields.groundtruth_keypoints,
fields.InputDataFields.groundtruth_instance_masks,
fields.InputDataFields.groundtruth_area,
fields.InputDataFields.groundtruth_is_crowd,
fields.InputDataFields.groundtruth_difficult
]
for key in optional_label_keys:
if key in input_dict:
labels_dict[key] = input_dict[key]
if fields.InputDataFields.groundtruth_difficult in labels_dict:
labels_dict[fields.InputDataFields.groundtruth_difficult] = tf.cast(
labels_dict[fields.InputDataFields.groundtruth_difficult], tf.int32)
return labels_dict
def _replace_empty_string_with_random_number(string_tensor):
"""Returns string unchanged if non-empty, and random string tensor otherwise.
The random string is an integer 0 and 2**63 - 1, casted as string.
Args:
string_tensor: A tf.tensor of dtype string.
Returns:
out_string: A tf.tensor of dtype string. If string_tensor contains the empty
string, out_string will contain a random integer casted to a string.
Otherwise string_tensor is returned unchanged.
"""
empty_string = tf.constant('', dtype=tf.string, name='EmptyString')
random_source_id = tf.as_string(
tf.random_uniform(shape=[], maxval=2**63 - 1, dtype=tf.int64))
out_string = tf.cond(
tf.equal(string_tensor, empty_string),
true_fn=lambda: random_source_id,
false_fn=lambda: string_tensor)
return out_string
def _get_features_dict(input_dict):
"""Extracts features dict from input dict."""
source_id = _replace_empty_string_with_random_number(
input_dict[fields.InputDataFields.source_id])
hash_from_source_id = tf.string_to_hash_bucket_fast(source_id, HASH_BINS)
features = {
fields.InputDataFields.image:
input_dict[fields.InputDataFields.image],
HASH_KEY: tf.cast(hash_from_source_id, tf.int32),
fields.InputDataFields.true_image_shape:
input_dict[fields.InputDataFields.true_image_shape],
fields.InputDataFields.original_image_spatial_shape:
input_dict[fields.InputDataFields.original_image_spatial_shape]
}
if fields.InputDataFields.original_image in input_dict:
features[fields.InputDataFields.original_image] = input_dict[
fields.InputDataFields.original_image]
return features
def create_train_input_fn(train_config, train_input_config,
model_config):
"""Creates a train `input` function for `Estimator`.
Args:
train_config: A train_pb2.TrainConfig.
train_input_config: An input_reader_pb2.InputReader.
model_config: A model_pb2.DetectionModel.
Returns:
`input_fn` for `Estimator` in TRAIN mode.
"""
def _train_input_fn(params=None):
"""Returns `features` and `labels` tensor dictionaries for training.
Args:
params: Parameter dictionary passed from the estimator.
Returns:
A tf.data.Dataset that holds (features, labels) tuple.
features: Dictionary of feature tensors.
features[fields.InputDataFields.image] is a [batch_size, H, W, C]
float32 tensor with preprocessed images.
features[HASH_KEY] is a [batch_size] int32 tensor representing unique
identifiers for the images.
features[fields.InputDataFields.true_image_shape] is a [batch_size, 3]
int32 tensor representing the true image shapes, as preprocessed
images could be padded.
features[fields.InputDataFields.original_image] (optional) is a
[batch_size, H, W, C] float32 tensor with original images.
labels: Dictionary of groundtruth tensors.
labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size]
int32 tensor indicating the number of groundtruth boxes.
labels[fields.InputDataFields.groundtruth_boxes] is a
[batch_size, num_boxes, 4] float32 tensor containing the corners of
the groundtruth boxes.
labels[fields.InputDataFields.groundtruth_classes] is a
[batch_size, num_boxes, num_classes] float32 one-hot tensor of
classes.
labels[fields.InputDataFields.groundtruth_weights] is a
[batch_size, num_boxes] float32 tensor containing groundtruth weights
for the boxes.
-- Optional --
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[batch_size, num_boxes, H, W] float32 tensor containing only binary
values, which represent instance masks for objects.
labels[fields.InputDataFields.groundtruth_keypoints] is a
[batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
keypoints for each box.
Raises:
TypeError: if the `train_config`, `train_input_config` or `model_config`
are not of the correct type.
"""
if not isinstance(train_config, train_pb2.TrainConfig):
raise TypeError('For training mode, the `train_config` must be a '
'train_pb2.TrainConfig.')
if not isinstance(train_input_config, input_reader_pb2.InputReader):
raise TypeError('The `train_input_config` must be a '
'input_reader_pb2.InputReader.')
if not isinstance(model_config, model_pb2.DetectionModel):
raise TypeError('The `model_config` must be a '
'model_pb2.DetectionModel.')
def transform_and_pad_input_data_fn(tensor_dict):
"""Combines transform and pad operation."""
data_augmentation_options = [
preprocessor_builder.build(step)
for step in train_config.data_augmentation_options
]
data_augmentation_fn = functools.partial(
augment_input_data,
data_augmentation_options=data_augmentation_options)
model = model_builder.build(model_config, is_training=True)
image_resizer_config = config_util.get_image_resizer_config(model_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config)
transform_data_fn = functools.partial(
transform_input_data, model_preprocess_fn=model.preprocess,
image_resizer_fn=image_resizer_fn,
num_classes=config_util.get_number_of_classes(model_config),
data_augmentation_fn=data_augmentation_fn,
merge_multiple_boxes=train_config.merge_multiple_label_boxes,
retain_original_image=train_config.retain_original_images,
use_multiclass_scores=train_config.use_multiclass_scores,
use_bfloat16=train_config.use_bfloat16)
tensor_dict = pad_input_data_to_static_shapes(
tensor_dict=transform_data_fn(tensor_dict),
max_num_boxes=train_input_config.max_number_of_boxes,
num_classes=config_util.get_number_of_classes(model_config),
spatial_image_shape=config_util.get_spatial_image_size(
image_resizer_config))
return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
train_input_config,
transform_input_data_fn=transform_and_pad_input_data_fn,
batch_size=params['batch_size'] if params else train_config.batch_size)
return dataset
return _train_input_fn
def create_eval_input_fn(eval_config, eval_input_config, model_config):
"""Creates an eval `input` function for `Estimator`.
Args:
eval_config: An eval_pb2.EvalConfig.
eval_input_config: An input_reader_pb2.InputReader.
model_config: A model_pb2.DetectionModel.
Returns:
`input_fn` for `Estimator` in EVAL mode.
"""
def _eval_input_fn(params=None):
"""Returns `features` and `labels` tensor dictionaries for evaluation.
Args:
params: Parameter dictionary passed from the estimator.
Returns:
A tf.data.Dataset that holds (features, labels) tuple.
features: Dictionary of feature tensors.
features[fields.InputDataFields.image] is a [1, H, W, C] float32 tensor
with preprocessed images.
features[HASH_KEY] is a [1] int32 tensor representing unique
identifiers for the images.
features[fields.InputDataFields.true_image_shape] is a [1, 3]
int32 tensor representing the true image shapes, as preprocessed
images could be padded.
features[fields.InputDataFields.original_image] is a [1, H', W', C]
float32 tensor with the original image.
labels: Dictionary of groundtruth tensors.
labels[fields.InputDataFields.groundtruth_boxes] is a [1, num_boxes, 4]
float32 tensor containing the corners of the groundtruth boxes.
labels[fields.InputDataFields.groundtruth_classes] is a
[num_boxes, num_classes] float32 one-hot tensor of classes.
labels[fields.InputDataFields.groundtruth_area] is a [1, num_boxes]
float32 tensor containing object areas.
labels[fields.InputDataFields.groundtruth_is_crowd] is a [1, num_boxes]
bool tensor indicating if the boxes enclose a crowd.
labels[fields.InputDataFields.groundtruth_difficult] is a [1, num_boxes]
int32 tensor indicating if the boxes represent difficult instances.
-- Optional --
labels[fields.InputDataFields.groundtruth_instance_masks] is a
[1, num_boxes, H, W] float32 tensor containing only binary values,
which represent instance masks for objects.
Raises:
TypeError: if the `eval_config`, `eval_input_config` or `model_config`
are not of the correct type.
"""
params = params or {}
if not isinstance(eval_config, eval_pb2.EvalConfig):
raise TypeError('For eval mode, the `eval_config` must be a '
'train_pb2.EvalConfig.')
if not isinstance(eval_input_config, input_reader_pb2.InputReader):
raise TypeError('The `eval_input_config` must be a '
'input_reader_pb2.InputReader.')
if not isinstance(model_config, model_pb2.DetectionModel):
raise TypeError('The `model_config` must be a '
'model_pb2.DetectionModel.')
def transform_and_pad_input_data_fn(tensor_dict):
"""Combines transform and pad operation."""
num_classes = config_util.get_number_of_classes(model_config)
model = model_builder.build(model_config, is_training=False)
image_resizer_config = config_util.get_image_resizer_config(model_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config)
transform_data_fn = functools.partial(
transform_input_data, model_preprocess_fn=model.preprocess,
image_resizer_fn=image_resizer_fn,
num_classes=num_classes,
data_augmentation_fn=None,
retain_original_image=eval_config.retain_original_images)
tensor_dict = pad_input_data_to_static_shapes(
tensor_dict=transform_data_fn(tensor_dict),
max_num_boxes=eval_input_config.max_number_of_boxes,
num_classes=config_util.get_number_of_classes(model_config),
spatial_image_shape=config_util.get_spatial_image_size(
image_resizer_config))
return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
eval_input_config,
batch_size=params['batch_size'] if params else eval_config.batch_size,
transform_input_data_fn=transform_and_pad_input_data_fn)
return dataset
return _eval_input_fn
def create_predict_input_fn(model_config, predict_input_config):
"""Creates a predict `input` function for `Estimator`.
Args:
model_config: A model_pb2.DetectionModel.
predict_input_config: An input_reader_pb2.InputReader.
Returns:
`input_fn` for `Estimator` in PREDICT mode.
"""
def _predict_input_fn(params=None):
"""Decodes serialized tf.Examples and returns `ServingInputReceiver`.
Args:
params: Parameter dictionary passed from the estimator.
Returns:
`ServingInputReceiver`.
"""
del params
example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')
num_classes = config_util.get_number_of_classes(model_config)
model = model_builder.build(model_config, is_training=False)
image_resizer_config = config_util.get_image_resizer_config(model_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config)
transform_fn = functools.partial(
transform_input_data, model_preprocess_fn=model.preprocess,
image_resizer_fn=image_resizer_fn,
num_classes=num_classes,
data_augmentation_fn=None)
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=False,
num_additional_channels=predict_input_config.num_additional_channels)
input_dict = transform_fn(decoder.decode(example))
images = tf.to_float(input_dict[fields.InputDataFields.image])
images = tf.expand_dims(images, axis=0)
true_image_shape = tf.expand_dims(
input_dict[fields.InputDataFields.true_image_shape], axis=0)
return tf.estimator.export.ServingInputReceiver(
features={
fields.InputDataFields.image: images,
fields.InputDataFields.true_image_shape: true_image_shape},
receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
return _predict_input_fn