deeplab2 / model /decoder /deeplabv3plus.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains code to build a DeepLabV3Plus.
Reference:
- [Encoder-Decoder with Atrous Separable Convolution for Semantic Image
Segmentation](https://arxiv.org/pdf/1802.02611.pdf)
"""
import tensorflow as tf
from deeplab2 import common
from deeplab2.model import utils
from deeplab2.model.decoder import aspp
from deeplab2.model.layers import convolutions
layers = tf.keras.layers
class DeepLabV3Plus(tf.keras.layers.Layer):
"""A DeepLabV3+ decoder model.
This model takes in low- and high-level features from an encoder and performs
multi-scale context aggregation with the help of an ASPP layer on high-level
features. These are concatenated with the low-level features and used as input
to the classification head that is used to predict a semantic segmentation.
"""
def __init__(self,
decoder_options,
deeplabv3plus_options,
bn_layer=tf.keras.layers.BatchNormalization):
"""Creates a DeepLabV3+ decoder of type tf.keras.layers.Layer.
Args:
decoder_options: Decoder options as defined in config_pb2.DecoderOptions.
deeplabv3plus_options: Model options as defined in
config_pb2.ModelOptions.DeeplabV3PlusOptions.
bn_layer: An optional tf.keras.layers.Layer that computes the
normalization (default: tf.keras.layers.BatchNormalization).
"""
super(DeepLabV3Plus, self).__init__(name='DeepLabv3Plus')
self._high_level_feature_name = decoder_options.feature_key
self._low_level_feature_name = deeplabv3plus_options.low_level.feature_key
self._aspp = aspp.ASPP(decoder_options.aspp_channels,
decoder_options.atrous_rates,
bn_layer=bn_layer)
# Layers for low-level feature transformation.
self._project_conv_bn_act = convolutions.Conv2DSame(
deeplabv3plus_options.low_level.channels_project,
kernel_size=1,
name='project_conv_bn_act',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='relu')
# Layers for fusing low- and high-level features.
self._fuse = convolutions.StackedConv2DSame(
conv_type='depthwise_separable_conv',
num_layers=2,
output_channels=decoder_options.decoder_channels,
kernel_size=3,
name='fuse',
use_bias=False,
use_bn=True,
bn_layer=bn_layer,
activation='relu')
self._final_conv = convolutions.Conv2DSame(
deeplabv3plus_options.num_classes, kernel_size=1, name='final_conv')
def reset_pooling_layer(self):
"""Resets the ASPP pooling layer to global average pooling."""
self._aspp.reset_pooling_layer()
def set_pool_size(self, pool_size):
"""Sets the pooling size of the ASPP pooling layer.
Args:
pool_size: A tuple specifying the pooling size of the ASPP pooling layer.
"""
self._aspp.set_pool_size(pool_size)
def get_pool_size(self):
return self._aspp.get_pool_size()
@property
def checkpoint_items(self):
items = {
common.CKPT_DEEPLABV3PLUS_ASPP: self._aspp,
common.CKPT_DEEPLABV3PLUS_PROJECT_CONV_BN_ACT:
self._project_conv_bn_act,
common.CKPT_DEEPLABV3PLUS_FUSE: self._fuse,
common.CKPT_SEMANTIC_LAST_LAYER: self._final_conv,
}
return items
def call(self, features, training=False):
"""Performs a forward pass.
Args:
features: An input dict of tf.Tensor with shape [batch, height, width,
channels]. Different keys should point to different features extracted
by the encoder, e.g. low-level or high-level features.
training: A boolean flag indicating whether training behavior should be
used (default: False).
Returns:
A dictionary containing the semantic prediction under key
common.PRED_SEMANTIC_LOGITS_KEY.
"""
low_level_features = features[self._low_level_feature_name]
high_level_features = features[self._high_level_feature_name]
high_level_features = self._aspp(high_level_features, training=training)
low_level_features = self._project_conv_bn_act(low_level_features,
training=training)
target_h = tf.shape(low_level_features)[1]
target_w = tf.shape(low_level_features)[2]
high_level_features = utils.resize_align_corners(
high_level_features, [target_h, target_w])
x = tf.concat([high_level_features, low_level_features], 3)
x = self._fuse(x)
return {common.PRED_SEMANTIC_LOGITS_KEY: self._final_conv(x)}