deeplab2 / model /encoder /axial_resnet.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements Axial-ResNets proposed in Axial-DeepLab [1].
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation,
ECCV 2020 Spotlight.
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille,
Liang-Chieh Chen.
"""
import tensorflow as tf
from deeplab2.model import utils
from deeplab2.model.layers import activations
from deeplab2.model.layers import axial_block_groups
from deeplab2.model.layers import convolutions
from deeplab2.model.layers import resized_fuse
from deeplab2.model.layers import stems
# Add a suffix in layer names that indicate if the current layer is a part of
# the backbone or an extra layer, i.e. if the current layer will be pretrained
# or not. This name will be used when we apply 10x larger learning rates for
# extra parameters that have not been pretrained, in panoptic segmentation.
# This keyword is reserved and should not be a part of the variable names in a
# classification pretrained backbone.
EXTRA = 'extra'
# Similarly, we will apply 10x larger learning rates on the memory feature.
# This global variable name will be accessed when we build the optimizers. This
# keyword is reserved and should not be a part of the variable names in a
# classification pretrained backbone.
MEMORY_FEATURE = 'memory_feature'
class AxialResNet(tf.keras.Model):
"""An Axial-ResNet model as proposed in Axial-DeepLab [1] and MaX-DeepLab [2].
An Axial-ResNet [1] replaces 3x3 convolutions in a Resnet by axial-attention
layers. A dual-path transformer [2] and a stacked decoder [2] can be used
optionally. In addition, this class supports scaling models with SWideRNet [3]
and augmenting convolutions with Switchable Atrous Convolution [4].
Reference:
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation,
ECCV 2020 Spotlight. https://arxiv.org/abs/2003.07853
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille,
Liang-Chieh Chen.
[2] MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers",
CVPR 2021. https://arxiv.org/abs/2012.00759
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen.
[3] Scaling Wide Residual Networks for Panoptic Segmentation,
https://arxiv.org/abs/2011.11675
Liang-Chieh Chen, Huiyu Wang, Siyuan Qiao.
[4] DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable
Atrous Convolution, CVPR 2021. https://arxiv.org/abs/2006.02334
Siyuan Qiao, Liang-Chieh Chen, Alan Yuille.
"""
def __init__(self,
name,
num_blocks=(3, 4, 6, 3),
backbone_layer_multiplier=1.0,
width_multiplier=1.0,
stem_width_multiplier=1.0,
output_stride=16,
classification_mode=False,
backbone_type='resnet_beta',
use_axial_beyond_stride=16,
backbone_use_transformer_beyond_stride=32,
extra_decoder_use_transformer_beyond_stride=32,
backbone_decoder_num_stacks=0,
backbone_decoder_blocks_per_stage=1,
extra_decoder_num_stacks=0,
extra_decoder_blocks_per_stage=1,
max_num_mask_slots=128,
num_mask_slots=128,
memory_channels=256,
base_transformer_expansion=1.0,
global_feed_forward_network_channels=256,
high_resolution_output_stride=4,
activation='relu',
block_group_config=None,
bn_layer=tf.keras.layers.BatchNormalization,
conv_kernel_weight_decay=0.0):
"""Initializes an AxialResNet model.
Args:
name: A string, the name of the model.
num_blocks: A list of 4 integers. It denotes the number of blocks to
include in the last 4 stages or block groups. Each group consists of
blocks that output features of the same resolution. Defaults to (3, 4,
6, 3) as in MaX-DeepLab-S.
backbone_layer_multiplier: A float, layer_multiplier for the backbone,
excluding the STEM. This flag controls the number of layers. Defaults to
1.0 as in MaX-DeepLab-S.
width_multiplier: A float, the channel multiplier for the block groups.
Defaults to 1.0 as in MaX-DeepLab-S.
stem_width_multiplier: A float, the channel multiplier for stem
convolutions. Defaults to 1.0 as in MaX-DeepLab-S.
output_stride: An integer, the maximum ratio of input to output spatial
resolution. Defaults to 16 as in MaX-DeepLab-S.
classification_mode: A boolean, whether to perform in a classification
mode. If it is True, this function directly returns backbone feature
endpoints. Note that these feature endpoints can also be used directly
for Panoptic-DeepLab or Motion-DeepLab. If it is False, this function
builds MaX-DeepLab extra decoder layers and extra transformer layers.
Defaults to False as in MaX-DeepLab.
backbone_type: A string, the type of backbone. Supports 'resnet',
'resnet_beta', and 'wider_resnet'. It controls both the stem type and
the residual block type. Defaults to 'resnet_beta' as in MaX-DeepLab-S.
use_axial_beyond_stride: An integer, the stride beyond which we use axial
attention. Set to 0 if no axial attention is desired. Defaults to 16 as
in MaX-DeepLab.
backbone_use_transformer_beyond_stride: An integer, the stride beyond
which we use a memory path transformer block on top of a regular pixel
path block, in the backbone. Set to 0 if no transformer block is desired
in the backbone. Defaults to 32 as in MaX-DeepLab-S.
extra_decoder_use_transformer_beyond_stride: An integer, the stride beyond
which we use a memory path transformer block on top of a regular pixel
path block, in the extra decoder stages. Set to 0 if no transformer
block is desired in the extra decoder stages. Defaults to 32 as in
MaX-DeepLab-S.
backbone_decoder_num_stacks: An integer, the number of decoder stacks
(introduced in MaX-DeepLab) that we use in the backbone. The stacked
decoders are applied in a stacked hour-glass style. Defaults to 0 as in
MaX-DeepLab-S.
backbone_decoder_blocks_per_stage: An integer, the number of consecutive
residual blocks to apply for each decoder stage, in the backbone.
Defaults to 1 as in MaX-DeepLab-S.
extra_decoder_num_stacks: An integer, the number of decoder stacks
(introduced in MaX-DeepLab) that we use in the extra decoder layers. It
is different from backbone_decoder_blocks_per_stage in that the extra
decoder stacks will be trained from scratch on segmentation tasks,
instead of pretrained on ImageNet classification. Defaults to 0 as in
MaX-DeepLab-S.
extra_decoder_blocks_per_stage: An integer, the number of consecutive
residual blocks to apply for each decoder stage, in the extra decoder
stages. Defaults to 1 as in MaX-DeepLab-S.
max_num_mask_slots: An integer, the maximum possible number of mask slots
that will be used. This will be used in a pretraining-finetuning use
case with different num_mask_slots: We can set max_num_mask_slots to the
maximum possible num_mask_slots, and then the saved checkpoint can be
loaded for finetuning with a different num_mask_slots. Defaults to 128
as in MaX-DeepLab.
num_mask_slots: An integer, the number of mask slots that will be used.
Defaults to 128 as in MaX-DeepLab-S.
memory_channels: An integer, the number of channels for the whole memory
path. Defaults to 256 as in MaX-DeepLab-S.
base_transformer_expansion: A float, the base width expansion rate for
transformer layers. Defaults to 1.0 as in MaX-DeepLab-S.
global_feed_forward_network_channels: An integer, the number of channels
in the final global feed forward network, i.e. the mask feature head and
the mask class head. Defaults to 256 as in MaX-DeepLab-S.
high_resolution_output_stride: An integer, the final decoding output
stride. Defaults to 4 as in MaX-DeepLab-S.
activation: A string, type of activation function to apply. Support
'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'.
block_group_config: An argument dictionary that will be passed to
block_group.
bn_layer: An optional tf.keras.layers.Layer that computes the
normalization (default: tf.keras.layers.BatchNormalization).
conv_kernel_weight_decay: A float, the weight decay for convolution
kernels.
Raises:
ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or
'wider_resnet'.
ValueError: If extra_decoder_blocks_per_stage is not greater than zero.
"""
super(AxialResNet, self).__init__(name=name)
if extra_decoder_blocks_per_stage <= 0:
raise ValueError(
'Extra_decoder_blocks_per_stage should be great than zero.')
if block_group_config is None:
block_group_config = {}
# Compute parameter lists for block_groups. We consider five stages so that
# it is general enough to cover fully axial resnets and wider resnets.
total_strides_list = [1, 2, 4, 8, 16]
# Append 3 blocks for the first stage of fully axial resnets and wider
# resnets.
num_blocks_list = [3] + utils.scale_int_list(list(num_blocks),
backbone_layer_multiplier)
strides_list = [2] * 5
# Expand the transformer and the block filters with the stride.
transformer_expansions_list = []
filters_list = []
for index, stride in enumerate(total_strides_list):
# Reduce the number of channels when we apply transformer to low level
# features (stride = 2, 4, or 8). The base_transformer_expansion is used
# for stride = 16, i.e. the standard output_stride for MaX-DeepLab-S.
transformer_expansions_list.append(base_transformer_expansion * stride /
16.0)
# Compute the base number of filters in each stage. For example, the last
# stage of ResNet50 has an input stride of 16, then we compute the base
# number of filters for a bottleneck block as 16 * 32 = 512, which is the
# number of filters for the 3x3 convolution in those blocks.
if backbone_type == 'wider_resnet' and index == 0:
# SWideRNet variants use stem_width_multiplier for the first block.
filters_list.append(int(round(stride * 32 * stem_width_multiplier)))
else:
filters_list.append(int(round(stride * 32 * width_multiplier)))
self._num_mask_slots = None
# Initialize memory_feature only when a transformer block is used.
self._use_memory_feature = (backbone_use_transformer_beyond_stride or
(extra_decoder_use_transformer_beyond_stride and
(not classification_mode)))
if self._use_memory_feature:
self._memory_feature_shape = (1, max_num_mask_slots, memory_channels)
self._memory_feature_initializer = (
tf.keras.initializers.TruncatedNormal(stddev=1.0))
self._memory_feature_regularizer = tf.keras.regularizers.l2(
conv_kernel_weight_decay)
if num_mask_slots:
self._num_mask_slots = num_mask_slots
# Use a convolutional stem except fully axial cases.
stem_channels = int(round(64 * stem_width_multiplier))
self._activation_fn = activations.get_activation(activation)
if use_axial_beyond_stride == 1:
self._stem = tf.identity
first_block_index = 0
elif backbone_type.lower() == 'wider_resnet':
self._stem = convolutions.Conv2DSame(
output_channels=stem_channels,
kernel_size=3,
name='stem',
strides=2,
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay)
# Wider ResNet has five residual block stages, so we start from index 0.
first_block_index = 0
# Since we have applied the first strided convolution here, we do not use
# a stride for the first stage (which will operate on stride 2).
strides_list[0] = 1
total_strides_list[0] = 2
elif backbone_type.lower() == 'resnet_beta':
self._stem = stems.InceptionSTEM(
bn_layer=bn_layer,
width_multiplier=stem_width_multiplier,
conv_kernel_weight_decay=conv_kernel_weight_decay,
activation=activation)
first_block_index = 1
elif backbone_type.lower() == 'resnet':
self._stem = convolutions.Conv2DSame(
output_channels=stem_channels,
kernel_size=7,
name='stem',
strides=2,
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='none',
conv_kernel_weight_decay=conv_kernel_weight_decay)
first_block_index = 1
else:
raise ValueError(backbone_type + ' is not supported.')
self._first_block_index = first_block_index
# Apply standard ResNet block groups. We use first_block_index to
# distinguish models with 4 stages and those with 5 stages.
for index in range(first_block_index, 5):
current_name = '_stage{}'.format(index + 1)
utils.safe_setattr(self, current_name, axial_block_groups.BlockGroup(
filters=filters_list[index],
num_blocks=num_blocks_list[index],
name=utils.get_layer_name(current_name),
original_resnet_stride=strides_list[index],
original_resnet_input_stride=total_strides_list[index],
output_stride=output_stride,
backbone_type=backbone_type,
use_axial_beyond_stride=use_axial_beyond_stride,
use_transformer_beyond_stride=(
backbone_use_transformer_beyond_stride),
transformer_expansion=transformer_expansions_list[index],
activation=activation,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay,
**block_group_config))
self._backbone_decoder_num_stacks = backbone_decoder_num_stacks
self._classification_mode = classification_mode
self._extra_decoder_num_stacks = extra_decoder_num_stacks
self._output_stride = output_stride
self._high_resolution_output_stride = high_resolution_output_stride
self._width_multiplier = width_multiplier
self._activation = activation
self._bn_layer = bn_layer
self._conv_kernel_weight_decay = conv_kernel_weight_decay
self._backbone_use_transformer_beyond_stride = (
backbone_use_transformer_beyond_stride)
self._extra_decoder_use_transformer_beyond_stride = (
extra_decoder_use_transformer_beyond_stride)
# Keep track of the current stack so that we know when to stop.
current_stack = 0
# Track whether we are building the backbone. This will affect the backbone
# related arguments, local learning rate, and so on.
current_is_backbone = True
if backbone_decoder_num_stacks == 0:
# No stacked decoder is used in the backbone, so we have finished building
# the backbone. We either return the classification endpoints, or continue
# building a non-backbone decoder for panoptic segmentation.
if self._classification_mode:
return
else:
current_is_backbone = False
if not current_is_backbone:
# Now that we have finished building the backbone and no stacked decoder
# is used in the backbone, so we start to build extra (i.e., non-backbone)
# layers for panoptic segmentation.
current_name = '_stage5_' + EXTRA
utils.safe_setattr(
self, current_name, axial_block_groups.BlockGroup(
filters=filters_list[-1],
num_blocks=extra_decoder_blocks_per_stage,
name=utils.get_layer_name(current_name),
original_resnet_stride=1,
original_resnet_input_stride=32,
output_stride=output_stride,
backbone_type=backbone_type,
use_axial_beyond_stride=use_axial_beyond_stride,
use_transformer_beyond_stride=(
extra_decoder_use_transformer_beyond_stride),
transformer_expansion=base_transformer_expansion,
activation=activation,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay,
**block_group_config))
# Compute parameter lists for stacked decoder.
total_decoder_num_stacks = (
backbone_decoder_num_stacks + extra_decoder_num_stacks)
# Use a function to compute the next stride.
next_stride_fn = lambda x: x // 2
current_decoder_stride = output_stride
decoder_stage = 0
# Exit if we have enough stacks and reach the decoding output stride.
while (current_stack < total_decoder_num_stacks or
current_decoder_stride > high_resolution_output_stride):
decoder_stage += 1
current_decoder_stride = next_stride_fn(current_decoder_stride)
if current_decoder_stride == output_stride:
current_stack += 1
# Always use blocks from the last resnet stage if the current stride is
# output stride (the largest stride).
original_resnet_input_stride = 32
# Switch the decoder direction if we reach the largest stride.
next_stride_fn = lambda x: x // 2
else:
original_resnet_input_stride = current_decoder_stride
# Scale channels according to the strides.
decoder_channels = original_resnet_input_stride * 64 * width_multiplier
current_transformer_expansion = (
base_transformer_expansion * current_decoder_stride / 16.0)
# Apply a decoder block group for building the backbone.
if current_is_backbone:
current_name = '_decoder_stage{}'.format(decoder_stage)
utils.safe_setattr(
self, current_name, axial_block_groups.BlockGroup(
filters=decoder_channels // 4,
num_blocks=backbone_decoder_blocks_per_stage,
name=utils.get_layer_name(current_name),
original_resnet_stride=1,
original_resnet_input_stride=original_resnet_input_stride,
output_stride=output_stride,
backbone_type=backbone_type,
use_axial_beyond_stride=use_axial_beyond_stride,
use_transformer_beyond_stride=(
backbone_use_transformer_beyond_stride),
transformer_expansion=current_transformer_expansion,
activation=activation,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay,
**block_group_config))
if (current_decoder_stride == output_stride and
current_stack == backbone_decoder_num_stacks):
# Now that we have finished building the backbone, we either return the
# classification endpoints, or continue building a non-backbone decoder
# for panoptic segmentation.
if classification_mode:
return
else:
current_is_backbone = False
# Apply a decoder block group for building the extra layers.
if not current_is_backbone:
# Continue building an extra (i.e., non-backbone) decoder for panoptic
# segmentation.
current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA)
utils.safe_setattr(
self, current_name, axial_block_groups.BlockGroup(
filters=decoder_channels // 4,
num_blocks=extra_decoder_blocks_per_stage,
name=utils.get_layer_name(current_name),
original_resnet_stride=1,
original_resnet_input_stride=original_resnet_input_stride,
output_stride=output_stride,
backbone_type=backbone_type,
use_axial_beyond_stride=use_axial_beyond_stride,
use_transformer_beyond_stride=(
extra_decoder_use_transformer_beyond_stride),
transformer_expansion=current_transformer_expansion,
activation=activation,
bn_layer=bn_layer,
conv_kernel_weight_decay=conv_kernel_weight_decay,
**block_group_config))
if current_decoder_stride == high_resolution_output_stride:
next_stride_fn = lambda x: x * 2
# Assert that we have already returned if we are building a classifier.
assert not classification_mode
if (backbone_use_transformer_beyond_stride or
extra_decoder_use_transformer_beyond_stride):
# Build extra memory path feed forward networks for the class feature and
# the mask feature.
current_name = '_class_feature_' + EXTRA
utils.safe_setattr(
self, current_name, convolutions.Conv1D(
global_feed_forward_network_channels,
utils.get_layer_name(current_name),
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
conv_kernel_weight_decay=conv_kernel_weight_decay))
current_name = '_mask_feature_' + EXTRA
utils.safe_setattr(
self, current_name, convolutions.Conv1D(
global_feed_forward_network_channels,
utils.get_layer_name(current_name),
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation=activation,
conv_kernel_weight_decay=conv_kernel_weight_decay))
def build(self, input_shape):
"""Builds model weights and input shape dependent sub-layers."""
if self._use_memory_feature:
self._memory_feature = self.add_weight(
name=MEMORY_FEATURE,
shape=self._memory_feature_shape,
initializer=self._memory_feature_initializer,
regularizer=self._memory_feature_regularizer)
else:
self._memory_feature = None
# Go through the loop to build the ResizedFuse layers.
current_stack = 0
# Track whether we are building the backbone. This will affect the backbone
# related arguments, local learning rate, and so on.
current_is_backbone = self._backbone_decoder_num_stacks != 0
total_decoder_num_stacks = (
self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks)
next_stride_fn = lambda x: x // 2
current_decoder_stride = self._output_stride
decoder_stage = 0
while (current_stack < total_decoder_num_stacks or
current_decoder_stride > self._high_resolution_output_stride):
decoder_stage += 1
current_decoder_stride = next_stride_fn(current_decoder_stride)
if current_decoder_stride == self._output_stride:
current_stack += 1
original_resnet_input_stride = 32
next_stride_fn = lambda x: x // 2
else:
original_resnet_input_stride = current_decoder_stride
# Compute the decoder_channels according to original_resnet_input_stride.
# For example, at stride 4 with width multiplier = 1, we use 4 * 64 = 256
# channels, which is the same as a standard ResNet.
decoder_channels = int(round(
original_resnet_input_stride * 64 * self._width_multiplier))
decoder_height, decoder_width = utils.scale_mutable_sequence(
input_shape[1:3], 1.0 / current_decoder_stride)
if current_is_backbone:
current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage)
else:
current_name = '_decoder_stage{}_{}_resized_fuse'.format(
decoder_stage, EXTRA)
utils.safe_setattr(
self, current_name, resized_fuse.ResizedFuse(
name=utils.get_layer_name(current_name),
height=decoder_height,
width=decoder_width,
num_channels=decoder_channels,
activation=self._activation,
bn_layer=self._bn_layer,
conv_kernel_weight_decay=self._conv_kernel_weight_decay))
if (current_decoder_stride == self._output_stride and
current_stack == self._backbone_decoder_num_stacks):
# Now that we have finished building the backbone, we either return the
# classification endpoints, or continue building a non-backbone decoder
# for panoptic segmentation.
if self._classification_mode:
return
current_is_backbone = False
if current_decoder_stride == self._high_resolution_output_stride:
next_stride_fn = lambda x: x * 2
def call_encoder_before_stacked_decoder(self, inputs, training=False):
"""Performs a forward pass of the encoder before stacking decoders.
Args:
inputs: An input [batch, height, width, channel] tensor.
training: A boolean, whether the model is in training mode.
Returns:
current_output: An output tensor with shape [batch, new_height, new_width,
new_channel].
activated_output: An activated output tensor with shape [batch,
new_height, new_width, new_channel].
memory_feature: None if no transformer is used. A [batch, num_memory,
memory_channel] tensor if transformer is used.
endpoints: A dict, the network endpoints that might be used by DeepLab.
"""
memory_feature = self._memory_feature
if self._use_memory_feature:
if self._num_mask_slots:
memory_feature = self._memory_feature[:, :self._num_mask_slots, :]
memory_feature = tf.tile(memory_feature,
[tf.shape(inputs)[0], 1, 1])
endpoints = {}
output = self._stem(inputs)
activated_output = self._activation_fn(output)
endpoints['stage1'] = output
endpoints['res1'] = activated_output
# Apply standard ResNet block groups. We use first_block_index to
# distinguish models with 4 stages and those with 5 stages.
for index in range(self._first_block_index, 5):
current_name = '_stage{}'.format(index + 1)
current_output, activated_output, memory_feature = (
getattr(self, current_name)(
(activated_output, memory_feature), training=training))
endpoints[utils.get_layer_name(current_name)] = current_output
activated_output_name = 'res{}'.format(index + 1)
endpoints[activated_output_name] = activated_output
return current_output, activated_output, memory_feature, endpoints
def call_stacked_decoder(self,
current_output,
activated_output,
memory_feature,
endpoints,
training=False):
"""Performs a forward pass of the stacked decoders.
Args:
current_output: An output tensor with shape [batch, new_height, new_width,
new_channel].
activated_output: An activated output tensor with shape [batch,
new_height, new_width, new_channel].
memory_feature: None if no transformer is used. A [batch, num_memory,
memory_channel] tensor if transformer is used.
endpoints: A dict, the network endpoints that might be used by DeepLab.
training: A boolean, whether the model is in training mode.
Returns:
memory_feature: None if no transformer is used. A [batch, num_memory,
memory_channel] tensor if transformer is used.
high_resolution_outputs: A list of decoded tensors with
high_resolution_output_stride.
backbone_output: An output tensor of the backbone, with output_stride.
endpoints: A dict, the network endpoints that might be used by DeepLab.
"""
# Keep track of the current stack so that we know when to stop.
current_stack = 0
# Track whether we are building the backbone. This will affect the backbone
# related arguments, local learning rate, and so on.
current_is_backbone = True
high_resolution_outputs = []
if self._backbone_decoder_num_stacks == 0:
# Keep track of the backbone output, since it might be used as the
# semantic feature output.
backbone_output = activated_output
# Now that we have finished building the backbone, we either return the
# classification logits, or continue building a non-backbone decoder for
# panoptic segmentation.
if self._classification_mode:
endpoints['backbone_output'] = backbone_output
return None, None, None, endpoints
else:
current_is_backbone = False
if not current_is_backbone:
# Build extra layers if we have finished building the backbone.
current_name = '_stage5_' + EXTRA
current_output, activated_output, memory_feature = (
getattr(self, current_name)(
(activated_output, memory_feature), training=training))
# Compute parameter lists for stacked decoder.
total_decoder_num_stacks = (
self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks)
# Keep track of all endpoints that will be used in the stacked decoder.
stride_to_features = {}
stride_to_features[min(2, self._output_stride)] = [endpoints['stage1']]
stride_to_features[min(4, self._output_stride)] = [endpoints['stage2']]
stride_to_features[min(8, self._output_stride)] = [endpoints['stage3']]
stride_to_features[min(16, self._output_stride)] = [endpoints['stage4']]
# Only keep the last endpoint from the backbone with the same resolution,
# i.e., if the output stride is 16, the current output will override
# the stride 16 endpoint, endpoints['res4'].
stride_to_features[min(32, self._output_stride)] = [current_output]
# Use a function to compute the next stride.
next_stride_fn = lambda x: x // 2
current_decoder_stride = self._output_stride
decoder_stage = 0
# Exit if we have enough stacks and reach the decoding output stride.
while (current_stack < total_decoder_num_stacks or
current_decoder_stride > self._high_resolution_output_stride):
decoder_stage += 1
current_decoder_stride = next_stride_fn(current_decoder_stride)
if current_decoder_stride == self._output_stride:
current_stack += 1
# Switch the decoder direction if we reach the largest stride.
next_stride_fn = lambda x: x // 2
# Include the current feature and two previous features from the target
# resolution in the decoder. We select two because it contains one upward
# feature and one downward feature, but better choices are possible.
decoder_features_list = (
[current_output] +
stride_to_features[current_decoder_stride][-2:])
# Fuse and resize features with striding, resizing and 1x1 convolutions.
if current_is_backbone:
current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage)
else:
current_name = '_decoder_stage{}_{}_resized_fuse'.format(
decoder_stage, EXTRA)
activated_output = getattr(self, current_name)(
decoder_features_list, training=training)
# Apply a decoder block group for building the backbone.
if current_is_backbone:
current_name = '_decoder_stage{}'.format(decoder_stage)
current_output, activated_output, memory_feature = (
getattr(self, current_name)(
(activated_output, memory_feature), training=training))
if (current_decoder_stride == self._output_stride and
current_stack == self._backbone_decoder_num_stacks):
# Keep track of the backbone output, since it might be used as the
# semantic feature output.
backbone_output = activated_output
# Now that we have finished building the backbone, we either return the
# classification logits, or continue building a non-backbone decoder for
# panoptic segmentation.
if self._classification_mode:
endpoints['backbone_output'] = backbone_output
return None, None, None, endpoints
else:
current_is_backbone = False
# Apply a decoder block group for building the extra layers.
if not current_is_backbone:
current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA)
current_output, activated_output, memory_feature = (
getattr(self, current_name)(
(activated_output, memory_feature), training=training))
# Append the current feature into the feature dict for possible later
# usage.
stride_to_features[current_decoder_stride].append(current_output)
if current_decoder_stride == self._high_resolution_output_stride:
high_resolution_outputs.append(activated_output)
next_stride_fn = lambda x: x * 2
return memory_feature, high_resolution_outputs, backbone_output, endpoints
def call_extra_endpoints(self,
memory_feature,
high_resolution_outputs,
backbone_output,
endpoints,
training=False):
"""Performs a forward pass to generate extra endpoints.
Args:
memory_feature: None if no transformer is used. A [batch, num_memory,
memory_channel] tensor if transformer is used.
high_resolution_outputs: A list of decoded tensors with
high_resolution_output_stride.
backbone_output: An output tensor of the backbone, with output_stride.
endpoints: A dict, the network endpoints that might be used by DeepLab.
training: A boolean, whether the model is in training mode.
Returns:
endpoints: A dict, the network endpoints that might be used by DeepLab.
"""
# Assert that we have already returned if we are building a classifier.
assert not self._classification_mode
if (self._backbone_use_transformer_beyond_stride or
self._extra_decoder_use_transformer_beyond_stride):
# Build extra memory path feed forward networks for the class feature and
# the mask feature.
class_feature = getattr(self, '_class_feature_' + EXTRA)(
memory_feature, training=training)
mask_feature = getattr(self, '_mask_feature_' + EXTRA)(
memory_feature, training=training)
endpoints['transformer_class_feature'] = class_feature
endpoints['transformer_mask_feature'] = mask_feature
# Output the last high resolution feature as panoptic feature.
endpoints['feature_panoptic'] = high_resolution_outputs[-1]
# Avoid sharing our panoptic feature with the semantic auxiliary loss. So we
# use the backbone feature or the decoded backbone feature for the semantic
# segmentation head (i.e. the auxiliary loss).
if self._extra_decoder_num_stacks:
endpoints['feature_semantic'] = (
high_resolution_outputs[self._backbone_decoder_num_stacks])
else:
endpoints['feature_semantic'] = backbone_output
endpoints['backbone_output'] = backbone_output
return endpoints
def call(self, inputs, training=False):
"""Performs a forward pass.
Args:
inputs: An input [batch, height, width, channel] tensor.
training: A boolean, whether the model is in training mode.
Returns:
endpoints: A dict, the network endpoints that might be used by DeepLab.
"""
current_output, activated_output, memory_feature, endpoints = (
self.call_encoder_before_stacked_decoder(inputs, training=training))
memory_feature, high_resolution_outputs, backbone_output, endpoints = (
self.call_stacked_decoder(current_output,
activated_output,
memory_feature,
endpoints,
training=training))
if self._classification_mode:
return endpoints
endpoints = self.call_extra_endpoints(memory_feature,
high_resolution_outputs,
backbone_output,
endpoints,
training=training)
return endpoints