File size: 8,610 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Mapping, Optional, Union

# Import libraries

import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.vision.modeling.decoders import factory
from official.vision.modeling.layers import deeplab
from official.vision.modeling.layers import nn_layers

TensorMapUnion = Union[tf.Tensor, Mapping[str, tf.Tensor]]


@tf_keras.utils.register_keras_serializable(package='Vision')
class ASPP(tf_keras.layers.Layer):
  """Creates an Atrous Spatial Pyramid Pooling (ASPP) layer."""

  def __init__(
      self,
      level: int,
      dilation_rates: List[int],
      num_filters: int = 256,
      pool_kernel_size: Optional[int] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      activation: str = 'relu',
      dropout_rate: float = 0.0,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      interpolation: str = 'bilinear',
      use_depthwise_convolution: bool = False,
      spp_layer_version: str = 'v1',
      output_tensor: bool = False,
      **kwargs):
    """Initializes an Atrous Spatial Pyramid Pooling (ASPP) layer.

    Args:
      level: An `int` level to apply ASPP.
      dilation_rates: A `list` of dilation rates.
      num_filters: An `int` number of output filters in ASPP.
      pool_kernel_size: A `list` of [height, width] of pooling kernel size or
        None. Pooling size is with respect to original image size, it will be
        scaled down by 2**level. If None, global average pooling is used.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      activation: A `str` activation to be used in ASPP.
      dropout_rate: A `float` rate for dropout regularization.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      interpolation: A `str` of interpolation method. It should be one of
        `bilinear`, `nearest`, `bicubic`, `area`, `lanczos3`, `lanczos5`,
        `gaussian`, or `mitchellcubic`.
      use_depthwise_convolution: If True depthwise separable convolutions will
        be added to the Atrous spatial pyramid pooling.
     spp_layer_version: A `str` of spatial pyramid pooling layer version.
     output_tensor: Whether to output a single tensor or a dictionary of tensor.
       Default is false.
      **kwargs: Additional keyword arguments to be passed.
    """
    super().__init__(**kwargs)
    self._config_dict = {
        'level': level,
        'dilation_rates': dilation_rates,
        'num_filters': num_filters,
        'pool_kernel_size': pool_kernel_size,
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'activation': activation,
        'dropout_rate': dropout_rate,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'interpolation': interpolation,
        'use_depthwise_convolution': use_depthwise_convolution,
        'spp_layer_version': spp_layer_version,
        'output_tensor': output_tensor
    }
    self._aspp_layer = deeplab.SpatialPyramidPooling if self._config_dict[
        'spp_layer_version'] == 'v1' else nn_layers.SpatialPyramidPooling

  def build(self, input_shape):
    pool_kernel_size = None
    if self._config_dict['pool_kernel_size']:
      pool_kernel_size = [
          int(p_size // 2**self._config_dict['level'])
          for p_size in self._config_dict['pool_kernel_size']  # pytype: disable=attribute-error  # trace-all-classes
      ]

    self.aspp = self._aspp_layer(
        output_channels=self._config_dict['num_filters'],
        dilation_rates=self._config_dict['dilation_rates'],
        pool_kernel_size=pool_kernel_size,
        use_sync_bn=self._config_dict['use_sync_bn'],
        batchnorm_momentum=self._config_dict['norm_momentum'],
        batchnorm_epsilon=self._config_dict['norm_epsilon'],
        activation=self._config_dict['activation'],
        dropout=self._config_dict['dropout_rate'],
        kernel_initializer=self._config_dict['kernel_initializer'],
        kernel_regularizer=self._config_dict['kernel_regularizer'],
        interpolation=self._config_dict['interpolation'],
        use_depthwise_convolution=self._config_dict['use_depthwise_convolution']
    )

  def call(self, inputs: TensorMapUnion) -> TensorMapUnion:
    """Calls the Atrous Spatial Pyramid Pooling (ASPP) layer on an input.

    The output of ASPP will be a dict of {`level`, `tf.Tensor`} even if only one
    level is present, if output_tensor is false. Hence, this will be compatible
    with the rest of the segmentation model interfaces.
    If output_tensor is true, a single tensot is output.

    Args:
      inputs: A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or
        a `dict` of `tf.Tensor` where
        - key: A `str` of the level of the multilevel feature maps.
        - values: A `tf.Tensor` of shape [batch, height_l, width_l,
          filter_size].

    Returns:
      A `tf.Tensor` of shape [batch, height_l, width_l, filter_size] or a `dict`
        of `tf.Tensor` where
        - key: A `str` of the level of the multilevel feature maps.
        - values: A `tf.Tensor` of output of ASPP module.
    """
    outputs = {}
    level = str(self._config_dict['level'])
    backbone_output = inputs[level] if isinstance(inputs, dict) else inputs
    outputs = self.aspp(backbone_output)
    return outputs if self._config_dict['output_tensor'] else {level: outputs}

  def get_config(self) -> Mapping[str, Any]:
    base_config = super().get_config()
    return dict(list(base_config.items()) + list(self._config_dict.items()))

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)


@factory.register_decoder_builder('aspp')
def build_aspp_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds ASPP decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone. Note this is for consistent
        interface, and is not used by ASPP decoder.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf_keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf_keras.Model` instance of the ASPP decoder.

  Raises:
    ValueError: If the model_config.decoder.type is not `aspp`.
  """
  del input_specs  # input_specs is not used by ASPP decoder.
  decoder_type = model_config.decoder.type
  decoder_cfg = model_config.decoder.get()
  if decoder_type != 'aspp':
    raise ValueError(f'Inconsistent decoder type {decoder_type}. '
                     'Need to be `aspp`.')

  norm_activation_config = model_config.norm_activation
  return ASPP(
      level=decoder_cfg.level,
      dilation_rates=decoder_cfg.dilation_rates,
      num_filters=decoder_cfg.num_filters,
      use_depthwise_convolution=decoder_cfg.use_depthwise_convolution,
      pool_kernel_size=decoder_cfg.pool_kernel_size,
      dropout_rate=decoder_cfg.dropout_rate,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      activation=norm_activation_config.activation,
      kernel_regularizer=l2_regularizer,
      spp_layer_version=decoder_cfg.spp_layer_version,
      output_tensor=decoder_cfg.output_tensor)