|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This file contains functions to build encoder and decoder.""" |
|
import tensorflow as tf |
|
|
|
from deeplab2 import config_pb2 |
|
from deeplab2.model.decoder import deeplabv3 |
|
from deeplab2.model.decoder import deeplabv3plus |
|
from deeplab2.model.decoder import max_deeplab |
|
from deeplab2.model.decoder import motion_deeplab_decoder |
|
from deeplab2.model.decoder import panoptic_deeplab |
|
from deeplab2.model.decoder import vip_deeplab_decoder |
|
from deeplab2.model.encoder import axial_resnet_instances |
|
from deeplab2.model.encoder import mobilenet |
|
|
|
|
|
def create_encoder(backbone_options: config_pb2.ModelOptions.BackboneOptions, |
|
bn_layer: tf.keras.layers.Layer, |
|
conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model: |
|
"""Creates an encoder. |
|
|
|
Args: |
|
backbone_options: A proto config of type |
|
config_pb2.ModelOptions.BackboneOptions. |
|
bn_layer: A tf.keras.layers.Layer that computes the normalization. |
|
conv_kernel_weight_decay: A float, the weight decay for convolution kernels. |
|
|
|
Returns: |
|
An instance of tf.keras.Model containing the encoder. |
|
|
|
Raises: |
|
ValueError: An error occurs when the specified encoder meta architecture is |
|
not supported. |
|
""" |
|
if ('resnet' in backbone_options.name or |
|
'swidernet' in backbone_options.name or |
|
'axial_deeplab' in backbone_options.name or |
|
'max_deeplab' in backbone_options.name): |
|
return create_resnet_encoder( |
|
backbone_options, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
elif 'mobilenet' in backbone_options.name: |
|
return create_mobilenet_encoder( |
|
backbone_options, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
raise ValueError('The specified encoder %s is not a valid encoder.' % |
|
backbone_options.name) |
|
|
|
|
|
def create_mobilenet_encoder( |
|
backbone_options: config_pb2.ModelOptions.BackboneOptions, |
|
bn_layer: tf.keras.layers.Layer, |
|
conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model: |
|
"""Creates a MobileNet encoder specified by name. |
|
|
|
Args: |
|
backbone_options: A proto config of type |
|
config_pb2.ModelOptions.BackboneOptions. |
|
bn_layer: A tf.keras.layers.Layer that computes the normalization. |
|
conv_kernel_weight_decay: A float, the weight decay for convolution kernels. |
|
|
|
Returns: |
|
An instance of tf.keras.Model containing the MobileNet encoder. |
|
""" |
|
if backbone_options.name.lower() == 'mobilenet_v3_large': |
|
backbone = mobilenet.MobileNetV3Large |
|
elif backbone_options.name.lower() == 'mobilenet_v3_small': |
|
backbone = mobilenet.MobileNetV3Small |
|
else: |
|
raise ValueError('The specified encoder %s is not a valid encoder.' % |
|
backbone_options.name) |
|
assert backbone_options.use_squeeze_and_excite |
|
assert backbone_options.drop_path_keep_prob == 1 |
|
assert backbone_options.use_sac_beyond_stride == -1 |
|
assert backbone_options.backbone_layer_multiplier == 1 |
|
return backbone( |
|
output_stride=backbone_options.output_stride, |
|
width_multiplier=backbone_options.backbone_width_multiplier, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
|
|
|
|
def create_resnet_encoder( |
|
backbone_options: config_pb2.ModelOptions.BackboneOptions, |
|
bn_layer: tf.keras.layers.Layer, |
|
conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model: |
|
"""Creates a ResNet encoder specified by name. |
|
|
|
Args: |
|
backbone_options: A proto config of type |
|
config_pb2.ModelOptions.BackboneOptions. |
|
bn_layer: A tf.keras.layers.Layer that computes the normalization. |
|
conv_kernel_weight_decay: A float, the weight decay for convolution kernels. |
|
|
|
Returns: |
|
An instance of tf.keras.Model containing the ResNet encoder. |
|
""" |
|
return axial_resnet_instances.get_model( |
|
backbone_options.name, |
|
output_stride=backbone_options.output_stride, |
|
stem_width_multiplier=backbone_options.stem_width_multiplier, |
|
width_multiplier=backbone_options.backbone_width_multiplier, |
|
backbone_layer_multiplier=backbone_options.backbone_layer_multiplier, |
|
block_group_config={ |
|
'use_squeeze_and_excite': backbone_options.use_squeeze_and_excite, |
|
'drop_path_keep_prob': backbone_options.drop_path_keep_prob, |
|
'drop_path_schedule': backbone_options.drop_path_schedule, |
|
'use_sac_beyond_stride': backbone_options.use_sac_beyond_stride}, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
|
|
|
|
def create_decoder(model_options: config_pb2.ModelOptions, |
|
bn_layer: tf.keras.layers.Layer, |
|
ignore_label: int) -> tf.keras.Model: |
|
"""Creates a DeepLab decoder. |
|
|
|
Args: |
|
model_options: A proto config of type config_pb2.ModelOptions. |
|
bn_layer: A tf.keras.layers.Layer that computes the normalization. |
|
ignore_label: An integer specifying the ignore label. |
|
|
|
Returns: |
|
An instance of tf.keras.layers.Layer containing the decoder. |
|
|
|
Raises: |
|
ValueError: An error occurs when the specified meta architecture is not |
|
supported. |
|
""" |
|
meta_architecture = model_options.WhichOneof('meta_architecture') |
|
if meta_architecture == 'deeplab_v3': |
|
return deeplabv3.DeepLabV3( |
|
model_options.decoder, model_options.deeplab_v3, bn_layer=bn_layer) |
|
elif meta_architecture == 'deeplab_v3_plus': |
|
return deeplabv3plus.DeepLabV3Plus( |
|
model_options.decoder, model_options.deeplab_v3_plus, bn_layer=bn_layer) |
|
elif meta_architecture == 'panoptic_deeplab': |
|
return panoptic_deeplab.PanopticDeepLab( |
|
model_options.decoder, |
|
model_options.panoptic_deeplab, |
|
bn_layer=bn_layer) |
|
elif meta_architecture == 'motion_deeplab': |
|
return motion_deeplab_decoder.MotionDeepLabDecoder( |
|
model_options.decoder, |
|
model_options.motion_deeplab, |
|
bn_layer=bn_layer) |
|
elif meta_architecture == 'vip_deeplab': |
|
return vip_deeplab_decoder.ViPDeepLabDecoder( |
|
model_options.decoder, |
|
model_options.vip_deeplab, |
|
bn_layer=bn_layer) |
|
elif meta_architecture == 'max_deeplab': |
|
return max_deeplab.MaXDeepLab( |
|
model_options.decoder, |
|
model_options.max_deeplab, |
|
ignore_label=ignore_label, |
|
bn_layer=bn_layer) |
|
raise ValueError('The specified meta architecture %s is not implemented.' % |
|
meta_architecture) |
|
|