deanna-emery's picture
updates
93528c6
raw
history blame
8.61 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Mapping, Optional, Union
# Import libraries
import tensorflow as tf, tf_keras
from official.modeling import hyperparams
from official.vision.modeling.decoders import factory
from official.vision.modeling.layers import deeplab
from official.vision.modeling.layers import nn_layers
TensorMapUnion = Union[tf.Tensor, Mapping[str, tf.Tensor]]
@tf_keras.utils.register_keras_serializable(package='Vision')
class ASPP(tf_keras.layers.Layer):
"""Creates an Atrous Spatial Pyramid Pooling (ASPP) layer."""
def __init__(
self,
level: int,
dilation_rates: List[int],
num_filters: int = 256,
pool_kernel_size: Optional[int] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
activation: str = 'relu',
dropout_rate: float = 0.0,
kernel_initializer: str = 'VarianceScaling',
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
interpolation: str = 'bilinear',
use_depthwise_convolution: bool = False,
spp_layer_version: str = 'v1',
output_tensor: bool = False,
**kwargs):
"""Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.
Args:
level: An `int` level to apply ASPP.
dilation_rates: A `list` of dilation rates.
num_filters: An `int` number of output filters in ASPP.
pool_kernel_size: A `list` of [height, width] of pooling kernel size or
None. Pooling size is with respect to original image size, it will be
scaled down by 2**level. If None, global average pooling is used.
use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A `float` added to variance to avoid dividing by zero.
activation: A `str` activation to be used in ASPP.
dropout_rate: A `float` rate for dropout regularization.
kernel_initializer: A `str` name of kernel_initializer for convolutional
layers.
kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
Conv2D. Default is None.
interpolation: A `str` of interpolation method. It should be one of
`bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
`gaussian`, or `mitchellcubic`.
use_depthwise_convolution: If True depthwise separable convolutions will
be added to the Atrous spatial pyramid pooling.
spp_layer_version: A `str` of spatial pyramid pooling layer version.
output_tensor: Whether to output a single tensor or a dictionary of tensor.
Default is false.
**kwargs: Additional keyword arguments to be passed.
"""
super().__init__(**kwargs)
self._config_dict = {
'level': level,
'dilation_rates': dilation_rates,
'num_filters': num_filters,
'pool_kernel_size': pool_kernel_size,
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'activation': activation,
'dropout_rate': dropout_rate,
'kernel_initializer': kernel_initializer,
'kernel_regularizer': kernel_regularizer,
'interpolation': interpolation,
'use_depthwise_convolution': use_depthwise_convolution,
'spp_layer_version': spp_layer_version,
'output_tensor': output_tensor
}
self._aspp_layer = deeplab.SpatialPyramidPooling if self._config_dict[
'spp_layer_version'] == 'v1' else nn_layers.SpatialPyramidPooling
def build(self, input_shape):
pool_kernel_size = None
if self._config_dict['pool_kernel_size']:
pool_kernel_size = [
int(p_size // 2**self._config_dict['level'])
for p_size in self._config_dict['pool_kernel_size'] # pytype: disable=attribute-error # trace-all-classes
]
self.aspp = self._aspp_layer(
output_channels=self._config_dict['num_filters'],
dilation_rates=self._config_dict['dilation_rates'],
pool_kernel_size=pool_kernel_size,
use_sync_bn=self._config_dict['use_sync_bn'],
batchnorm_momentum=self._config_dict['norm_momentum'],
batchnorm_epsilon=self._config_dict['norm_epsilon'],
activation=self._config_dict['activation'],
dropout=self._config_dict['dropout_rate'],
kernel_initializer=self._config_dict['kernel_initializer'],
kernel_regularizer=self._config_dict['kernel_regularizer'],
interpolation=self._config_dict['interpolation'],
use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
)
def call(self, inputs: TensorMapUnion) -> TensorMapUnion:
"""Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.
The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
level is present, if output_tensor is false. Hence, this will be compatible
with the rest of the segmentation model interfaces.
If output_tensor is true, a single tensot is output.
Args:
inputs: A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or
a `dict` of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of shape [batch, height_l, width_l,
filter_size].
Returns:
A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or a `dict`
of `tf.Tensor` where
- key: A `str` of the level of the multilevel feature maps.
- values: A `tf.Tensor` of output of ASPP module.
"""
outputs = {}
level = str(self._config_dict['level'])
backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
outputs = self.aspp(backbone_output)
return outputs if self._config_dict['output_tensor'] else {level: outputs}
def get_config(self) -> Mapping[str, Any]:
base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
@factory.register_decoder_builder('aspp')
def build_aspp_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config,
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
"""Builds ASPP decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone. Note this is for consistent
interface, and is not used by ASPP decoder.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf_keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf_keras.Model` instance of the ASPP decoder.
Raises:
ValueError: If the model_config.decoder.type is not `aspp`.
"""
del input_specs # input_specs is not used by ASPP decoder.
decoder_type = model_config.decoder.type
decoder_cfg = model_config.decoder.get()
if decoder_type != 'aspp':
raise ValueError(f'Inconsistent decoder type {decoder_type}. '
'Need to be `aspp`.')
norm_activation_config = model_config.norm_activation
return ASPP(
level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters,
use_depthwise_convolution=decoder_cfg.use_depthwise_convolution,
pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer,
spp_layer_version=decoder_cfg.spp_layer_version,
output_tensor=decoder_cfg.output_tensor)