|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for object_detection.utils.label_map_util.""" |
|
|
|
import os |
|
import tensorflow as tf |
|
|
|
from google.protobuf import text_format |
|
from object_detection.protos import string_int_label_map_pb2 |
|
from object_detection.utils import label_map_util |
|
|
|
|
|
class LabelMapUtilTest(tf.test.TestCase): |
|
|
|
def _generate_label_map(self, num_classes): |
|
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() |
|
for i in range(1, num_classes + 1): |
|
item = label_map_proto.item.add() |
|
item.id = i |
|
item.name = 'label_' + str(i) |
|
item.display_name = str(i) |
|
return label_map_proto |
|
|
|
def test_get_label_map_dict(self): |
|
label_map_string = """ |
|
item { |
|
id:2 |
|
name:'cat' |
|
} |
|
item { |
|
id:1 |
|
name:'dog' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
label_map_dict = label_map_util.get_label_map_dict(label_map_path) |
|
self.assertEqual(label_map_dict['dog'], 1) |
|
self.assertEqual(label_map_dict['cat'], 2) |
|
|
|
def test_get_label_map_dict_display(self): |
|
label_map_string = """ |
|
item { |
|
id:2 |
|
display_name:'cat' |
|
} |
|
item { |
|
id:1 |
|
display_name:'dog' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
label_map_dict = label_map_util.get_label_map_dict( |
|
label_map_path, use_display_name=True) |
|
self.assertEqual(label_map_dict['dog'], 1) |
|
self.assertEqual(label_map_dict['cat'], 2) |
|
|
|
def test_load_bad_label_map(self): |
|
label_map_string = """ |
|
item { |
|
id:0 |
|
name:'class that should not be indexed at zero' |
|
} |
|
item { |
|
id:2 |
|
name:'cat' |
|
} |
|
item { |
|
id:1 |
|
name:'dog' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
with self.assertRaises(ValueError): |
|
label_map_util.load_labelmap(label_map_path) |
|
|
|
def test_load_label_map_with_background(self): |
|
label_map_string = """ |
|
item { |
|
id:0 |
|
name:'background' |
|
} |
|
item { |
|
id:2 |
|
name:'cat' |
|
} |
|
item { |
|
id:1 |
|
name:'dog' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
label_map_dict = label_map_util.get_label_map_dict(label_map_path) |
|
self.assertEqual(label_map_dict['background'], 0) |
|
self.assertEqual(label_map_dict['dog'], 1) |
|
self.assertEqual(label_map_dict['cat'], 2) |
|
|
|
def test_get_label_map_dict_with_fill_in_gaps_and_background(self): |
|
label_map_string = """ |
|
item { |
|
id:3 |
|
name:'cat' |
|
} |
|
item { |
|
id:1 |
|
name:'dog' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
label_map_dict = label_map_util.get_label_map_dict( |
|
label_map_path, fill_in_gaps_and_background=True) |
|
|
|
self.assertEqual(label_map_dict['background'], 0) |
|
self.assertEqual(label_map_dict['dog'], 1) |
|
self.assertEqual(label_map_dict['2'], 2) |
|
self.assertEqual(label_map_dict['cat'], 3) |
|
self.assertEqual(len(label_map_dict), max(label_map_dict.values()) + 1) |
|
|
|
def test_keep_categories_with_unique_id(self): |
|
label_map_proto = string_int_label_map_pb2.StringIntLabelMap() |
|
label_map_string = """ |
|
item { |
|
id:2 |
|
name:'cat' |
|
} |
|
item { |
|
id:1 |
|
name:'child' |
|
} |
|
item { |
|
id:1 |
|
name:'person' |
|
} |
|
item { |
|
id:1 |
|
name:'n00007846' |
|
} |
|
""" |
|
text_format.Merge(label_map_string, label_map_proto) |
|
categories = label_map_util.convert_label_map_to_categories( |
|
label_map_proto, max_num_classes=3) |
|
self.assertListEqual([{ |
|
'id': 2, |
|
'name': u'cat' |
|
}, { |
|
'id': 1, |
|
'name': u'child' |
|
}], categories) |
|
|
|
def test_convert_label_map_to_categories_no_label_map(self): |
|
categories = label_map_util.convert_label_map_to_categories( |
|
None, max_num_classes=3) |
|
expected_categories_list = [{ |
|
'name': u'category_1', |
|
'id': 1 |
|
}, { |
|
'name': u'category_2', |
|
'id': 2 |
|
}, { |
|
'name': u'category_3', |
|
'id': 3 |
|
}] |
|
self.assertListEqual(expected_categories_list, categories) |
|
|
|
def test_convert_label_map_to_categories(self): |
|
label_map_proto = self._generate_label_map(num_classes=4) |
|
categories = label_map_util.convert_label_map_to_categories( |
|
label_map_proto, max_num_classes=3) |
|
expected_categories_list = [{ |
|
'name': u'1', |
|
'id': 1 |
|
}, { |
|
'name': u'2', |
|
'id': 2 |
|
}, { |
|
'name': u'3', |
|
'id': 3 |
|
}] |
|
self.assertListEqual(expected_categories_list, categories) |
|
|
|
def test_convert_label_map_to_categories_with_few_classes(self): |
|
label_map_proto = self._generate_label_map(num_classes=4) |
|
cat_no_offset = label_map_util.convert_label_map_to_categories( |
|
label_map_proto, max_num_classes=2) |
|
expected_categories_list = [{ |
|
'name': u'1', |
|
'id': 1 |
|
}, { |
|
'name': u'2', |
|
'id': 2 |
|
}] |
|
self.assertListEqual(expected_categories_list, cat_no_offset) |
|
|
|
def test_get_max_label_map_index(self): |
|
num_classes = 4 |
|
label_map_proto = self._generate_label_map(num_classes=num_classes) |
|
max_index = label_map_util.get_max_label_map_index(label_map_proto) |
|
self.assertEqual(num_classes, max_index) |
|
|
|
def test_create_category_index(self): |
|
categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}] |
|
category_index = label_map_util.create_category_index(categories) |
|
self.assertDictEqual({ |
|
1: { |
|
'name': u'1', |
|
'id': 1 |
|
}, |
|
2: { |
|
'name': u'2', |
|
'id': 2 |
|
} |
|
}, category_index) |
|
|
|
def test_create_categories_from_labelmap(self): |
|
label_map_string = """ |
|
item { |
|
id:1 |
|
name:'dog' |
|
} |
|
item { |
|
id:2 |
|
name:'cat' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
categories = label_map_util.create_categories_from_labelmap(label_map_path) |
|
self.assertListEqual([{ |
|
'name': u'dog', |
|
'id': 1 |
|
}, { |
|
'name': u'cat', |
|
'id': 2 |
|
}], categories) |
|
|
|
def test_create_category_index_from_labelmap(self): |
|
label_map_string = """ |
|
item { |
|
id:2 |
|
name:'cat' |
|
} |
|
item { |
|
id:1 |
|
name:'dog' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
category_index = label_map_util.create_category_index_from_labelmap( |
|
label_map_path) |
|
self.assertDictEqual({ |
|
1: { |
|
'name': u'dog', |
|
'id': 1 |
|
}, |
|
2: { |
|
'name': u'cat', |
|
'id': 2 |
|
} |
|
}, category_index) |
|
|
|
def test_create_category_index_from_labelmap_display(self): |
|
label_map_string = """ |
|
item { |
|
id:2 |
|
name:'cat' |
|
display_name:'meow' |
|
} |
|
item { |
|
id:1 |
|
name:'dog' |
|
display_name:'woof' |
|
} |
|
""" |
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') |
|
with tf.gfile.Open(label_map_path, 'wb') as f: |
|
f.write(label_map_string) |
|
|
|
self.assertDictEqual({ |
|
1: { |
|
'name': u'dog', |
|
'id': 1 |
|
}, |
|
2: { |
|
'name': u'cat', |
|
'id': 2 |
|
} |
|
}, label_map_util.create_category_index_from_labelmap( |
|
label_map_path, False)) |
|
|
|
self.assertDictEqual({ |
|
1: { |
|
'name': u'woof', |
|
'id': 1 |
|
}, |
|
2: { |
|
'name': u'meow', |
|
'id': 2 |
|
} |
|
}, label_map_util.create_category_index_from_labelmap(label_map_path)) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|