|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for model.builder.""" |
|
|
|
import os |
|
from absl.testing import parameterized |
|
|
|
import tensorflow as tf |
|
|
|
from google.protobuf import text_format |
|
from deeplab2 import config_pb2 |
|
from deeplab2.model import builder |
|
from deeplab2.model.decoder import motion_deeplab_decoder |
|
from deeplab2.model.encoder import axial_resnet_instances |
|
from deeplab2.model.encoder import mobilenet |
|
|
|
|
|
|
|
_CONFIG_PATH = 'deeplab2/configs/example' |
|
|
|
|
|
def _read_proto_file(filename, proto): |
|
filename = filename |
|
with tf.io.gfile.GFile(filename, 'r') as proto_file: |
|
return text_format.ParseLines(proto_file, proto) |
|
|
|
|
|
class BuilderTest(tf.test.TestCase, parameterized.TestCase): |
|
|
|
def test_resnet50_encoder_creation(self): |
|
backbone_options = config_pb2.ModelOptions.BackboneOptions( |
|
name='resnet50', output_stride=32) |
|
encoder = builder.create_encoder( |
|
backbone_options, |
|
tf.keras.layers.experimental.SyncBatchNormalization) |
|
self.assertIsInstance(encoder, axial_resnet_instances.ResNet50) |
|
|
|
@parameterized.parameters('mobilenet_v3_large', 'mobilenet_v3_small') |
|
def test_mobilenet_encoder_creation(self, model_name): |
|
backbone_options = config_pb2.ModelOptions.BackboneOptions( |
|
name=model_name, use_squeeze_and_excite=True, output_stride=32) |
|
encoder = builder.create_encoder( |
|
backbone_options, |
|
tf.keras.layers.experimental.SyncBatchNormalization) |
|
self.assertIsInstance(encoder, mobilenet.MobileNet) |
|
|
|
def test_resnet_encoder_creation(self): |
|
backbone_options = config_pb2.ModelOptions.BackboneOptions( |
|
name='max_deeplab_s', output_stride=32) |
|
encoder = builder.create_resnet_encoder( |
|
backbone_options, |
|
bn_layer=tf.keras.layers.experimental.SyncBatchNormalization) |
|
self.assertIsInstance(encoder, axial_resnet_instances.MaXDeepLabS) |
|
|
|
def test_decoder_creation(self): |
|
proto_filename = os.path.join( |
|
_CONFIG_PATH, 'example_kitti-step_motion_deeplab.textproto') |
|
model_options = _read_proto_file(proto_filename, config_pb2.ModelOptions()) |
|
motion_decoder = builder.create_decoder( |
|
model_options, tf.keras.layers.experimental.SyncBatchNormalization, |
|
ignore_label=255) |
|
self.assertIsInstance(motion_decoder, |
|
motion_deeplab_decoder.MotionDeepLabDecoder) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|