# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tensorflow Example proto parser for data loading. A parser to decode data containing serialized tensorflow.Example protos into materialized tensors (numpy arrays). """ import numpy as np from object_detection.core import data_parser from object_detection.core import standard_fields as fields class FloatParser(data_parser.DataToNumpyParser): """Tensorflow Example float parser.""" def __init__(self, field_name): self.field_name = field_name def parse(self, tf_example): return np.array( tf_example.features.feature[self.field_name].float_list.value, dtype=np.float).transpose() if tf_example.features.feature[ self.field_name].HasField("float_list") else None class StringParser(data_parser.DataToNumpyParser): """Tensorflow Example string parser.""" def __init__(self, field_name): self.field_name = field_name def parse(self, tf_example): return "".join(tf_example.features.feature[self.field_name] .bytes_list.value) if tf_example.features.feature[ self.field_name].HasField("bytes_list") else None class Int64Parser(data_parser.DataToNumpyParser): """Tensorflow Example int64 parser.""" def __init__(self, field_name): self.field_name = field_name def parse(self, tf_example): return np.array( tf_example.features.feature[self.field_name].int64_list.value, dtype=np.int64).transpose() if tf_example.features.feature[ self.field_name].HasField("int64_list") else None class BoundingBoxParser(data_parser.DataToNumpyParser): """Tensorflow Example bounding box parser.""" def __init__(self, xmin_field_name, ymin_field_name, xmax_field_name, ymax_field_name): self.field_names = [ ymin_field_name, xmin_field_name, ymax_field_name, xmax_field_name ] def parse(self, tf_example): result = [] parsed = True for field_name in self.field_names: result.append(tf_example.features.feature[field_name].float_list.value) parsed &= ( tf_example.features.feature[field_name].HasField("float_list")) return np.array(result).transpose() if parsed else None class TfExampleDetectionAndGTParser(data_parser.DataToNumpyParser): """Tensorflow Example proto parser.""" def __init__(self): self.items_to_handlers = { fields.DetectionResultFields.key: StringParser(fields.TfExampleFields.source_id), # Object ground truth boxes and classes. fields.InputDataFields.groundtruth_boxes: (BoundingBoxParser( fields.TfExampleFields.object_bbox_xmin, fields.TfExampleFields.object_bbox_ymin, fields.TfExampleFields.object_bbox_xmax, fields.TfExampleFields.object_bbox_ymax)), fields.InputDataFields.groundtruth_classes: ( Int64Parser(fields.TfExampleFields.object_class_label)), # Object detections. fields.DetectionResultFields.detection_boxes: (BoundingBoxParser( fields.TfExampleFields.detection_bbox_xmin, fields.TfExampleFields.detection_bbox_ymin, fields.TfExampleFields.detection_bbox_xmax, fields.TfExampleFields.detection_bbox_ymax)), fields.DetectionResultFields.detection_classes: ( Int64Parser(fields.TfExampleFields.detection_class_label)), fields.DetectionResultFields.detection_scores: ( FloatParser(fields.TfExampleFields.detection_score)), } self.optional_items_to_handlers = { fields.InputDataFields.groundtruth_difficult: Int64Parser(fields.TfExampleFields.object_difficult), fields.InputDataFields.groundtruth_group_of: Int64Parser(fields.TfExampleFields.object_group_of), fields.InputDataFields.groundtruth_image_classes: Int64Parser(fields.TfExampleFields.image_class_label), } def parse(self, tf_example): """Parses tensorflow example and returns a tensor dictionary. Args: tf_example: a tf.Example object. Returns: A dictionary of the following numpy arrays: fields.DetectionResultFields.source_id - string containing original image id. fields.InputDataFields.groundtruth_boxes - a numpy array containing groundtruth boxes. fields.InputDataFields.groundtruth_classes - a numpy array containing groundtruth classes. fields.InputDataFields.groundtruth_group_of - a numpy array containing groundtruth group of flag (optional, None if not specified). fields.InputDataFields.groundtruth_difficult - a numpy array containing groundtruth difficult flag (optional, None if not specified). fields.InputDataFields.groundtruth_image_classes - a numpy array containing groundtruth image-level labels. fields.DetectionResultFields.detection_boxes - a numpy array containing detection boxes. fields.DetectionResultFields.detection_classes - a numpy array containing detection class labels. fields.DetectionResultFields.detection_scores - a numpy array containing detection scores. Returns None if tf.Example was not parsed or non-optional fields were not found. """ results_dict = {} parsed = True for key, parser in self.items_to_handlers.items(): results_dict[key] = parser.parse(tf_example) parsed &= (results_dict[key] is not None) for key, parser in self.optional_items_to_handlers.items(): results_dict[key] = parser.parse(tf_example) return results_dict if parsed else None