DR-App / object_detection /dataset_tools /create_coco_tf_record_test.py
pat229988's picture
Upload 653 files
9a393e2
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for create_coco_tf_record.py."""
import io
import json
import os
import numpy as np
import PIL.Image
import tensorflow as tf
from object_detection.dataset_tools import create_coco_tf_record
class CreateCocoTFRecordTest(tf.test.TestCase):
def _assertProtoEqual(self, proto_field, expectation):
"""Helper function to assert if a proto field equals some value.
Args:
proto_field: The protobuf field to compare.
expectation: The expected value of the protobuf field.
"""
proto_list = [p for p in proto_field]
self.assertListEqual(proto_list, expectation)
def test_create_tf_example(self):
image_file_name = 'tmp_image.jpg'
image_data = np.random.rand(256, 256, 3)
tmp_dir = self.get_temp_dir()
save_path = os.path.join(tmp_dir, image_file_name)
image = PIL.Image.fromarray(image_data, 'RGB')
image.save(save_path)
image = {
'file_name': image_file_name,
'height': 256,
'width': 256,
'id': 11,
}
annotations_list = [{
'area': .5,
'iscrowd': False,
'image_id': 11,
'bbox': [64, 64, 128, 128],
'category_id': 2,
'id': 1000,
}]
image_dir = tmp_dir
category_index = {
1: {
'name': 'dog',
'id': 1
},
2: {
'name': 'cat',
'id': 2
},
3: {
'name': 'human',
'id': 3
}
}
(_, example,
num_annotations_skipped) = create_coco_tf_record.create_tf_example(
image, annotations_list, image_dir, category_index)
self.assertEqual(num_annotations_skipped, 0)
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [256])
self._assertProtoEqual(
example.features.feature['image/width'].int64_list.value, [256])
self._assertProtoEqual(
example.features.feature['image/filename'].bytes_list.value,
[image_file_name])
self._assertProtoEqual(
example.features.feature['image/source_id'].bytes_list.value,
[str(image['id'])])
self._assertProtoEqual(
example.features.feature['image/format'].bytes_list.value, ['jpeg'])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmin'].float_list.value,
[0.25])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymin'].float_list.value,
[0.25])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmax'].float_list.value,
[0.75])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymax'].float_list.value,
[0.75])
self._assertProtoEqual(
example.features.feature['image/object/class/text'].bytes_list.value,
['cat'])
def test_create_tf_example_with_instance_masks(self):
image_file_name = 'tmp_image.jpg'
image_data = np.random.rand(8, 8, 3)
tmp_dir = self.get_temp_dir()
save_path = os.path.join(tmp_dir, image_file_name)
image = PIL.Image.fromarray(image_data, 'RGB')
image.save(save_path)
image = {
'file_name': image_file_name,
'height': 8,
'width': 8,
'id': 11,
}
annotations_list = [{
'area': .5,
'iscrowd': False,
'image_id': 11,
'bbox': [0, 0, 8, 8],
'segmentation': [[4, 0, 0, 0, 0, 4], [8, 4, 4, 8, 8, 8]],
'category_id': 1,
'id': 1000,
}]
image_dir = tmp_dir
category_index = {
1: {
'name': 'dog',
'id': 1
},
}
(_, example,
num_annotations_skipped) = create_coco_tf_record.create_tf_example(
image, annotations_list, image_dir, category_index, include_masks=True)
self.assertEqual(num_annotations_skipped, 0)
self._assertProtoEqual(
example.features.feature['image/height'].int64_list.value, [8])
self._assertProtoEqual(
example.features.feature['image/width'].int64_list.value, [8])
self._assertProtoEqual(
example.features.feature['image/filename'].bytes_list.value,
[image_file_name])
self._assertProtoEqual(
example.features.feature['image/source_id'].bytes_list.value,
[str(image['id'])])
self._assertProtoEqual(
example.features.feature['image/format'].bytes_list.value, ['jpeg'])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmin'].float_list.value,
[0])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymin'].float_list.value,
[0])
self._assertProtoEqual(
example.features.feature['image/object/bbox/xmax'].float_list.value,
[1])
self._assertProtoEqual(
example.features.feature['image/object/bbox/ymax'].float_list.value,
[1])
self._assertProtoEqual(
example.features.feature['image/object/class/text'].bytes_list.value,
['dog'])
encoded_mask_pngs = [
io.BytesIO(encoded_masks) for encoded_masks in example.features.feature[
'image/object/mask'].bytes_list.value
]
pil_masks = [
np.array(PIL.Image.open(encoded_mask_png))
for encoded_mask_png in encoded_mask_pngs
]
self.assertTrue(len(pil_masks) == 1)
self.assertAllEqual(pil_masks[0],
[[1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1, 1, 1]])
def test_create_sharded_tf_record(self):
tmp_dir = self.get_temp_dir()
image_paths = ['tmp1_image.jpg', 'tmp2_image.jpg']
for image_path in image_paths:
image_data = np.random.rand(256, 256, 3)
save_path = os.path.join(tmp_dir, image_path)
image = PIL.Image.fromarray(image_data, 'RGB')
image.save(save_path)
images = [{
'file_name': image_paths[0],
'height': 256,
'width': 256,
'id': 11,
}, {
'file_name': image_paths[1],
'height': 256,
'width': 256,
'id': 12,
}]
annotations = [{
'area': .5,
'iscrowd': False,
'image_id': 11,
'bbox': [64, 64, 128, 128],
'category_id': 2,
'id': 1000,
}]
category_index = [{
'name': 'dog',
'id': 1
}, {
'name': 'cat',
'id': 2
}, {
'name': 'human',
'id': 3
}]
groundtruth_data = {'images': images, 'annotations': annotations,
'categories': category_index}
annotation_file = os.path.join(tmp_dir, 'annotation.json')
with open(annotation_file, 'w') as annotation_fid:
json.dump(groundtruth_data, annotation_fid)
output_path = os.path.join(tmp_dir, 'out.record')
create_coco_tf_record._create_tf_record_from_coco_annotations(
annotation_file,
tmp_dir,
output_path,
False,
2)
self.assertTrue(os.path.exists(output_path + '-00000-of-00002'))
self.assertTrue(os.path.exists(output_path + '-00001-of-00002'))
if __name__ == '__main__':
tf.test.main()