# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tensorflow Example proto decoder for object detection. A decoder to decode string tensors containing serialized tensorflow.Example protos for object detection. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import enum import numpy as np from six.moves import zip import tensorflow.compat.v1 as tf from tf_slim import tfexample_decoder as slim_example_decoder from object_detection.core import data_decoder from object_detection.core import standard_fields as fields from object_detection.protos import input_reader_pb2 from object_detection.utils import label_map_util # pylint: disable=g-import-not-at-top try: from tensorflow.contrib import lookup as contrib_lookup except ImportError: # TF 2.0 doesn't ship with contrib. pass # pylint: enable=g-import-not-at-top _LABEL_OFFSET = 1 class Visibility(enum.Enum): """Visibility definitions. This follows the MS Coco convention (http://cocodataset.org/#format-data). """ # Keypoint is not labeled. UNLABELED = 0 # Keypoint is labeled but falls outside the object segment (e.g. occluded). NOT_VISIBLE = 1 # Keypoint is labeled and visible. VISIBLE = 2 class _ClassTensorHandler(slim_example_decoder.Tensor): """An ItemHandler to fetch class ids from class text.""" def __init__(self, tensor_key, label_map_proto_file, shape_keys=None, shape=None, default_value=''): """Initializes the LookupTensor handler. Simply calls a vocabulary (most often, a label mapping) lookup. Args: tensor_key: the name of the `TFExample` feature to read the tensor from. label_map_proto_file: File path to a text format LabelMapProto message mapping class text to id. shape_keys: Optional name or list of names of the TF-Example feature in which the tensor shape is stored. If a list, then each corresponds to one dimension of the shape. shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is reshaped accordingly. default_value: The value used when the `tensor_key` is not found in a particular `TFExample`. Raises: ValueError: if both `shape_keys` and `shape` are specified. """ name_to_id = label_map_util.get_label_map_dict( label_map_proto_file, use_display_name=False) # We use a default_value of -1, but we expect all labels to be contained # in the label map. try: # Dynamically try to load the tf v2 lookup, falling back to contrib lookup = tf.compat.v2.lookup hash_table_class = tf.compat.v2.lookup.StaticHashTable except AttributeError: lookup = contrib_lookup hash_table_class = contrib_lookup.HashTable name_to_id_table = hash_table_class( initializer=lookup.KeyValueTensorInitializer( keys=tf.constant(list(name_to_id.keys())), values=tf.constant(list(name_to_id.values()), dtype=tf.int64)), default_value=-1) display_name_to_id = label_map_util.get_label_map_dict( label_map_proto_file, use_display_name=True) # We use a default_value of -1, but we expect all labels to be contained # in the label map. display_name_to_id_table = hash_table_class( initializer=lookup.KeyValueTensorInitializer( keys=tf.constant(list(display_name_to_id.keys())), values=tf.constant( list(display_name_to_id.values()), dtype=tf.int64)), default_value=-1) self._name_to_id_table = name_to_id_table self._display_name_to_id_table = display_name_to_id_table super(_ClassTensorHandler, self).__init__(tensor_key, shape_keys, shape, default_value) def tensors_to_item(self, keys_to_tensors): unmapped_tensor = super(_ClassTensorHandler, self).tensors_to_item(keys_to_tensors) return tf.maximum(self._name_to_id_table.lookup(unmapped_tensor), self._display_name_to_id_table.lookup(unmapped_tensor)) class _BackupHandler(slim_example_decoder.ItemHandler): """An ItemHandler that tries two ItemHandlers in order.""" def __init__(self, handler, backup): """Initializes the BackupHandler handler. If the first Handler's tensors_to_item returns a Tensor with no elements, the second Handler is used. Args: handler: The primary ItemHandler. backup: The backup ItemHandler. Raises: ValueError: if either is not an ItemHandler. """ if not isinstance(handler, slim_example_decoder.ItemHandler): raise ValueError('Primary handler is of type %s instead of ItemHandler' % type(handler)) if not isinstance(backup, slim_example_decoder.ItemHandler): raise ValueError( 'Backup handler is of type %s instead of ItemHandler' % type(backup)) self._handler = handler self._backup = backup super(_BackupHandler, self).__init__(handler.keys + backup.keys) def tensors_to_item(self, keys_to_tensors): item = self._handler.tensors_to_item(keys_to_tensors) return tf.cond( pred=tf.equal(tf.reduce_prod(tf.shape(item)), 0), true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors), false_fn=lambda: item) class TfExampleDecoder(data_decoder.DataDecoder): """Tensorflow Example proto decoder.""" def __init__(self, load_instance_masks=False, instance_mask_type=input_reader_pb2.NUMERICAL_MASKS, label_map_proto_file=None, use_display_name=False, dct_method='', num_keypoints=0, num_additional_channels=0, load_multiclass_scores=False, load_context_features=False, expand_hierarchy_labels=False): """Constructor sets keys_to_features and items_to_handlers. Args: load_instance_masks: whether or not to load and handle instance masks. instance_mask_type: type of instance masks. Options are provided in input_reader.proto. This is only used if `load_instance_masks` is True. label_map_proto_file: a file path to a object_detection.protos.StringIntLabelMap proto. If provided, then the mapped IDs of 'image/object/class/text' will take precedence over the existing 'image/object/class/label' ID. Also, if provided, it is assumed that 'image/object/class/text' will be in the data. use_display_name: whether or not to use the `display_name` for label mapping (instead of `name`). Only used if label_map_proto_file is provided. dct_method: An optional string. Defaults to None. It only takes effect when image format is jpeg, used to specify a hint about the algorithm used for jpeg decompression. Currently valid values are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for example, the jpeg library does not have that specific option. num_keypoints: the number of keypoints per object. num_additional_channels: how many additional channels to use. load_multiclass_scores: Whether to load multiclass scores associated with boxes. load_context_features: Whether to load information from context_features, to provide additional context to a detection model for training and/or inference. expand_hierarchy_labels: Expands the object and image labels taking into account the provided hierarchy in the label_map_proto_file. For positive classes, the labels are extended to ancestor. For negative classes, the labels are expanded to descendants. Raises: ValueError: If `instance_mask_type` option is not one of input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL, or input_reader_pb2.PNG_MASKS. ValueError: If `expand_labels_hierarchy` is True, but the `label_map_proto_file` is not provided. """ # TODO(rathodv): delete unused `use_display_name` argument once we change # other decoders to handle label maps similarly. del use_display_name self.keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''), 'image/source_id': tf.FixedLenFeature((), tf.string, default_value=''), 'image/height': tf.FixedLenFeature((), tf.int64, default_value=1), 'image/width': tf.FixedLenFeature((), tf.int64, default_value=1), # Image-level labels. 'image/class/text': tf.VarLenFeature(tf.string), 'image/class/label': tf.VarLenFeature(tf.int64), 'image/class/confidence': tf.VarLenFeature(tf.float32), # Object boxes and classes. 'image/object/bbox/xmin': tf.VarLenFeature(tf.float32), 'image/object/bbox/xmax': tf.VarLenFeature(tf.float32), 'image/object/bbox/ymin': tf.VarLenFeature(tf.float32), 'image/object/bbox/ymax': tf.VarLenFeature(tf.float32), 'image/object/class/label': tf.VarLenFeature(tf.int64), 'image/object/class/text': tf.VarLenFeature(tf.string), 'image/object/area': tf.VarLenFeature(tf.float32), 'image/object/is_crowd': tf.VarLenFeature(tf.int64), 'image/object/difficult': tf.VarLenFeature(tf.int64), 'image/object/group_of': tf.VarLenFeature(tf.int64), 'image/object/weight': tf.VarLenFeature(tf.float32), } # We are checking `dct_method` instead of passing it directly in order to # ensure TF version 1.6 compatibility. if dct_method: image = slim_example_decoder.Image( image_key='image/encoded', format_key='image/format', channels=3, dct_method=dct_method) additional_channel_image = slim_example_decoder.Image( image_key='image/additional_channels/encoded', format_key='image/format', channels=1, repeated=True, dct_method=dct_method) else: image = slim_example_decoder.Image( image_key='image/encoded', format_key='image/format', channels=3) additional_channel_image = slim_example_decoder.Image( image_key='image/additional_channels/encoded', format_key='image/format', channels=1, repeated=True) self.items_to_handlers = { fields.InputDataFields.image: image, fields.InputDataFields.source_id: ( slim_example_decoder.Tensor('image/source_id')), fields.InputDataFields.key: ( slim_example_decoder.Tensor('image/key/sha256')), fields.InputDataFields.filename: ( slim_example_decoder.Tensor('image/filename')), # Image-level labels. fields.InputDataFields.groundtruth_image_confidences: ( slim_example_decoder.Tensor('image/class/confidence')), # Object boxes and classes. fields.InputDataFields.groundtruth_boxes: ( slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/')), fields.InputDataFields.groundtruth_area: slim_example_decoder.Tensor('image/object/area'), fields.InputDataFields.groundtruth_is_crowd: ( slim_example_decoder.Tensor('image/object/is_crowd')), fields.InputDataFields.groundtruth_difficult: ( slim_example_decoder.Tensor('image/object/difficult')), fields.InputDataFields.groundtruth_group_of: ( slim_example_decoder.Tensor('image/object/group_of')), fields.InputDataFields.groundtruth_weights: ( slim_example_decoder.Tensor('image/object/weight')), } if load_multiclass_scores: self.keys_to_features[ 'image/object/class/multiclass_scores'] = tf.VarLenFeature(tf.float32) self.items_to_handlers[fields.InputDataFields.multiclass_scores] = ( slim_example_decoder.Tensor('image/object/class/multiclass_scores')) if load_context_features: self.keys_to_features[ 'image/context_features'] = tf.VarLenFeature(tf.float32) self.items_to_handlers[fields.InputDataFields.context_features] = ( slim_example_decoder.ItemHandlerCallback( ['image/context_features', 'image/context_feature_length'], self._reshape_context_features)) self.keys_to_features[ 'image/context_feature_length'] = tf.FixedLenFeature((), tf.int64) self.items_to_handlers[fields.InputDataFields.context_feature_length] = ( slim_example_decoder.Tensor('image/context_feature_length')) if num_additional_channels > 0: self.keys_to_features[ 'image/additional_channels/encoded'] = tf.FixedLenFeature( (num_additional_channels,), tf.string) self.items_to_handlers[ fields.InputDataFields. image_additional_channels] = additional_channel_image self._num_keypoints = num_keypoints if num_keypoints > 0: self.keys_to_features['image/object/keypoint/x'] = ( tf.VarLenFeature(tf.float32)) self.keys_to_features['image/object/keypoint/y'] = ( tf.VarLenFeature(tf.float32)) self.keys_to_features['image/object/keypoint/visibility'] = ( tf.VarLenFeature(tf.int64)) self.items_to_handlers[fields.InputDataFields.groundtruth_keypoints] = ( slim_example_decoder.ItemHandlerCallback( ['image/object/keypoint/y', 'image/object/keypoint/x'], self._reshape_keypoints)) kpt_vis_field = fields.InputDataFields.groundtruth_keypoint_visibilities self.items_to_handlers[kpt_vis_field] = ( slim_example_decoder.ItemHandlerCallback( ['image/object/keypoint/x', 'image/object/keypoint/visibility'], self._reshape_keypoint_visibilities)) if load_instance_masks: if instance_mask_type in (input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL_MASKS): self.keys_to_features['image/object/mask'] = ( tf.VarLenFeature(tf.float32)) self.items_to_handlers[ fields.InputDataFields.groundtruth_instance_masks] = ( slim_example_decoder.ItemHandlerCallback( ['image/object/mask', 'image/height', 'image/width'], self._reshape_instance_masks)) elif instance_mask_type == input_reader_pb2.PNG_MASKS: self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string) self.items_to_handlers[ fields.InputDataFields.groundtruth_instance_masks] = ( slim_example_decoder.ItemHandlerCallback( ['image/object/mask', 'image/height', 'image/width'], self._decode_png_instance_masks)) else: raise ValueError('Did not recognize the `instance_mask_type` option.') if label_map_proto_file: # If the label_map_proto is provided, try to use it in conjunction with # the class text, and fall back to a materialized ID. label_handler = _BackupHandler( _ClassTensorHandler( 'image/object/class/text', label_map_proto_file, default_value=''), slim_example_decoder.Tensor('image/object/class/label')) image_label_handler = _BackupHandler( _ClassTensorHandler( fields.TfExampleFields.image_class_text, label_map_proto_file, default_value=''), slim_example_decoder.Tensor(fields.TfExampleFields.image_class_label)) else: label_handler = slim_example_decoder.Tensor('image/object/class/label') image_label_handler = slim_example_decoder.Tensor( fields.TfExampleFields.image_class_label) self.items_to_handlers[ fields.InputDataFields.groundtruth_classes] = label_handler self.items_to_handlers[ fields.InputDataFields.groundtruth_image_classes] = image_label_handler self._expand_hierarchy_labels = expand_hierarchy_labels self._ancestors_lut = None self._descendants_lut = None if expand_hierarchy_labels: if label_map_proto_file: ancestors_lut, descendants_lut = ( label_map_util.get_label_map_hierarchy_lut(label_map_proto_file, True)) self._ancestors_lut = tf.constant(ancestors_lut, dtype=tf.int64) self._descendants_lut = tf.constant(descendants_lut, dtype=tf.int64) else: raise ValueError('In order to expand labels, the label_map_proto_file ' 'has to be provided.') def decode(self, tf_example_string_tensor): """Decodes serialized tensorflow example and returns a tensor dictionary. Args: tf_example_string_tensor: a string tensor holding a serialized tensorflow example proto. Returns: A dictionary of the following tensors. fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3] containing image. fields.InputDataFields.original_image_spatial_shape - 1D int32 tensor of shape [2] containing shape of the image. fields.InputDataFields.source_id - string tensor containing original image id. fields.InputDataFields.key - string tensor with unique sha256 hash key. fields.InputDataFields.filename - string tensor with original dataset filename. fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape [None, 4] containing box corners. fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape [None] containing classes for the boxes. fields.InputDataFields.groundtruth_weights - 1D float32 tensor of shape [None] indicating the weights of groundtruth boxes. fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape [None] containing containing object mask area in pixel squared. fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape [None] indicating if the boxes enclose a crowd. Optional: fields.InputDataFields.groundtruth_image_confidences - 1D float tensor of shape [None] indicating if a class is present in the image (1.0) or a class is not present in the image (0.0). fields.InputDataFields.image_additional_channels - 3D uint8 tensor of shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim is width; 3rd dim is the number of additional channels. fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape [None] indicating if the boxes represent `difficult` instances. fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape [None] indicating if the boxes represent `group_of` instances. fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of shape [None, num_keypoints, 2] containing keypoints, where the coordinates of the keypoints are ordered (y, x). fields.InputDataFields.groundtruth_keypoint_visibilities - 2D bool tensor of shape [None, num_keypoints] containing keypoint visibilites. fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of shape [None, None, None] containing instance masks. fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape [None] containing classes for the boxes. fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape [None * num_classes] containing flattened multiclass scores for groundtruth boxes. fields.InputDataFields.context_features - 1D float32 tensor of shape [context_feature_length * num_context_features] fields.InputDataFields.context_feature_length - int32 tensor specifying the length of each feature in context_features """ serialized_example = tf.reshape(tf_example_string_tensor, shape=[]) decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features, self.items_to_handlers) keys = decoder.list_items() tensors = decoder.decode(serialized_example, items=keys) tensor_dict = dict(zip(keys, tensors)) is_crowd = fields.InputDataFields.groundtruth_is_crowd tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool) tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3]) tensor_dict[fields.InputDataFields.original_image_spatial_shape] = tf.shape( tensor_dict[fields.InputDataFields.image])[:2] if fields.InputDataFields.image_additional_channels in tensor_dict: channels = tensor_dict[fields.InputDataFields.image_additional_channels] channels = tf.squeeze(channels, axis=3) channels = tf.transpose(channels, perm=[1, 2, 0]) tensor_dict[fields.InputDataFields.image_additional_channels] = channels def default_groundtruth_weights(): return tf.ones( [tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]], dtype=tf.float32) tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond( tf.greater( tf.shape( tensor_dict[fields.InputDataFields.groundtruth_weights])[0], 0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights], default_groundtruth_weights) if fields.InputDataFields.groundtruth_keypoints in tensor_dict: # Set all keypoints that are not labeled to NaN. gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints gt_kpt_vis_fld = fields.InputDataFields.groundtruth_keypoint_visibilities visibilities_tiled = tf.tile( tf.expand_dims(tensor_dict[gt_kpt_vis_fld], -1), [1, 1, 2]) tensor_dict[gt_kpt_fld] = tf.where( visibilities_tiled, tensor_dict[gt_kpt_fld], np.nan * tf.ones_like(tensor_dict[gt_kpt_fld])) if self._expand_hierarchy_labels: input_fields = fields.InputDataFields image_classes, image_confidences = self._expand_image_label_hierarchy( tensor_dict[input_fields.groundtruth_image_classes], tensor_dict[input_fields.groundtruth_image_confidences]) tensor_dict[input_fields.groundtruth_image_classes] = image_classes tensor_dict[input_fields.groundtruth_image_confidences] = ( image_confidences) box_fields = [ fields.InputDataFields.groundtruth_group_of, fields.InputDataFields.groundtruth_is_crowd, fields.InputDataFields.groundtruth_difficult, fields.InputDataFields.groundtruth_area, fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_weights, ] def expand_field(field_name): return self._expansion_box_field_labels( tensor_dict[input_fields.groundtruth_classes], tensor_dict[field_name]) # pylint: disable=cell-var-from-loop for field in box_fields: if field in tensor_dict: tensor_dict[field] = tf.cond( tf.size(tensor_dict[field]) > 0, lambda: expand_field(field), lambda: tensor_dict[field]) # pylint: enable=cell-var-from-loop tensor_dict[input_fields.groundtruth_classes] = ( self._expansion_box_field_labels( tensor_dict[input_fields.groundtruth_classes], tensor_dict[input_fields.groundtruth_classes], True)) if fields.InputDataFields.groundtruth_group_of in tensor_dict: group_of = fields.InputDataFields.groundtruth_group_of tensor_dict[group_of] = tf.cast(tensor_dict[group_of], dtype=tf.bool) return tensor_dict def _reshape_keypoints(self, keys_to_tensors): """Reshape keypoints. The keypoints are reshaped to [num_instances, num_keypoints, 2]. Args: keys_to_tensors: a dictionary from keys to tensors. Expected keys are: 'image/object/keypoint/x' 'image/object/keypoint/y' Returns: A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values in [0, 1]. """ y = keys_to_tensors['image/object/keypoint/y'] if isinstance(y, tf.SparseTensor): y = tf.sparse_tensor_to_dense(y) y = tf.expand_dims(y, 1) x = keys_to_tensors['image/object/keypoint/x'] if isinstance(x, tf.SparseTensor): x = tf.sparse_tensor_to_dense(x) x = tf.expand_dims(x, 1) keypoints = tf.concat([y, x], 1) keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2]) return keypoints def _reshape_keypoint_visibilities(self, keys_to_tensors): """Reshape keypoint visibilities. The keypoint visibilities are reshaped to [num_instances, num_keypoints]. The raw keypoint visibilities are expected to conform to the MSCoco definition. See Visibility enum. The returned boolean is True for the labeled case (either Visibility.NOT_VISIBLE or Visibility.VISIBLE). These are the same categories that COCO uses to evaluate keypoint detection performance: http://cocodataset.org/#keypoints-eval If image/object/keypoint/visibility is not provided, visibilities will be set to True for finite keypoint coordinate values, and 0 if the coordinates are NaN. Args: keys_to_tensors: a dictionary from keys to tensors. Expected keys are: 'image/object/keypoint/x' 'image/object/keypoint/visibility' Returns: A 2-D bool tensor of shape [num_instances, num_keypoints] with values in {0, 1}. 1 if the keypoint is labeled, 0 otherwise. """ x = keys_to_tensors['image/object/keypoint/x'] vis = keys_to_tensors['image/object/keypoint/visibility'] if isinstance(vis, tf.SparseTensor): vis = tf.sparse_tensor_to_dense(vis) if isinstance(x, tf.SparseTensor): x = tf.sparse_tensor_to_dense(x) default_vis = tf.where( tf.math.is_nan(x), Visibility.UNLABELED.value * tf.ones_like(x, dtype=tf.int64), Visibility.VISIBLE.value * tf.ones_like(x, dtype=tf.int64)) # Use visibility if provided, otherwise use the default visibility. vis = tf.cond(tf.equal(tf.size(x), tf.size(vis)), true_fn=lambda: vis, false_fn=lambda: default_vis) vis = tf.math.logical_or( tf.math.equal(vis, Visibility.NOT_VISIBLE.value), tf.math.equal(vis, Visibility.VISIBLE.value)) vis = tf.reshape(vis, [-1, self._num_keypoints]) return vis def _reshape_instance_masks(self, keys_to_tensors): """Reshape instance segmentation masks. The instance segmentation masks are reshaped to [num_instances, height, width]. Args: keys_to_tensors: a dictionary from keys to tensors. Returns: A 3-D float tensor of shape [num_instances, height, width] with values in {0, 1}. """ height = keys_to_tensors['image/height'] width = keys_to_tensors['image/width'] to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32) masks = keys_to_tensors['image/object/mask'] if isinstance(masks, tf.SparseTensor): masks = tf.sparse_tensor_to_dense(masks) masks = tf.reshape( tf.cast(tf.greater(masks, 0.0), dtype=tf.float32), to_shape) return tf.cast(masks, tf.float32) def _reshape_context_features(self, keys_to_tensors): """Reshape context features. The instance context_features are reshaped to [num_context_features, context_feature_length] Args: keys_to_tensors: a dictionary from keys to tensors. Returns: A 2-D float tensor of shape [num_context_features, context_feature_length] """ context_feature_length = keys_to_tensors['image/context_feature_length'] to_shape = tf.cast(tf.stack([-1, context_feature_length]), tf.int32) context_features = keys_to_tensors['image/context_features'] if isinstance(context_features, tf.SparseTensor): context_features = tf.sparse_tensor_to_dense(context_features) context_features = tf.reshape(context_features, to_shape) return context_features def _decode_png_instance_masks(self, keys_to_tensors): """Decode PNG instance segmentation masks and stack into dense tensor. The instance segmentation masks are reshaped to [num_instances, height, width]. Args: keys_to_tensors: a dictionary from keys to tensors. Returns: A 3-D float tensor of shape [num_instances, height, width] with values in {0, 1}. """ def decode_png_mask(image_buffer): image = tf.squeeze( tf.image.decode_image(image_buffer, channels=1), axis=2) image.set_shape([None, None]) image = tf.cast(tf.greater(image, 0), dtype=tf.float32) return image png_masks = keys_to_tensors['image/object/mask'] height = keys_to_tensors['image/height'] width = keys_to_tensors['image/width'] if isinstance(png_masks, tf.SparseTensor): png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='') return tf.cond( tf.greater(tf.size(png_masks), 0), lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32), lambda: tf.zeros(tf.cast(tf.stack([0, height, width]), dtype=tf.int32))) def _expand_image_label_hierarchy(self, image_classes, image_confidences): """Expand image level labels according to the hierarchy. Args: image_classes: Int64 tensor with the image level class ids for a sample. image_confidences: Float tensor signaling whether a class id is present in the image (1.0) or not present (0.0). Returns: new_image_classes: Int64 tensor equal to expanding image_classes. new_image_confidences: Float tensor equal to expanding image_confidences. """ def expand_labels(relation_tensor, confidence_value): """Expand to ancestors or descendants depending on arguments.""" mask = tf.equal(image_confidences, confidence_value) target_image_classes = tf.boolean_mask(image_classes, mask) expanded_indices = tf.reduce_any((tf.gather( relation_tensor, target_image_classes - _LABEL_OFFSET, axis=0) > 0), axis=0) expanded_indices = tf.where(expanded_indices)[:, 0] + _LABEL_OFFSET new_groundtruth_image_classes = ( tf.concat([ tf.boolean_mask(image_classes, tf.logical_not(mask)), expanded_indices, ], axis=0)) new_groundtruth_image_confidences = ( tf.concat([ tf.boolean_mask(image_confidences, tf.logical_not(mask)), tf.ones([tf.shape(expanded_indices)[0]], dtype=image_confidences.dtype) * confidence_value, ], axis=0)) return new_groundtruth_image_classes, new_groundtruth_image_confidences image_classes, image_confidences = expand_labels(self._ancestors_lut, 1.0) new_image_classes, new_image_confidences = expand_labels( self._descendants_lut, 0.0) return new_image_classes, new_image_confidences def _expansion_box_field_labels(self, object_classes, object_field, copy_class_id=False): """Expand the labels of a specific object field according to the hierarchy. Args: object_classes: Int64 tensor with the class id for each element in object_field. object_field: Tensor to be expanded. copy_class_id: Boolean to choose whether to use class id values in the output tensor instead of replicating the original values. Returns: A tensor with the result of expanding object_field. """ expanded_indices = tf.gather( self._ancestors_lut, object_classes - _LABEL_OFFSET, axis=0) if copy_class_id: new_object_field = tf.where(expanded_indices > 0)[:, 1] + _LABEL_OFFSET else: new_object_field = tf.repeat( object_field, tf.reduce_sum(expanded_indices, axis=1), axis=0) return new_object_field