# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Implements Axial-ResNets proposed in Axial-DeepLab [1]. [1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, ECCV 2020 Spotlight. Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. """ import tensorflow as tf from deeplab2.model import utils from deeplab2.model.layers import activations from deeplab2.model.layers import axial_block_groups from deeplab2.model.layers import convolutions from deeplab2.model.layers import resized_fuse from deeplab2.model.layers import stems # Add a suffix in layer names that indicate if the current layer is a part of # the backbone or an extra layer, i.e. if the current layer will be pretrained # or not. This name will be used when we apply 10x larger learning rates for # extra parameters that have not been pretrained, in panoptic segmentation. # This keyword is reserved and should not be a part of the variable names in a # classification pretrained backbone. EXTRA = 'extra' # Similarly, we will apply 10x larger learning rates on the memory feature. # This global variable name will be accessed when we build the optimizers. This # keyword is reserved and should not be a part of the variable names in a # classification pretrained backbone. MEMORY_FEATURE = 'memory_feature' class AxialResNet(tf.keras.Model): """An Axial-ResNet model as proposed in Axial-DeepLab [1] and MaX-DeepLab [2]. An Axial-ResNet [1] replaces 3x3 convolutions in a Resnet by axial-attention layers. A dual-path transformer [2] and a stacked decoder [2] can be used optionally. In addition, this class supports scaling models with SWideRNet [3] and augmenting convolutions with Switchable Atrous Convolution [4]. Reference: [1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, ECCV 2020 Spotlight. https://arxiv.org/abs/2003.07853 Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. [2] MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", CVPR 2021. https://arxiv.org/abs/2012.00759 Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. [3] Scaling Wide Residual Networks for Panoptic Segmentation, https://arxiv.org/abs/2011.11675 Liang-Chieh Chen, Huiyu Wang, Siyuan Qiao. [4] DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable Atrous Convolution, CVPR 2021. https://arxiv.org/abs/2006.02334 Siyuan Qiao, Liang-Chieh Chen, Alan Yuille. """ def __init__(self, name, num_blocks=(3, 4, 6, 3), backbone_layer_multiplier=1.0, width_multiplier=1.0, stem_width_multiplier=1.0, output_stride=16, classification_mode=False, backbone_type='resnet_beta', use_axial_beyond_stride=16, backbone_use_transformer_beyond_stride=32, extra_decoder_use_transformer_beyond_stride=32, backbone_decoder_num_stacks=0, backbone_decoder_blocks_per_stage=1, extra_decoder_num_stacks=0, extra_decoder_blocks_per_stage=1, max_num_mask_slots=128, num_mask_slots=128, memory_channels=256, base_transformer_expansion=1.0, global_feed_forward_network_channels=256, high_resolution_output_stride=4, activation='relu', block_group_config=None, bn_layer=tf.keras.layers.BatchNormalization, conv_kernel_weight_decay=0.0): """Initializes an AxialResNet model. Args: name: A string, the name of the model. num_blocks: A list of 4 integers. It denotes the number of blocks to include in the last 4 stages or block groups. Each group consists of blocks that output features of the same resolution. Defaults to (3, 4, 6, 3) as in MaX-DeepLab-S. backbone_layer_multiplier: A float, layer_multiplier for the backbone, excluding the STEM. This flag controls the number of layers. Defaults to 1.0 as in MaX-DeepLab-S. width_multiplier: A float, the channel multiplier for the block groups. Defaults to 1.0 as in MaX-DeepLab-S. stem_width_multiplier: A float, the channel multiplier for stem convolutions. Defaults to 1.0 as in MaX-DeepLab-S. output_stride: An integer, the maximum ratio of input to output spatial resolution. Defaults to 16 as in MaX-DeepLab-S. classification_mode: A boolean, whether to perform in a classification mode. If it is True, this function directly returns backbone feature endpoints. Note that these feature endpoints can also be used directly for Panoptic-DeepLab or Motion-DeepLab. If it is False, this function builds MaX-DeepLab extra decoder layers and extra transformer layers. Defaults to False as in MaX-DeepLab. backbone_type: A string, the type of backbone. Supports 'resnet', 'resnet_beta', and 'wider_resnet'. It controls both the stem type and the residual block type. Defaults to 'resnet_beta' as in MaX-DeepLab-S. use_axial_beyond_stride: An integer, the stride beyond which we use axial attention. Set to 0 if no axial attention is desired. Defaults to 16 as in MaX-DeepLab. backbone_use_transformer_beyond_stride: An integer, the stride beyond which we use a memory path transformer block on top of a regular pixel path block, in the backbone. Set to 0 if no transformer block is desired in the backbone. Defaults to 32 as in MaX-DeepLab-S. extra_decoder_use_transformer_beyond_stride: An integer, the stride beyond which we use a memory path transformer block on top of a regular pixel path block, in the extra decoder stages. Set to 0 if no transformer block is desired in the extra decoder stages. Defaults to 32 as in MaX-DeepLab-S. backbone_decoder_num_stacks: An integer, the number of decoder stacks (introduced in MaX-DeepLab) that we use in the backbone. The stacked decoders are applied in a stacked hour-glass style. Defaults to 0 as in MaX-DeepLab-S. backbone_decoder_blocks_per_stage: An integer, the number of consecutive residual blocks to apply for each decoder stage, in the backbone. Defaults to 1 as in MaX-DeepLab-S. extra_decoder_num_stacks: An integer, the number of decoder stacks (introduced in MaX-DeepLab) that we use in the extra decoder layers. It is different from backbone_decoder_blocks_per_stage in that the extra decoder stacks will be trained from scratch on segmentation tasks, instead of pretrained on ImageNet classification. Defaults to 0 as in MaX-DeepLab-S. extra_decoder_blocks_per_stage: An integer, the number of consecutive residual blocks to apply for each decoder stage, in the extra decoder stages. Defaults to 1 as in MaX-DeepLab-S. max_num_mask_slots: An integer, the maximum possible number of mask slots that will be used. This will be used in a pretraining-finetuning use case with different num_mask_slots: We can set max_num_mask_slots to the maximum possible num_mask_slots, and then the saved checkpoint can be loaded for finetuning with a different num_mask_slots. Defaults to 128 as in MaX-DeepLab. num_mask_slots: An integer, the number of mask slots that will be used. Defaults to 128 as in MaX-DeepLab-S. memory_channels: An integer, the number of channels for the whole memory path. Defaults to 256 as in MaX-DeepLab-S. base_transformer_expansion: A float, the base width expansion rate for transformer layers. Defaults to 1.0 as in MaX-DeepLab-S. global_feed_forward_network_channels: An integer, the number of channels in the final global feed forward network, i.e. the mask feature head and the mask class head. Defaults to 256 as in MaX-DeepLab-S. high_resolution_output_stride: An integer, the final decoding output stride. Defaults to 4 as in MaX-DeepLab-S. activation: A string, type of activation function to apply. Support 'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'. block_group_config: An argument dictionary that will be passed to block_group. bn_layer: An optional tf.keras.layers.Layer that computes the normalization (default: tf.keras.layers.BatchNormalization). conv_kernel_weight_decay: A float, the weight decay for convolution kernels. Raises: ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or 'wider_resnet'. ValueError: If extra_decoder_blocks_per_stage is not greater than zero. """ super(AxialResNet, self).__init__(name=name) if extra_decoder_blocks_per_stage <= 0: raise ValueError( 'Extra_decoder_blocks_per_stage should be great than zero.') if block_group_config is None: block_group_config = {} # Compute parameter lists for block_groups. We consider five stages so that # it is general enough to cover fully axial resnets and wider resnets. total_strides_list = [1, 2, 4, 8, 16] # Append 3 blocks for the first stage of fully axial resnets and wider # resnets. num_blocks_list = [3] + utils.scale_int_list(list(num_blocks), backbone_layer_multiplier) strides_list = [2] * 5 # Expand the transformer and the block filters with the stride. transformer_expansions_list = [] filters_list = [] for index, stride in enumerate(total_strides_list): # Reduce the number of channels when we apply transformer to low level # features (stride = 2, 4, or 8). The base_transformer_expansion is used # for stride = 16, i.e. the standard output_stride for MaX-DeepLab-S. transformer_expansions_list.append(base_transformer_expansion * stride / 16.0) # Compute the base number of filters in each stage. For example, the last # stage of ResNet50 has an input stride of 16, then we compute the base # number of filters for a bottleneck block as 16 * 32 = 512, which is the # number of filters for the 3x3 convolution in those blocks. if backbone_type == 'wider_resnet' and index == 0: # SWideRNet variants use stem_width_multiplier for the first block. filters_list.append(int(round(stride * 32 * stem_width_multiplier))) else: filters_list.append(int(round(stride * 32 * width_multiplier))) self._num_mask_slots = None # Initialize memory_feature only when a transformer block is used. self._use_memory_feature = (backbone_use_transformer_beyond_stride or (extra_decoder_use_transformer_beyond_stride and (not classification_mode))) if self._use_memory_feature: self._memory_feature_shape = (1, max_num_mask_slots, memory_channels) self._memory_feature_initializer = ( tf.keras.initializers.TruncatedNormal(stddev=1.0)) self._memory_feature_regularizer = tf.keras.regularizers.l2( conv_kernel_weight_decay) if num_mask_slots: self._num_mask_slots = num_mask_slots # Use a convolutional stem except fully axial cases. stem_channels = int(round(64 * stem_width_multiplier)) self._activation_fn = activations.get_activation(activation) if use_axial_beyond_stride == 1: self._stem = tf.identity first_block_index = 0 elif backbone_type.lower() == 'wider_resnet': self._stem = convolutions.Conv2DSame( output_channels=stem_channels, kernel_size=3, name='stem', strides=2, use_bias=False, use_bn=True, bn_layer=bn_layer, activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay) # Wider ResNet has five residual block stages, so we start from index 0. first_block_index = 0 # Since we have applied the first strided convolution here, we do not use # a stride for the first stage (which will operate on stride 2). strides_list[0] = 1 total_strides_list[0] = 2 elif backbone_type.lower() == 'resnet_beta': self._stem = stems.InceptionSTEM( bn_layer=bn_layer, width_multiplier=stem_width_multiplier, conv_kernel_weight_decay=conv_kernel_weight_decay, activation=activation) first_block_index = 1 elif backbone_type.lower() == 'resnet': self._stem = convolutions.Conv2DSame( output_channels=stem_channels, kernel_size=7, name='stem', strides=2, use_bias=False, use_bn=True, bn_layer=bn_layer, activation='none', conv_kernel_weight_decay=conv_kernel_weight_decay) first_block_index = 1 else: raise ValueError(backbone_type + ' is not supported.') self._first_block_index = first_block_index # Apply standard ResNet block groups. We use first_block_index to # distinguish models with 4 stages and those with 5 stages. for index in range(first_block_index, 5): current_name = '_stage{}'.format(index + 1) utils.safe_setattr(self, current_name, axial_block_groups.BlockGroup( filters=filters_list[index], num_blocks=num_blocks_list[index], name=utils.get_layer_name(current_name), original_resnet_stride=strides_list[index], original_resnet_input_stride=total_strides_list[index], output_stride=output_stride, backbone_type=backbone_type, use_axial_beyond_stride=use_axial_beyond_stride, use_transformer_beyond_stride=( backbone_use_transformer_beyond_stride), transformer_expansion=transformer_expansions_list[index], activation=activation, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **block_group_config)) self._backbone_decoder_num_stacks = backbone_decoder_num_stacks self._classification_mode = classification_mode self._extra_decoder_num_stacks = extra_decoder_num_stacks self._output_stride = output_stride self._high_resolution_output_stride = high_resolution_output_stride self._width_multiplier = width_multiplier self._activation = activation self._bn_layer = bn_layer self._conv_kernel_weight_decay = conv_kernel_weight_decay self._backbone_use_transformer_beyond_stride = ( backbone_use_transformer_beyond_stride) self._extra_decoder_use_transformer_beyond_stride = ( extra_decoder_use_transformer_beyond_stride) # Keep track of the current stack so that we know when to stop. current_stack = 0 # Track whether we are building the backbone. This will affect the backbone # related arguments, local learning rate, and so on. current_is_backbone = True if backbone_decoder_num_stacks == 0: # No stacked decoder is used in the backbone, so we have finished building # the backbone. We either return the classification endpoints, or continue # building a non-backbone decoder for panoptic segmentation. if self._classification_mode: return else: current_is_backbone = False if not current_is_backbone: # Now that we have finished building the backbone and no stacked decoder # is used in the backbone, so we start to build extra (i.e., non-backbone) # layers for panoptic segmentation. current_name = '_stage5_' + EXTRA utils.safe_setattr( self, current_name, axial_block_groups.BlockGroup( filters=filters_list[-1], num_blocks=extra_decoder_blocks_per_stage, name=utils.get_layer_name(current_name), original_resnet_stride=1, original_resnet_input_stride=32, output_stride=output_stride, backbone_type=backbone_type, use_axial_beyond_stride=use_axial_beyond_stride, use_transformer_beyond_stride=( extra_decoder_use_transformer_beyond_stride), transformer_expansion=base_transformer_expansion, activation=activation, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **block_group_config)) # Compute parameter lists for stacked decoder. total_decoder_num_stacks = ( backbone_decoder_num_stacks + extra_decoder_num_stacks) # Use a function to compute the next stride. next_stride_fn = lambda x: x // 2 current_decoder_stride = output_stride decoder_stage = 0 # Exit if we have enough stacks and reach the decoding output stride. while (current_stack < total_decoder_num_stacks or current_decoder_stride > high_resolution_output_stride): decoder_stage += 1 current_decoder_stride = next_stride_fn(current_decoder_stride) if current_decoder_stride == output_stride: current_stack += 1 # Always use blocks from the last resnet stage if the current stride is # output stride (the largest stride). original_resnet_input_stride = 32 # Switch the decoder direction if we reach the largest stride. next_stride_fn = lambda x: x // 2 else: original_resnet_input_stride = current_decoder_stride # Scale channels according to the strides. decoder_channels = original_resnet_input_stride * 64 * width_multiplier current_transformer_expansion = ( base_transformer_expansion * current_decoder_stride / 16.0) # Apply a decoder block group for building the backbone. if current_is_backbone: current_name = '_decoder_stage{}'.format(decoder_stage) utils.safe_setattr( self, current_name, axial_block_groups.BlockGroup( filters=decoder_channels // 4, num_blocks=backbone_decoder_blocks_per_stage, name=utils.get_layer_name(current_name), original_resnet_stride=1, original_resnet_input_stride=original_resnet_input_stride, output_stride=output_stride, backbone_type=backbone_type, use_axial_beyond_stride=use_axial_beyond_stride, use_transformer_beyond_stride=( backbone_use_transformer_beyond_stride), transformer_expansion=current_transformer_expansion, activation=activation, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **block_group_config)) if (current_decoder_stride == output_stride and current_stack == backbone_decoder_num_stacks): # Now that we have finished building the backbone, we either return the # classification endpoints, or continue building a non-backbone decoder # for panoptic segmentation. if classification_mode: return else: current_is_backbone = False # Apply a decoder block group for building the extra layers. if not current_is_backbone: # Continue building an extra (i.e., non-backbone) decoder for panoptic # segmentation. current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA) utils.safe_setattr( self, current_name, axial_block_groups.BlockGroup( filters=decoder_channels // 4, num_blocks=extra_decoder_blocks_per_stage, name=utils.get_layer_name(current_name), original_resnet_stride=1, original_resnet_input_stride=original_resnet_input_stride, output_stride=output_stride, backbone_type=backbone_type, use_axial_beyond_stride=use_axial_beyond_stride, use_transformer_beyond_stride=( extra_decoder_use_transformer_beyond_stride), transformer_expansion=current_transformer_expansion, activation=activation, bn_layer=bn_layer, conv_kernel_weight_decay=conv_kernel_weight_decay, **block_group_config)) if current_decoder_stride == high_resolution_output_stride: next_stride_fn = lambda x: x * 2 # Assert that we have already returned if we are building a classifier. assert not classification_mode if (backbone_use_transformer_beyond_stride or extra_decoder_use_transformer_beyond_stride): # Build extra memory path feed forward networks for the class feature and # the mask feature. current_name = '_class_feature_' + EXTRA utils.safe_setattr( self, current_name, convolutions.Conv1D( global_feed_forward_network_channels, utils.get_layer_name(current_name), use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, conv_kernel_weight_decay=conv_kernel_weight_decay)) current_name = '_mask_feature_' + EXTRA utils.safe_setattr( self, current_name, convolutions.Conv1D( global_feed_forward_network_channels, utils.get_layer_name(current_name), use_bias=False, use_bn=True, bn_layer=bn_layer, activation=activation, conv_kernel_weight_decay=conv_kernel_weight_decay)) def build(self, input_shape): """Builds model weights and input shape dependent sub-layers.""" if self._use_memory_feature: self._memory_feature = self.add_weight( name=MEMORY_FEATURE, shape=self._memory_feature_shape, initializer=self._memory_feature_initializer, regularizer=self._memory_feature_regularizer) else: self._memory_feature = None # Go through the loop to build the ResizedFuse layers. current_stack = 0 # Track whether we are building the backbone. This will affect the backbone # related arguments, local learning rate, and so on. current_is_backbone = self._backbone_decoder_num_stacks != 0 total_decoder_num_stacks = ( self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks) next_stride_fn = lambda x: x // 2 current_decoder_stride = self._output_stride decoder_stage = 0 while (current_stack < total_decoder_num_stacks or current_decoder_stride > self._high_resolution_output_stride): decoder_stage += 1 current_decoder_stride = next_stride_fn(current_decoder_stride) if current_decoder_stride == self._output_stride: current_stack += 1 original_resnet_input_stride = 32 next_stride_fn = lambda x: x // 2 else: original_resnet_input_stride = current_decoder_stride # Compute the decoder_channels according to original_resnet_input_stride. # For example, at stride 4 with width multiplier = 1, we use 4 * 64 = 256 # channels, which is the same as a standard ResNet. decoder_channels = int(round( original_resnet_input_stride * 64 * self._width_multiplier)) decoder_height, decoder_width = utils.scale_mutable_sequence( input_shape[1:3], 1.0 / current_decoder_stride) if current_is_backbone: current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage) else: current_name = '_decoder_stage{}_{}_resized_fuse'.format( decoder_stage, EXTRA) utils.safe_setattr( self, current_name, resized_fuse.ResizedFuse( name=utils.get_layer_name(current_name), height=decoder_height, width=decoder_width, num_channels=decoder_channels, activation=self._activation, bn_layer=self._bn_layer, conv_kernel_weight_decay=self._conv_kernel_weight_decay)) if (current_decoder_stride == self._output_stride and current_stack == self._backbone_decoder_num_stacks): # Now that we have finished building the backbone, we either return the # classification endpoints, or continue building a non-backbone decoder # for panoptic segmentation. if self._classification_mode: return current_is_backbone = False if current_decoder_stride == self._high_resolution_output_stride: next_stride_fn = lambda x: x * 2 def call_encoder_before_stacked_decoder(self, inputs, training=False): """Performs a forward pass of the encoder before stacking decoders. Args: inputs: An input [batch, height, width, channel] tensor. training: A boolean, whether the model is in training mode. Returns: current_output: An output tensor with shape [batch, new_height, new_width, new_channel]. activated_output: An activated output tensor with shape [batch, new_height, new_width, new_channel]. memory_feature: None if no transformer is used. A [batch, num_memory, memory_channel] tensor if transformer is used. endpoints: A dict, the network endpoints that might be used by DeepLab. """ memory_feature = self._memory_feature if self._use_memory_feature: if self._num_mask_slots: memory_feature = self._memory_feature[:, :self._num_mask_slots, :] memory_feature = tf.tile(memory_feature, [tf.shape(inputs)[0], 1, 1]) endpoints = {} output = self._stem(inputs) activated_output = self._activation_fn(output) endpoints['stage1'] = output endpoints['res1'] = activated_output # Apply standard ResNet block groups. We use first_block_index to # distinguish models with 4 stages and those with 5 stages. for index in range(self._first_block_index, 5): current_name = '_stage{}'.format(index + 1) current_output, activated_output, memory_feature = ( getattr(self, current_name)( (activated_output, memory_feature), training=training)) endpoints[utils.get_layer_name(current_name)] = current_output activated_output_name = 'res{}'.format(index + 1) endpoints[activated_output_name] = activated_output return current_output, activated_output, memory_feature, endpoints def call_stacked_decoder(self, current_output, activated_output, memory_feature, endpoints, training=False): """Performs a forward pass of the stacked decoders. Args: current_output: An output tensor with shape [batch, new_height, new_width, new_channel]. activated_output: An activated output tensor with shape [batch, new_height, new_width, new_channel]. memory_feature: None if no transformer is used. A [batch, num_memory, memory_channel] tensor if transformer is used. endpoints: A dict, the network endpoints that might be used by DeepLab. training: A boolean, whether the model is in training mode. Returns: memory_feature: None if no transformer is used. A [batch, num_memory, memory_channel] tensor if transformer is used. high_resolution_outputs: A list of decoded tensors with high_resolution_output_stride. backbone_output: An output tensor of the backbone, with output_stride. endpoints: A dict, the network endpoints that might be used by DeepLab. """ # Keep track of the current stack so that we know when to stop. current_stack = 0 # Track whether we are building the backbone. This will affect the backbone # related arguments, local learning rate, and so on. current_is_backbone = True high_resolution_outputs = [] if self._backbone_decoder_num_stacks == 0: # Keep track of the backbone output, since it might be used as the # semantic feature output. backbone_output = activated_output # Now that we have finished building the backbone, we either return the # classification logits, or continue building a non-backbone decoder for # panoptic segmentation. if self._classification_mode: endpoints['backbone_output'] = backbone_output return None, None, None, endpoints else: current_is_backbone = False if not current_is_backbone: # Build extra layers if we have finished building the backbone. current_name = '_stage5_' + EXTRA current_output, activated_output, memory_feature = ( getattr(self, current_name)( (activated_output, memory_feature), training=training)) # Compute parameter lists for stacked decoder. total_decoder_num_stacks = ( self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks) # Keep track of all endpoints that will be used in the stacked decoder. stride_to_features = {} stride_to_features[min(2, self._output_stride)] = [endpoints['stage1']] stride_to_features[min(4, self._output_stride)] = [endpoints['stage2']] stride_to_features[min(8, self._output_stride)] = [endpoints['stage3']] stride_to_features[min(16, self._output_stride)] = [endpoints['stage4']] # Only keep the last endpoint from the backbone with the same resolution, # i.e., if the output stride is 16, the current output will override # the stride 16 endpoint, endpoints['res4']. stride_to_features[min(32, self._output_stride)] = [current_output] # Use a function to compute the next stride. next_stride_fn = lambda x: x // 2 current_decoder_stride = self._output_stride decoder_stage = 0 # Exit if we have enough stacks and reach the decoding output stride. while (current_stack < total_decoder_num_stacks or current_decoder_stride > self._high_resolution_output_stride): decoder_stage += 1 current_decoder_stride = next_stride_fn(current_decoder_stride) if current_decoder_stride == self._output_stride: current_stack += 1 # Switch the decoder direction if we reach the largest stride. next_stride_fn = lambda x: x // 2 # Include the current feature and two previous features from the target # resolution in the decoder. We select two because it contains one upward # feature and one downward feature, but better choices are possible. decoder_features_list = ( [current_output] + stride_to_features[current_decoder_stride][-2:]) # Fuse and resize features with striding, resizing and 1x1 convolutions. if current_is_backbone: current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage) else: current_name = '_decoder_stage{}_{}_resized_fuse'.format( decoder_stage, EXTRA) activated_output = getattr(self, current_name)( decoder_features_list, training=training) # Apply a decoder block group for building the backbone. if current_is_backbone: current_name = '_decoder_stage{}'.format(decoder_stage) current_output, activated_output, memory_feature = ( getattr(self, current_name)( (activated_output, memory_feature), training=training)) if (current_decoder_stride == self._output_stride and current_stack == self._backbone_decoder_num_stacks): # Keep track of the backbone output, since it might be used as the # semantic feature output. backbone_output = activated_output # Now that we have finished building the backbone, we either return the # classification logits, or continue building a non-backbone decoder for # panoptic segmentation. if self._classification_mode: endpoints['backbone_output'] = backbone_output return None, None, None, endpoints else: current_is_backbone = False # Apply a decoder block group for building the extra layers. if not current_is_backbone: current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA) current_output, activated_output, memory_feature = ( getattr(self, current_name)( (activated_output, memory_feature), training=training)) # Append the current feature into the feature dict for possible later # usage. stride_to_features[current_decoder_stride].append(current_output) if current_decoder_stride == self._high_resolution_output_stride: high_resolution_outputs.append(activated_output) next_stride_fn = lambda x: x * 2 return memory_feature, high_resolution_outputs, backbone_output, endpoints def call_extra_endpoints(self, memory_feature, high_resolution_outputs, backbone_output, endpoints, training=False): """Performs a forward pass to generate extra endpoints. Args: memory_feature: None if no transformer is used. A [batch, num_memory, memory_channel] tensor if transformer is used. high_resolution_outputs: A list of decoded tensors with high_resolution_output_stride. backbone_output: An output tensor of the backbone, with output_stride. endpoints: A dict, the network endpoints that might be used by DeepLab. training: A boolean, whether the model is in training mode. Returns: endpoints: A dict, the network endpoints that might be used by DeepLab. """ # Assert that we have already returned if we are building a classifier. assert not self._classification_mode if (self._backbone_use_transformer_beyond_stride or self._extra_decoder_use_transformer_beyond_stride): # Build extra memory path feed forward networks for the class feature and # the mask feature. class_feature = getattr(self, '_class_feature_' + EXTRA)( memory_feature, training=training) mask_feature = getattr(self, '_mask_feature_' + EXTRA)( memory_feature, training=training) endpoints['transformer_class_feature'] = class_feature endpoints['transformer_mask_feature'] = mask_feature # Output the last high resolution feature as panoptic feature. endpoints['feature_panoptic'] = high_resolution_outputs[-1] # Avoid sharing our panoptic feature with the semantic auxiliary loss. So we # use the backbone feature or the decoded backbone feature for the semantic # segmentation head (i.e. the auxiliary loss). if self._extra_decoder_num_stacks: endpoints['feature_semantic'] = ( high_resolution_outputs[self._backbone_decoder_num_stacks]) else: endpoints['feature_semantic'] = backbone_output endpoints['backbone_output'] = backbone_output return endpoints def call(self, inputs, training=False): """Performs a forward pass. Args: inputs: An input [batch, height, width, channel] tensor. training: A boolean, whether the model is in training mode. Returns: endpoints: A dict, the network endpoints that might be used by DeepLab. """ current_output, activated_output, memory_feature, endpoints = ( self.call_encoder_before_stacked_decoder(inputs, training=training)) memory_feature, high_resolution_outputs, backbone_output, endpoints = ( self.call_stacked_decoder(current_output, activated_output, memory_feature, endpoints, training=training)) if self._classification_mode: return endpoints endpoints = self.call_extra_endpoints(memory_feature, high_resolution_outputs, backbone_output, endpoints, training=training) return endpoints