|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for box_coder_builder.""" |
|
|
|
import tensorflow as tf |
|
|
|
from google.protobuf import text_format |
|
from object_detection.box_coders import faster_rcnn_box_coder |
|
from object_detection.box_coders import keypoint_box_coder |
|
from object_detection.box_coders import mean_stddev_box_coder |
|
from object_detection.box_coders import square_box_coder |
|
from object_detection.builders import box_coder_builder |
|
from object_detection.protos import box_coder_pb2 |
|
|
|
|
|
class BoxCoderBuilderTest(tf.test.TestCase): |
|
|
|
def test_build_faster_rcnn_box_coder_with_defaults(self): |
|
box_coder_text_proto = """ |
|
faster_rcnn_box_coder { |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertIsInstance(box_coder_object, |
|
faster_rcnn_box_coder.FasterRcnnBoxCoder) |
|
self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0]) |
|
|
|
def test_build_faster_rcnn_box_coder_with_non_default_parameters(self): |
|
box_coder_text_proto = """ |
|
faster_rcnn_box_coder { |
|
y_scale: 6.0 |
|
x_scale: 3.0 |
|
height_scale: 7.0 |
|
width_scale: 8.0 |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertIsInstance(box_coder_object, |
|
faster_rcnn_box_coder.FasterRcnnBoxCoder) |
|
self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0]) |
|
|
|
def test_build_keypoint_box_coder_with_defaults(self): |
|
box_coder_text_proto = """ |
|
keypoint_box_coder { |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertIsInstance(box_coder_object, keypoint_box_coder.KeypointBoxCoder) |
|
self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0]) |
|
|
|
def test_build_keypoint_box_coder_with_non_default_parameters(self): |
|
box_coder_text_proto = """ |
|
keypoint_box_coder { |
|
num_keypoints: 6 |
|
y_scale: 6.0 |
|
x_scale: 3.0 |
|
height_scale: 7.0 |
|
width_scale: 8.0 |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertIsInstance(box_coder_object, keypoint_box_coder.KeypointBoxCoder) |
|
self.assertEqual(box_coder_object._num_keypoints, 6) |
|
self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0]) |
|
|
|
def test_build_mean_stddev_box_coder(self): |
|
box_coder_text_proto = """ |
|
mean_stddev_box_coder { |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertTrue( |
|
isinstance(box_coder_object, |
|
mean_stddev_box_coder.MeanStddevBoxCoder)) |
|
|
|
def test_build_square_box_coder_with_defaults(self): |
|
box_coder_text_proto = """ |
|
square_box_coder { |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertTrue( |
|
isinstance(box_coder_object, square_box_coder.SquareBoxCoder)) |
|
self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0]) |
|
|
|
def test_build_square_box_coder_with_non_default_parameters(self): |
|
box_coder_text_proto = """ |
|
square_box_coder { |
|
y_scale: 6.0 |
|
x_scale: 3.0 |
|
length_scale: 7.0 |
|
} |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
box_coder_object = box_coder_builder.build(box_coder_proto) |
|
self.assertTrue( |
|
isinstance(box_coder_object, square_box_coder.SquareBoxCoder)) |
|
self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0]) |
|
|
|
def test_raise_error_on_empty_box_coder(self): |
|
box_coder_text_proto = """ |
|
""" |
|
box_coder_proto = box_coder_pb2.BoxCoder() |
|
text_format.Merge(box_coder_text_proto, box_coder_proto) |
|
with self.assertRaises(ValueError): |
|
box_coder_builder.build(box_coder_proto) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|