# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This file contains functions to build encoder and decoder.""" import tensorflow as tf from deeplab2 import config_pb2 from deeplab2.model.decoder import deeplabv3 from deeplab2.model.decoder import deeplabv3plus from deeplab2.model.decoder import max_deeplab from deeplab2.model.decoder import motion_deeplab_decoder from deeplab2.model.decoder import panoptic_deeplab from deeplab2.model.decoder import vip_deeplab_decoder from deeplab2.model.encoder import axial_resnet_instances from deeplab2.model.encoder import mobilenet def create_encoder(backbone_options: config_pb2.ModelOptions.BackboneOptions, bn_layer: tf.keras.layers.Layer, conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model: """Creates an encoder. Args: backbone_options: A proto config of type config_pb2.ModelOptions.BackboneOptions. bn_layer: A tf.keras.layers.Layer that computes the normalization. conv_kernel_weight_decay: A float, the weight decay for convolution kernels. Returns: An instance of tf.keras.Model containing the encoder. Raises: ValueError: An error occurs when the specified encoder meta architecture is not supported. """ if ('resnet' in backbone_options.name or 'swidernet' in backbone_options.name or 'axial_deeplab' in backbone_options.name or 'max_deeplab' in backbone_options.name): return create_resnet_encoder( backbone_options, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay) elif 'mobilenet' in backbone_options.name: return create_mobilenet_encoder( backbone_options, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay) raise ValueError('The specified encoder %s is not a valid encoder.' % backbone_options.name) def create_mobilenet_encoder( backbone_options: config_pb2.ModelOptions.BackboneOptions, bn_layer: tf.keras.layers.Layer, conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model: """Creates a MobileNet encoder specified by name. Args: backbone_options: A proto config of type config_pb2.ModelOptions.BackboneOptions. bn_layer: A tf.keras.layers.Layer that computes the normalization. conv_kernel_weight_decay: A float, the weight decay for convolution kernels. Returns: An instance of tf.keras.Model containing the MobileNet encoder. """ if backbone_options.name.lower() == 'mobilenet_v3_large': backbone = mobilenet.MobileNetV3Large elif backbone_options.name.lower() == 'mobilenet_v3_small': backbone = mobilenet.MobileNetV3Small else: raise ValueError('The specified encoder %s is not a valid encoder.' % backbone_options.name) assert backbone_options.use_squeeze_and_excite assert backbone_options.drop_path_keep_prob == 1 assert backbone_options.use_sac_beyond_stride == -1 assert backbone_options.backbone_layer_multiplier == 1 return backbone( output_stride=backbone_options.output_stride, width_multiplier=backbone_options.backbone_width_multiplier, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay) def create_resnet_encoder( backbone_options: config_pb2.ModelOptions.BackboneOptions, bn_layer: tf.keras.layers.Layer, conv_kernel_weight_decay: float = 0.0) -> tf.keras.Model: """Creates a ResNet encoder specified by name. Args: backbone_options: A proto config of type config_pb2.ModelOptions.BackboneOptions. bn_layer: A tf.keras.layers.Layer that computes the normalization. conv_kernel_weight_decay: A float, the weight decay for convolution kernels. Returns: An instance of tf.keras.Model containing the ResNet encoder. """ return axial_resnet_instances.get_model( backbone_options.name, output_stride=backbone_options.output_stride, stem_width_multiplier=backbone_options.stem_width_multiplier, width_multiplier=backbone_options.backbone_width_multiplier, backbone_layer_multiplier=backbone_options.backbone_layer_multiplier, block_group_config={ 'use_squeeze_and_excite': backbone_options.use_squeeze_and_excite, 'drop_path_keep_prob': backbone_options.drop_path_keep_prob, 'drop_path_schedule': backbone_options.drop_path_schedule, 'use_sac_beyond_stride': backbone_options.use_sac_beyond_stride}, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay) def create_decoder(model_options: config_pb2.ModelOptions, bn_layer: tf.keras.layers.Layer, ignore_label: int) -> tf.keras.Model: """Creates a DeepLab decoder. Args: model_options: A proto config of type config_pb2.ModelOptions. bn_layer: A tf.keras.layers.Layer that computes the normalization. ignore_label: An integer specifying the ignore label. Returns: An instance of tf.keras.layers.Layer containing the decoder. Raises: ValueError: An error occurs when the specified meta architecture is not supported. """ meta_architecture = model_options.WhichOneof('meta_architecture') if meta_architecture == 'deeplab_v3': return deeplabv3.DeepLabV3( model_options.decoder, model_options.deeplab_v3, bn_layer=bn_layer) elif meta_architecture == 'deeplab_v3_plus': return deeplabv3plus.DeepLabV3Plus( model_options.decoder, model_options.deeplab_v3_plus, bn_layer=bn_layer) elif meta_architecture == 'panoptic_deeplab': return panoptic_deeplab.PanopticDeepLab( model_options.decoder, model_options.panoptic_deeplab, bn_layer=bn_layer) elif meta_architecture == 'motion_deeplab': return motion_deeplab_decoder.MotionDeepLabDecoder( model_options.decoder, model_options.motion_deeplab, bn_layer=bn_layer) elif meta_architecture == 'vip_deeplab': return vip_deeplab_decoder.ViPDeepLabDecoder( model_options.decoder, model_options.vip_deeplab, bn_layer=bn_layer) elif meta_architecture == 'max_deeplab': return max_deeplab.MaXDeepLab( model_options.decoder, model_options.max_deeplab, ignore_label=ignore_label, bn_layer=bn_layer) raise ValueError('The specified meta architecture %s is not implemented.' % meta_architecture)