|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for the OpenImages label expansion (OIDHierarchicalLabelsExpansion).""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow as tf |
|
|
|
from object_detection.dataset_tools import oid_hierarchical_labels_expansion |
|
|
|
|
|
def create_test_data(): |
|
hierarchy = { |
|
'LabelName': |
|
'a', |
|
'Subcategory': [{ |
|
'LabelName': 'b' |
|
}, { |
|
'LabelName': 'c', |
|
'Subcategory': [{ |
|
'LabelName': 'd' |
|
}, { |
|
'LabelName': 'e' |
|
}] |
|
}, { |
|
'LabelName': 'f', |
|
'Subcategory': [{ |
|
'LabelName': 'd' |
|
},] |
|
}] |
|
} |
|
bbox_rows = [ |
|
'123,xclick,b,1,0.1,0.2,0.1,0.2,1,1,0,0,0', |
|
'123,xclick,d,1,0.2,0.3,0.1,0.2,1,1,0,0,0' |
|
] |
|
label_rows = [ |
|
'123,verification,b,0', '123,verification,c,0', '124,verification,d,1' |
|
] |
|
return hierarchy, bbox_rows, label_rows |
|
|
|
|
|
class HierarchicalLabelsExpansionTest(tf.test.TestCase): |
|
|
|
def test_bbox_expansion(self): |
|
hierarchy, bbox_rows, _ = create_test_data() |
|
expansion_generator = ( |
|
oid_hierarchical_labels_expansion.OIDHierarchicalLabelsExpansion( |
|
hierarchy)) |
|
all_result_rows = [] |
|
for row in bbox_rows: |
|
all_result_rows.extend(expansion_generator.expand_boxes_from_csv(row)) |
|
self.assertItemsEqual([ |
|
'123,xclick,b,1,0.1,0.2,0.1,0.2,1,1,0,0,0', |
|
'123,xclick,d,1,0.2,0.3,0.1,0.2,1,1,0,0,0', |
|
'123,xclick,f,1,0.2,0.3,0.1,0.2,1,1,0,0,0', |
|
'123,xclick,c,1,0.2,0.3,0.1,0.2,1,1,0,0,0' |
|
], all_result_rows) |
|
|
|
def test_labels_expansion(self): |
|
hierarchy, _, label_rows = create_test_data() |
|
expansion_generator = ( |
|
oid_hierarchical_labels_expansion.OIDHierarchicalLabelsExpansion( |
|
hierarchy)) |
|
all_result_rows = [] |
|
for row in label_rows: |
|
all_result_rows.extend(expansion_generator.expand_labels_from_csv(row)) |
|
self.assertItemsEqual([ |
|
'123,verification,b,0', '123,verification,c,0', '123,verification,d,0', |
|
'123,verification,e,0', '124,verification,d,1', '124,verification,f,1', |
|
'124,verification,c,1' |
|
], all_result_rows) |
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|