|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implements Axial-ResNets proposed in Axial-DeepLab [1]. |
|
|
|
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, |
|
ECCV 2020 Spotlight. |
|
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, |
|
Liang-Chieh Chen. |
|
""" |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2.model import utils |
|
from deeplab2.model.layers import activations |
|
from deeplab2.model.layers import axial_block_groups |
|
from deeplab2.model.layers import convolutions |
|
from deeplab2.model.layers import resized_fuse |
|
from deeplab2.model.layers import stems |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXTRA = 'extra' |
|
|
|
|
|
|
|
|
|
MEMORY_FEATURE = 'memory_feature' |
|
|
|
|
|
class AxialResNet(tf.keras.Model): |
|
"""An Axial-ResNet model as proposed in Axial-DeepLab [1] and MaX-DeepLab [2]. |
|
|
|
An Axial-ResNet [1] replaces 3x3 convolutions in a Resnet by axial-attention |
|
layers. A dual-path transformer [2] and a stacked decoder [2] can be used |
|
optionally. In addition, this class supports scaling models with SWideRNet [3] |
|
and augmenting convolutions with Switchable Atrous Convolution [4]. |
|
|
|
Reference: |
|
[1] Axial-Deeplab: Stand-Alone Axial-Attention for Panoptic Segmentation, |
|
ECCV 2020 Spotlight. https://arxiv.org/abs/2003.07853 |
|
Huiyu Wang, Yukun Zhu, Bradley Green, Hartwig Adam, Alan Yuille, |
|
Liang-Chieh Chen. |
|
[2] MaX-DeepLab: "End-to-End Panoptic Segmentation with Mask Transformers", |
|
CVPR 2021. https://arxiv.org/abs/2012.00759 |
|
Huiyu Wang, Yukun Zhu, Hartwig Adam, Alan Yuille, Liang-Chieh Chen. |
|
[3] Scaling Wide Residual Networks for Panoptic Segmentation, |
|
https://arxiv.org/abs/2011.11675 |
|
Liang-Chieh Chen, Huiyu Wang, Siyuan Qiao. |
|
[4] DetectoRS: Detecting Objects with Recursive Feature Pyramid and Switchable |
|
Atrous Convolution, CVPR 2021. https://arxiv.org/abs/2006.02334 |
|
Siyuan Qiao, Liang-Chieh Chen, Alan Yuille. |
|
""" |
|
|
|
def __init__(self, |
|
name, |
|
num_blocks=(3, 4, 6, 3), |
|
backbone_layer_multiplier=1.0, |
|
width_multiplier=1.0, |
|
stem_width_multiplier=1.0, |
|
output_stride=16, |
|
classification_mode=False, |
|
backbone_type='resnet_beta', |
|
use_axial_beyond_stride=16, |
|
backbone_use_transformer_beyond_stride=32, |
|
extra_decoder_use_transformer_beyond_stride=32, |
|
backbone_decoder_num_stacks=0, |
|
backbone_decoder_blocks_per_stage=1, |
|
extra_decoder_num_stacks=0, |
|
extra_decoder_blocks_per_stage=1, |
|
max_num_mask_slots=128, |
|
num_mask_slots=128, |
|
memory_channels=256, |
|
base_transformer_expansion=1.0, |
|
global_feed_forward_network_channels=256, |
|
high_resolution_output_stride=4, |
|
activation='relu', |
|
block_group_config=None, |
|
bn_layer=tf.keras.layers.BatchNormalization, |
|
conv_kernel_weight_decay=0.0): |
|
"""Initializes an AxialResNet model. |
|
|
|
Args: |
|
name: A string, the name of the model. |
|
num_blocks: A list of 4 integers. It denotes the number of blocks to |
|
include in the last 4 stages or block groups. Each group consists of |
|
blocks that output features of the same resolution. Defaults to (3, 4, |
|
6, 3) as in MaX-DeepLab-S. |
|
backbone_layer_multiplier: A float, layer_multiplier for the backbone, |
|
excluding the STEM. This flag controls the number of layers. Defaults to |
|
1.0 as in MaX-DeepLab-S. |
|
width_multiplier: A float, the channel multiplier for the block groups. |
|
Defaults to 1.0 as in MaX-DeepLab-S. |
|
stem_width_multiplier: A float, the channel multiplier for stem |
|
convolutions. Defaults to 1.0 as in MaX-DeepLab-S. |
|
output_stride: An integer, the maximum ratio of input to output spatial |
|
resolution. Defaults to 16 as in MaX-DeepLab-S. |
|
classification_mode: A boolean, whether to perform in a classification |
|
mode. If it is True, this function directly returns backbone feature |
|
endpoints. Note that these feature endpoints can also be used directly |
|
for Panoptic-DeepLab or Motion-DeepLab. If it is False, this function |
|
builds MaX-DeepLab extra decoder layers and extra transformer layers. |
|
Defaults to False as in MaX-DeepLab. |
|
backbone_type: A string, the type of backbone. Supports 'resnet', |
|
'resnet_beta', and 'wider_resnet'. It controls both the stem type and |
|
the residual block type. Defaults to 'resnet_beta' as in MaX-DeepLab-S. |
|
use_axial_beyond_stride: An integer, the stride beyond which we use axial |
|
attention. Set to 0 if no axial attention is desired. Defaults to 16 as |
|
in MaX-DeepLab. |
|
backbone_use_transformer_beyond_stride: An integer, the stride beyond |
|
which we use a memory path transformer block on top of a regular pixel |
|
path block, in the backbone. Set to 0 if no transformer block is desired |
|
in the backbone. Defaults to 32 as in MaX-DeepLab-S. |
|
extra_decoder_use_transformer_beyond_stride: An integer, the stride beyond |
|
which we use a memory path transformer block on top of a regular pixel |
|
path block, in the extra decoder stages. Set to 0 if no transformer |
|
block is desired in the extra decoder stages. Defaults to 32 as in |
|
MaX-DeepLab-S. |
|
backbone_decoder_num_stacks: An integer, the number of decoder stacks |
|
(introduced in MaX-DeepLab) that we use in the backbone. The stacked |
|
decoders are applied in a stacked hour-glass style. Defaults to 0 as in |
|
MaX-DeepLab-S. |
|
backbone_decoder_blocks_per_stage: An integer, the number of consecutive |
|
residual blocks to apply for each decoder stage, in the backbone. |
|
Defaults to 1 as in MaX-DeepLab-S. |
|
extra_decoder_num_stacks: An integer, the number of decoder stacks |
|
(introduced in MaX-DeepLab) that we use in the extra decoder layers. It |
|
is different from backbone_decoder_blocks_per_stage in that the extra |
|
decoder stacks will be trained from scratch on segmentation tasks, |
|
instead of pretrained on ImageNet classification. Defaults to 0 as in |
|
MaX-DeepLab-S. |
|
extra_decoder_blocks_per_stage: An integer, the number of consecutive |
|
residual blocks to apply for each decoder stage, in the extra decoder |
|
stages. Defaults to 1 as in MaX-DeepLab-S. |
|
max_num_mask_slots: An integer, the maximum possible number of mask slots |
|
that will be used. This will be used in a pretraining-finetuning use |
|
case with different num_mask_slots: We can set max_num_mask_slots to the |
|
maximum possible num_mask_slots, and then the saved checkpoint can be |
|
loaded for finetuning with a different num_mask_slots. Defaults to 128 |
|
as in MaX-DeepLab. |
|
num_mask_slots: An integer, the number of mask slots that will be used. |
|
Defaults to 128 as in MaX-DeepLab-S. |
|
memory_channels: An integer, the number of channels for the whole memory |
|
path. Defaults to 256 as in MaX-DeepLab-S. |
|
base_transformer_expansion: A float, the base width expansion rate for |
|
transformer layers. Defaults to 1.0 as in MaX-DeepLab-S. |
|
global_feed_forward_network_channels: An integer, the number of channels |
|
in the final global feed forward network, i.e. the mask feature head and |
|
the mask class head. Defaults to 256 as in MaX-DeepLab-S. |
|
high_resolution_output_stride: An integer, the final decoding output |
|
stride. Defaults to 4 as in MaX-DeepLab-S. |
|
activation: A string, type of activation function to apply. Support |
|
'relu', 'swish' (or 'silu'), 'gelu', 'approximated_gelu', and 'elu'. |
|
block_group_config: An argument dictionary that will be passed to |
|
block_group. |
|
bn_layer: An optional tf.keras.layers.Layer that computes the |
|
normalization (default: tf.keras.layers.BatchNormalization). |
|
conv_kernel_weight_decay: A float, the weight decay for convolution |
|
kernels. |
|
|
|
Raises: |
|
ValueError: If backbone_type is not one of 'resnet', 'resnet_beta', or |
|
'wider_resnet'. |
|
ValueError: If extra_decoder_blocks_per_stage is not greater than zero. |
|
""" |
|
super(AxialResNet, self).__init__(name=name) |
|
|
|
if extra_decoder_blocks_per_stage <= 0: |
|
raise ValueError( |
|
'Extra_decoder_blocks_per_stage should be great than zero.') |
|
if block_group_config is None: |
|
block_group_config = {} |
|
|
|
|
|
|
|
total_strides_list = [1, 2, 4, 8, 16] |
|
|
|
|
|
|
|
num_blocks_list = [3] + utils.scale_int_list(list(num_blocks), |
|
backbone_layer_multiplier) |
|
strides_list = [2] * 5 |
|
|
|
|
|
transformer_expansions_list = [] |
|
filters_list = [] |
|
for index, stride in enumerate(total_strides_list): |
|
|
|
|
|
|
|
transformer_expansions_list.append(base_transformer_expansion * stride / |
|
16.0) |
|
|
|
|
|
|
|
|
|
if backbone_type == 'wider_resnet' and index == 0: |
|
|
|
filters_list.append(int(round(stride * 32 * stem_width_multiplier))) |
|
else: |
|
filters_list.append(int(round(stride * 32 * width_multiplier))) |
|
|
|
self._num_mask_slots = None |
|
|
|
self._use_memory_feature = (backbone_use_transformer_beyond_stride or |
|
(extra_decoder_use_transformer_beyond_stride and |
|
(not classification_mode))) |
|
if self._use_memory_feature: |
|
self._memory_feature_shape = (1, max_num_mask_slots, memory_channels) |
|
self._memory_feature_initializer = ( |
|
tf.keras.initializers.TruncatedNormal(stddev=1.0)) |
|
self._memory_feature_regularizer = tf.keras.regularizers.l2( |
|
conv_kernel_weight_decay) |
|
if num_mask_slots: |
|
self._num_mask_slots = num_mask_slots |
|
|
|
|
|
stem_channels = int(round(64 * stem_width_multiplier)) |
|
self._activation_fn = activations.get_activation(activation) |
|
if use_axial_beyond_stride == 1: |
|
self._stem = tf.identity |
|
first_block_index = 0 |
|
elif backbone_type.lower() == 'wider_resnet': |
|
self._stem = convolutions.Conv2DSame( |
|
output_channels=stem_channels, |
|
kernel_size=3, |
|
name='stem', |
|
strides=2, |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
|
|
first_block_index = 0 |
|
|
|
|
|
strides_list[0] = 1 |
|
total_strides_list[0] = 2 |
|
elif backbone_type.lower() == 'resnet_beta': |
|
self._stem = stems.InceptionSTEM( |
|
bn_layer=bn_layer, |
|
width_multiplier=stem_width_multiplier, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
activation=activation) |
|
first_block_index = 1 |
|
elif backbone_type.lower() == 'resnet': |
|
self._stem = convolutions.Conv2DSame( |
|
output_channels=stem_channels, |
|
kernel_size=7, |
|
name='stem', |
|
strides=2, |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation='none', |
|
conv_kernel_weight_decay=conv_kernel_weight_decay) |
|
first_block_index = 1 |
|
else: |
|
raise ValueError(backbone_type + ' is not supported.') |
|
|
|
self._first_block_index = first_block_index |
|
|
|
|
|
for index in range(first_block_index, 5): |
|
current_name = '_stage{}'.format(index + 1) |
|
utils.safe_setattr(self, current_name, axial_block_groups.BlockGroup( |
|
filters=filters_list[index], |
|
num_blocks=num_blocks_list[index], |
|
name=utils.get_layer_name(current_name), |
|
original_resnet_stride=strides_list[index], |
|
original_resnet_input_stride=total_strides_list[index], |
|
output_stride=output_stride, |
|
backbone_type=backbone_type, |
|
use_axial_beyond_stride=use_axial_beyond_stride, |
|
use_transformer_beyond_stride=( |
|
backbone_use_transformer_beyond_stride), |
|
transformer_expansion=transformer_expansions_list[index], |
|
activation=activation, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
**block_group_config)) |
|
self._backbone_decoder_num_stacks = backbone_decoder_num_stacks |
|
self._classification_mode = classification_mode |
|
self._extra_decoder_num_stacks = extra_decoder_num_stacks |
|
self._output_stride = output_stride |
|
self._high_resolution_output_stride = high_resolution_output_stride |
|
self._width_multiplier = width_multiplier |
|
self._activation = activation |
|
self._bn_layer = bn_layer |
|
self._conv_kernel_weight_decay = conv_kernel_weight_decay |
|
self._backbone_use_transformer_beyond_stride = ( |
|
backbone_use_transformer_beyond_stride) |
|
self._extra_decoder_use_transformer_beyond_stride = ( |
|
extra_decoder_use_transformer_beyond_stride) |
|
|
|
|
|
current_stack = 0 |
|
|
|
|
|
current_is_backbone = True |
|
|
|
if backbone_decoder_num_stacks == 0: |
|
|
|
|
|
|
|
if self._classification_mode: |
|
return |
|
else: |
|
current_is_backbone = False |
|
if not current_is_backbone: |
|
|
|
|
|
|
|
current_name = '_stage5_' + EXTRA |
|
utils.safe_setattr( |
|
self, current_name, axial_block_groups.BlockGroup( |
|
filters=filters_list[-1], |
|
num_blocks=extra_decoder_blocks_per_stage, |
|
name=utils.get_layer_name(current_name), |
|
original_resnet_stride=1, |
|
original_resnet_input_stride=32, |
|
output_stride=output_stride, |
|
backbone_type=backbone_type, |
|
use_axial_beyond_stride=use_axial_beyond_stride, |
|
use_transformer_beyond_stride=( |
|
extra_decoder_use_transformer_beyond_stride), |
|
transformer_expansion=base_transformer_expansion, |
|
activation=activation, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
**block_group_config)) |
|
|
|
|
|
total_decoder_num_stacks = ( |
|
backbone_decoder_num_stacks + extra_decoder_num_stacks) |
|
|
|
|
|
next_stride_fn = lambda x: x // 2 |
|
current_decoder_stride = output_stride |
|
decoder_stage = 0 |
|
|
|
|
|
while (current_stack < total_decoder_num_stacks or |
|
current_decoder_stride > high_resolution_output_stride): |
|
decoder_stage += 1 |
|
current_decoder_stride = next_stride_fn(current_decoder_stride) |
|
|
|
if current_decoder_stride == output_stride: |
|
current_stack += 1 |
|
|
|
|
|
original_resnet_input_stride = 32 |
|
|
|
|
|
next_stride_fn = lambda x: x // 2 |
|
else: |
|
original_resnet_input_stride = current_decoder_stride |
|
|
|
|
|
decoder_channels = original_resnet_input_stride * 64 * width_multiplier |
|
current_transformer_expansion = ( |
|
base_transformer_expansion * current_decoder_stride / 16.0) |
|
|
|
|
|
if current_is_backbone: |
|
current_name = '_decoder_stage{}'.format(decoder_stage) |
|
utils.safe_setattr( |
|
self, current_name, axial_block_groups.BlockGroup( |
|
filters=decoder_channels // 4, |
|
num_blocks=backbone_decoder_blocks_per_stage, |
|
name=utils.get_layer_name(current_name), |
|
original_resnet_stride=1, |
|
original_resnet_input_stride=original_resnet_input_stride, |
|
output_stride=output_stride, |
|
backbone_type=backbone_type, |
|
use_axial_beyond_stride=use_axial_beyond_stride, |
|
use_transformer_beyond_stride=( |
|
backbone_use_transformer_beyond_stride), |
|
transformer_expansion=current_transformer_expansion, |
|
activation=activation, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
**block_group_config)) |
|
|
|
if (current_decoder_stride == output_stride and |
|
current_stack == backbone_decoder_num_stacks): |
|
|
|
|
|
|
|
if classification_mode: |
|
return |
|
else: |
|
current_is_backbone = False |
|
|
|
|
|
if not current_is_backbone: |
|
|
|
|
|
current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA) |
|
utils.safe_setattr( |
|
self, current_name, axial_block_groups.BlockGroup( |
|
filters=decoder_channels // 4, |
|
num_blocks=extra_decoder_blocks_per_stage, |
|
name=utils.get_layer_name(current_name), |
|
original_resnet_stride=1, |
|
original_resnet_input_stride=original_resnet_input_stride, |
|
output_stride=output_stride, |
|
backbone_type=backbone_type, |
|
use_axial_beyond_stride=use_axial_beyond_stride, |
|
use_transformer_beyond_stride=( |
|
extra_decoder_use_transformer_beyond_stride), |
|
transformer_expansion=current_transformer_expansion, |
|
activation=activation, |
|
bn_layer=bn_layer, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay, |
|
**block_group_config)) |
|
if current_decoder_stride == high_resolution_output_stride: |
|
next_stride_fn = lambda x: x * 2 |
|
|
|
|
|
assert not classification_mode |
|
if (backbone_use_transformer_beyond_stride or |
|
extra_decoder_use_transformer_beyond_stride): |
|
|
|
|
|
current_name = '_class_feature_' + EXTRA |
|
utils.safe_setattr( |
|
self, current_name, convolutions.Conv1D( |
|
global_feed_forward_network_channels, |
|
utils.get_layer_name(current_name), |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation=activation, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay)) |
|
current_name = '_mask_feature_' + EXTRA |
|
utils.safe_setattr( |
|
self, current_name, convolutions.Conv1D( |
|
global_feed_forward_network_channels, |
|
utils.get_layer_name(current_name), |
|
use_bias=False, |
|
use_bn=True, |
|
bn_layer=bn_layer, |
|
activation=activation, |
|
conv_kernel_weight_decay=conv_kernel_weight_decay)) |
|
|
|
def build(self, input_shape): |
|
"""Builds model weights and input shape dependent sub-layers.""" |
|
if self._use_memory_feature: |
|
self._memory_feature = self.add_weight( |
|
name=MEMORY_FEATURE, |
|
shape=self._memory_feature_shape, |
|
initializer=self._memory_feature_initializer, |
|
regularizer=self._memory_feature_regularizer) |
|
else: |
|
self._memory_feature = None |
|
|
|
|
|
current_stack = 0 |
|
|
|
|
|
current_is_backbone = self._backbone_decoder_num_stacks != 0 |
|
total_decoder_num_stacks = ( |
|
self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks) |
|
next_stride_fn = lambda x: x // 2 |
|
current_decoder_stride = self._output_stride |
|
decoder_stage = 0 |
|
while (current_stack < total_decoder_num_stacks or |
|
current_decoder_stride > self._high_resolution_output_stride): |
|
decoder_stage += 1 |
|
current_decoder_stride = next_stride_fn(current_decoder_stride) |
|
if current_decoder_stride == self._output_stride: |
|
current_stack += 1 |
|
original_resnet_input_stride = 32 |
|
next_stride_fn = lambda x: x // 2 |
|
else: |
|
original_resnet_input_stride = current_decoder_stride |
|
|
|
|
|
|
|
decoder_channels = int(round( |
|
original_resnet_input_stride * 64 * self._width_multiplier)) |
|
decoder_height, decoder_width = utils.scale_mutable_sequence( |
|
input_shape[1:3], 1.0 / current_decoder_stride) |
|
if current_is_backbone: |
|
current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage) |
|
else: |
|
current_name = '_decoder_stage{}_{}_resized_fuse'.format( |
|
decoder_stage, EXTRA) |
|
utils.safe_setattr( |
|
self, current_name, resized_fuse.ResizedFuse( |
|
name=utils.get_layer_name(current_name), |
|
height=decoder_height, |
|
width=decoder_width, |
|
num_channels=decoder_channels, |
|
activation=self._activation, |
|
bn_layer=self._bn_layer, |
|
conv_kernel_weight_decay=self._conv_kernel_weight_decay)) |
|
if (current_decoder_stride == self._output_stride and |
|
current_stack == self._backbone_decoder_num_stacks): |
|
|
|
|
|
|
|
if self._classification_mode: |
|
return |
|
current_is_backbone = False |
|
if current_decoder_stride == self._high_resolution_output_stride: |
|
next_stride_fn = lambda x: x * 2 |
|
|
|
def call_encoder_before_stacked_decoder(self, inputs, training=False): |
|
"""Performs a forward pass of the encoder before stacking decoders. |
|
|
|
Args: |
|
inputs: An input [batch, height, width, channel] tensor. |
|
training: A boolean, whether the model is in training mode. |
|
|
|
Returns: |
|
current_output: An output tensor with shape [batch, new_height, new_width, |
|
new_channel]. |
|
activated_output: An activated output tensor with shape [batch, |
|
new_height, new_width, new_channel]. |
|
memory_feature: None if no transformer is used. A [batch, num_memory, |
|
memory_channel] tensor if transformer is used. |
|
endpoints: A dict, the network endpoints that might be used by DeepLab. |
|
""" |
|
memory_feature = self._memory_feature |
|
if self._use_memory_feature: |
|
if self._num_mask_slots: |
|
memory_feature = self._memory_feature[:, :self._num_mask_slots, :] |
|
memory_feature = tf.tile(memory_feature, |
|
[tf.shape(inputs)[0], 1, 1]) |
|
|
|
endpoints = {} |
|
output = self._stem(inputs) |
|
activated_output = self._activation_fn(output) |
|
endpoints['stage1'] = output |
|
endpoints['res1'] = activated_output |
|
|
|
|
|
|
|
for index in range(self._first_block_index, 5): |
|
current_name = '_stage{}'.format(index + 1) |
|
current_output, activated_output, memory_feature = ( |
|
getattr(self, current_name)( |
|
(activated_output, memory_feature), training=training)) |
|
endpoints[utils.get_layer_name(current_name)] = current_output |
|
activated_output_name = 'res{}'.format(index + 1) |
|
endpoints[activated_output_name] = activated_output |
|
return current_output, activated_output, memory_feature, endpoints |
|
|
|
def call_stacked_decoder(self, |
|
current_output, |
|
activated_output, |
|
memory_feature, |
|
endpoints, |
|
training=False): |
|
"""Performs a forward pass of the stacked decoders. |
|
|
|
Args: |
|
current_output: An output tensor with shape [batch, new_height, new_width, |
|
new_channel]. |
|
activated_output: An activated output tensor with shape [batch, |
|
new_height, new_width, new_channel]. |
|
memory_feature: None if no transformer is used. A [batch, num_memory, |
|
memory_channel] tensor if transformer is used. |
|
endpoints: A dict, the network endpoints that might be used by DeepLab. |
|
training: A boolean, whether the model is in training mode. |
|
|
|
Returns: |
|
memory_feature: None if no transformer is used. A [batch, num_memory, |
|
memory_channel] tensor if transformer is used. |
|
high_resolution_outputs: A list of decoded tensors with |
|
high_resolution_output_stride. |
|
backbone_output: An output tensor of the backbone, with output_stride. |
|
endpoints: A dict, the network endpoints that might be used by DeepLab. |
|
""" |
|
|
|
current_stack = 0 |
|
|
|
|
|
current_is_backbone = True |
|
high_resolution_outputs = [] |
|
|
|
if self._backbone_decoder_num_stacks == 0: |
|
|
|
|
|
backbone_output = activated_output |
|
|
|
|
|
|
|
if self._classification_mode: |
|
endpoints['backbone_output'] = backbone_output |
|
return None, None, None, endpoints |
|
else: |
|
current_is_backbone = False |
|
|
|
if not current_is_backbone: |
|
|
|
current_name = '_stage5_' + EXTRA |
|
current_output, activated_output, memory_feature = ( |
|
getattr(self, current_name)( |
|
(activated_output, memory_feature), training=training)) |
|
|
|
|
|
total_decoder_num_stacks = ( |
|
self._backbone_decoder_num_stacks + self._extra_decoder_num_stacks) |
|
|
|
|
|
stride_to_features = {} |
|
stride_to_features[min(2, self._output_stride)] = [endpoints['stage1']] |
|
stride_to_features[min(4, self._output_stride)] = [endpoints['stage2']] |
|
stride_to_features[min(8, self._output_stride)] = [endpoints['stage3']] |
|
stride_to_features[min(16, self._output_stride)] = [endpoints['stage4']] |
|
|
|
|
|
|
|
stride_to_features[min(32, self._output_stride)] = [current_output] |
|
|
|
|
|
next_stride_fn = lambda x: x // 2 |
|
current_decoder_stride = self._output_stride |
|
decoder_stage = 0 |
|
|
|
|
|
while (current_stack < total_decoder_num_stacks or |
|
current_decoder_stride > self._high_resolution_output_stride): |
|
decoder_stage += 1 |
|
current_decoder_stride = next_stride_fn(current_decoder_stride) |
|
|
|
if current_decoder_stride == self._output_stride: |
|
current_stack += 1 |
|
|
|
next_stride_fn = lambda x: x // 2 |
|
|
|
|
|
|
|
|
|
decoder_features_list = ( |
|
[current_output] + |
|
stride_to_features[current_decoder_stride][-2:]) |
|
|
|
|
|
if current_is_backbone: |
|
current_name = '_decoder_stage{}_resized_fuse'.format(decoder_stage) |
|
else: |
|
current_name = '_decoder_stage{}_{}_resized_fuse'.format( |
|
decoder_stage, EXTRA) |
|
activated_output = getattr(self, current_name)( |
|
decoder_features_list, training=training) |
|
|
|
|
|
if current_is_backbone: |
|
current_name = '_decoder_stage{}'.format(decoder_stage) |
|
current_output, activated_output, memory_feature = ( |
|
getattr(self, current_name)( |
|
(activated_output, memory_feature), training=training)) |
|
|
|
if (current_decoder_stride == self._output_stride and |
|
current_stack == self._backbone_decoder_num_stacks): |
|
|
|
|
|
backbone_output = activated_output |
|
|
|
|
|
|
|
if self._classification_mode: |
|
endpoints['backbone_output'] = backbone_output |
|
return None, None, None, endpoints |
|
else: |
|
current_is_backbone = False |
|
|
|
|
|
if not current_is_backbone: |
|
current_name = '_decoder_stage{}_{}'.format(decoder_stage, EXTRA) |
|
current_output, activated_output, memory_feature = ( |
|
getattr(self, current_name)( |
|
(activated_output, memory_feature), training=training)) |
|
|
|
|
|
|
|
stride_to_features[current_decoder_stride].append(current_output) |
|
if current_decoder_stride == self._high_resolution_output_stride: |
|
high_resolution_outputs.append(activated_output) |
|
next_stride_fn = lambda x: x * 2 |
|
return memory_feature, high_resolution_outputs, backbone_output, endpoints |
|
|
|
def call_extra_endpoints(self, |
|
memory_feature, |
|
high_resolution_outputs, |
|
backbone_output, |
|
endpoints, |
|
training=False): |
|
"""Performs a forward pass to generate extra endpoints. |
|
|
|
Args: |
|
memory_feature: None if no transformer is used. A [batch, num_memory, |
|
memory_channel] tensor if transformer is used. |
|
high_resolution_outputs: A list of decoded tensors with |
|
high_resolution_output_stride. |
|
backbone_output: An output tensor of the backbone, with output_stride. |
|
endpoints: A dict, the network endpoints that might be used by DeepLab. |
|
training: A boolean, whether the model is in training mode. |
|
|
|
Returns: |
|
endpoints: A dict, the network endpoints that might be used by DeepLab. |
|
""" |
|
|
|
assert not self._classification_mode |
|
if (self._backbone_use_transformer_beyond_stride or |
|
self._extra_decoder_use_transformer_beyond_stride): |
|
|
|
|
|
class_feature = getattr(self, '_class_feature_' + EXTRA)( |
|
memory_feature, training=training) |
|
mask_feature = getattr(self, '_mask_feature_' + EXTRA)( |
|
memory_feature, training=training) |
|
endpoints['transformer_class_feature'] = class_feature |
|
endpoints['transformer_mask_feature'] = mask_feature |
|
|
|
|
|
endpoints['feature_panoptic'] = high_resolution_outputs[-1] |
|
|
|
|
|
|
|
|
|
if self._extra_decoder_num_stacks: |
|
endpoints['feature_semantic'] = ( |
|
high_resolution_outputs[self._backbone_decoder_num_stacks]) |
|
else: |
|
endpoints['feature_semantic'] = backbone_output |
|
endpoints['backbone_output'] = backbone_output |
|
return endpoints |
|
|
|
def call(self, inputs, training=False): |
|
"""Performs a forward pass. |
|
|
|
Args: |
|
inputs: An input [batch, height, width, channel] tensor. |
|
training: A boolean, whether the model is in training mode. |
|
|
|
Returns: |
|
endpoints: A dict, the network endpoints that might be used by DeepLab. |
|
""" |
|
current_output, activated_output, memory_feature, endpoints = ( |
|
self.call_encoder_before_stacked_decoder(inputs, training=training)) |
|
memory_feature, high_resolution_outputs, backbone_output, endpoints = ( |
|
self.call_stacked_decoder(current_output, |
|
activated_output, |
|
memory_feature, |
|
endpoints, |
|
training=training)) |
|
if self._classification_mode: |
|
return endpoints |
|
endpoints = self.call_extra_endpoints(memory_feature, |
|
high_resolution_outputs, |
|
backbone_output, |
|
endpoints, |
|
training=training) |
|
return endpoints |
|
|