deeplab2 / model /builder.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains functions to build encoder and decoder."""
import tensorflow as tf
from deeplab2 import config_pb2
from deeplab2.model.decoder import deeplabv3
from deeplab2.model.decoder import deeplabv3plus
from deeplab2.model.decoder import max_deeplab
from deeplab2.model.decoder import motion_deeplab_decoder
from deeplab2.model.decoder import panoptic_deeplab
from deeplab2.model.decoder import vip_deeplab_decoder
from deeplab2.model.encoder import axial_resnet_instances
from deeplab2.model.encoder import mobilenet
def create_encoder(backbone_options: config_pb2.ModelOptions.BackboneOptions,
bn_layer: tf.keras.layers.Layer,
conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model:
"""Creates an encoder.
Args:
backbone_options: A proto config of type
config_pb2.ModelOptions.BackboneOptions.
bn_layer: A tf.keras.layers.Layer that computes the normalization.
conv_kernel_weight_decay: A float, the weight decay for convolution kernels.
Returns:
An instance of tf.keras.Model containing the encoder.
Raises:
ValueError: An error occurs when the specified encoder meta architecture is
not supported.
"""
if ('resnet' in backbone_options.name or
'swidernet' in backbone_options.name or
'axial_deeplab' in backbone_options.name or
'max_deeplab' in backbone_options.name):
return create_resnet_encoder(
backbone_options,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay)
elif 'mobilenet' in backbone_options.name:
return create_mobilenet_encoder(
backbone_options,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay)
raise ValueError('The specified encoder %s is not a valid encoder.' %
backbone_options.name)
def create_mobilenet_encoder(
backbone_options: config_pb2.ModelOptions.BackboneOptions,
bn_layer: tf.keras.layers.Layer,
conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model:
"""Creates a MobileNet encoder specified by name.
Args:
backbone_options: A proto config of type
config_pb2.ModelOptions.BackboneOptions.
bn_layer: A tf.keras.layers.Layer that computes the normalization.
conv_kernel_weight_decay: A float, the weight decay for convolution kernels.
Returns:
An instance of tf.keras.Model containing the MobileNet encoder.
"""
if backbone_options.name.lower() == 'mobilenet_v3_large':
backbone = mobilenet.MobileNetV3Large
elif backbone_options.name.lower() == 'mobilenet_v3_small':
backbone = mobilenet.MobileNetV3Small
else:
raise ValueError('The specified encoder %s is not a valid encoder.' %
backbone_options.name)
assert backbone_options.use_squeeze_and_excite
assert backbone_options.drop_path_keep_prob == 1
assert backbone_options.use_sac_beyond_stride == -1
assert backbone_options.backbone_layer_multiplier == 1
return backbone(
output_stride=backbone_options.output_stride,
width_multiplier=backbone_options.backbone_width_multiplier,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay)
def create_resnet_encoder(
backbone_options: config_pb2.ModelOptions.BackboneOptions,
bn_layer: tf.keras.layers.Layer,
conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model:
"""Creates a ResNet encoder specified by name.
Args:
backbone_options: A proto config of type
config_pb2.ModelOptions.BackboneOptions.
bn_layer: A tf.keras.layers.Layer that computes the normalization.
conv_kernel_weight_decay: A float, the weight decay for convolution kernels.
Returns:
An instance of tf.keras.Model containing the ResNet encoder.
"""
return axial_resnet_instances.get_model(
backbone_options.name,
output_stride=backbone_options.output_stride,
stem_width_multiplier=backbone_options.stem_width_multiplier,
width_multiplier=backbone_options.backbone_width_multiplier,
backbone_layer_multiplier=backbone_options.backbone_layer_multiplier,
block_group_config={
'use_squeeze_and_excite': backbone_options.use_squeeze_and_excite,
'drop_path_keep_prob': backbone_options.drop_path_keep_prob,
'drop_path_schedule': backbone_options.drop_path_schedule,
'use_sac_beyond_stride': backbone_options.use_sac_beyond_stride},
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay)
def create_decoder(model_options: config_pb2.ModelOptions,
bn_layer: tf.keras.layers.Layer,
ignore_label: int) -> tf.keras.Model:
"""Creates a DeepLab decoder.
Args:
model_options: A proto config of type config_pb2.ModelOptions.
bn_layer: A tf.keras.layers.Layer that computes the normalization.
ignore_label: An integer specifying the ignore label.
Returns:
An instance of tf.keras.layers.Layer containing the decoder.
Raises:
ValueError: An error occurs when the specified meta architecture is not
supported.
"""
meta_architecture = model_options.WhichOneof('meta_architecture')
if meta_architecture == 'deeplab_v3':
return deeplabv3.DeepLabV3(
model_options.decoder, model_options.deeplab_v3, bn_layer=bn_layer)
elif meta_architecture == 'deeplab_v3_plus':
return deeplabv3plus.DeepLabV3Plus(
model_options.decoder, model_options.deeplab_v3_plus, bn_layer=bn_layer)
elif meta_architecture == 'panoptic_deeplab':
return panoptic_deeplab.PanopticDeepLab(
model_options.decoder,
model_options.panoptic_deeplab,
bn_layer=bn_layer)
elif meta_architecture == 'motion_deeplab':
return motion_deeplab_decoder.MotionDeepLabDecoder(
model_options.decoder,
model_options.motion_deeplab,
bn_layer=bn_layer)
elif meta_architecture == 'vip_deeplab':
return vip_deeplab_decoder.ViPDeepLabDecoder(
model_options.decoder,
model_options.vip_deeplab,
bn_layer=bn_layer)
elif meta_architecture == 'max_deeplab':
return max_deeplab.MaXDeepLab(
model_options.decoder,
model_options.max_deeplab,
ignore_label=ignore_label,
bn_layer=bn_layer)
raise ValueError('The specified meta architecture %s is not implemented.' %
meta_architecture)