# coding=utf-8 # Copyright 2021 The Deeplab2 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This file contains the Motion-DeepLab architecture.""" import functools from typing import Any, Dict, Text, Tuple from absl import logging import tensorflow as tf from deeplab2 import common from deeplab2 import config_pb2 from deeplab2.data import dataset from deeplab2.model import builder from deeplab2.model import utils from deeplab2.model.post_processor import motion_deeplab from deeplab2.model.post_processor import post_processor_builder class MotionDeepLab(tf.keras.Model): """This class represents the Motion-DeepLab meta architecture. This class is the basis of the Motion-DeepLab architecture. This Model can be used for Video Panoptic Segmentation or Segmenting and Tracking Every Pixel (STEP). """ def __init__(self, config: config_pb2.ExperimentOptions, dataset_descriptor: dataset.DatasetDescriptor): """Initializes a Motion-DeepLab architecture. Args: config: A config_pb2.ExperimentOptions configuration. dataset_descriptor: A dataset.DatasetDescriptor. """ super(MotionDeepLab, self).__init__(name='MotionDeepLab') if config.trainer_options.solver_options.use_sync_batchnorm: logging.info('Synchronized Batchnorm is used.') bn_layer = functools.partial( tf.keras.layers.experimental.SyncBatchNormalization, momentum=config.trainer_options.solver_options.batchnorm_momentum, epsilon=config.trainer_options.solver_options.batchnorm_epsilon) else: logging.info('Standard (unsynchronized) Batchnorm is used.') bn_layer = functools.partial( tf.keras.layers.BatchNormalization, momentum=config.trainer_options.solver_options.batchnorm_momentum, epsilon=config.trainer_options.solver_options.batchnorm_epsilon) self._encoder = builder.create_encoder( config.model_options.backbone, bn_layer, conv_kernel_weight_decay=( config.trainer_options.solver_options.weight_decay)) self._decoder = builder.create_decoder(config.model_options, bn_layer, dataset_descriptor.ignore_label) self._prev_center_prediction = tf.Variable( 0.0, trainable=False, validate_shape=False, shape=tf.TensorShape(None), dtype=tf.float32, name='prev_prediction_buffer') self._prev_center_list = tf.Variable( tf.zeros((0, 5), dtype=tf.int32), trainable=False, validate_shape=False, shape=tf.TensorShape(None), name='prev_prediction_list') self._next_tracking_id = tf.Variable( 1, trainable=False, validate_shape=False, dtype=tf.int32, name='next+_tracking_id') self._post_processor = post_processor_builder.get_post_processor( config, dataset_descriptor) self._render_fn = functools.partial( motion_deeplab.render_panoptic_map_as_heatmap, sigma=8, label_divisor=dataset_descriptor.panoptic_label_divisor, void_label=dataset_descriptor.ignore_label) self._track_fn = functools.partial( motion_deeplab.assign_instances_to_previous_tracks, label_divisor=dataset_descriptor.panoptic_label_divisor) # The ASPP pooling size is always set to train crop size, which is found to # be experimentally better. pool_size = config.train_dataset_options.crop_size output_stride = float(config.model_options.backbone.output_stride) pool_size = tuple( utils.scale_mutable_sequence(pool_size, 1.0 / output_stride)) logging.info('Setting pooling size to %s', pool_size) self.set_pool_size(pool_size) def call(self, input_tensor: tf.Tensor, training=False) -> Dict[Text, Any]: """Performs a forward pass. Args: input_tensor: An input tensor of type tf.Tensor with shape [batch, height, width, channels]. The input tensor should contain batches of RGB images. training: A boolean flag indicating whether training behavior should be used (default: False). Returns: A dictionary containing the results of the specified DeepLab architecture. The results are bilinearly upsampled to input size before returning. """ if not training: # During evaluation, we add the previous predicted heatmap as 7th input # channel (cf. during training, we use groundtruth heatmap). input_tensor = self._add_previous_heatmap_to_input(input_tensor) # Normalize the input in the same way as Inception. We normalize it outside # the encoder so that we can extend encoders to different backbones without # copying the normalization to each encoder. We normalize it after data # preprocessing because it is faster on TPUs than on host CPUs. The # normalization should not increase TPU memory consumption because it does # not require gradient. input_tensor = input_tensor / 127.5 - 1.0 # Get the static spatial shape of the input tensor. _, input_h, input_w, _ = input_tensor.get_shape().as_list() pred = self._decoder( self._encoder(input_tensor, training=training), training=training) result_dict = dict() for key, value in pred.items(): if (key == common.PRED_OFFSET_MAP_KEY or key == common.PRED_FRAME_OFFSET_MAP_KEY): result_dict[key] = utils.resize_and_rescale_offsets( value, [input_h, input_w]) else: result_dict[key] = utils.resize_bilinear( value, [input_h, input_w]) # Change the semantic logits to probabilities with softmax. result_dict[common.PRED_SEMANTIC_PROBS_KEY] = tf.nn.softmax( result_dict[common.PRED_SEMANTIC_LOGITS_KEY]) if not training: result_dict.update(self._post_processor(result_dict)) next_heatmap, next_centers = self._render_fn( result_dict[common.PRED_PANOPTIC_KEY]) panoptic_map, next_centers, next_id = self._track_fn( self._prev_center_list.value(), next_centers, next_heatmap, result_dict[common.PRED_FRAME_OFFSET_MAP_KEY], result_dict[common.PRED_PANOPTIC_KEY], self._next_tracking_id.value() ) result_dict[common.PRED_PANOPTIC_KEY] = panoptic_map self._next_tracking_id.assign(next_id) self._prev_center_prediction.assign( tf.expand_dims(next_heatmap, axis=3, name='expand_prev_centermap')) self._prev_center_list.assign(next_centers) if common.PRED_CENTER_HEATMAP_KEY in result_dict: result_dict[common.PRED_CENTER_HEATMAP_KEY] = tf.squeeze( result_dict[common.PRED_CENTER_HEATMAP_KEY], axis=3) return result_dict def _add_previous_heatmap_to_input(self, input_tensor: tf.Tensor ) -> tf.Tensor: frame1, frame2 = tf.split(input_tensor, [3, 3], axis=3) # We use a simple way to detect if the first frame of a sequence is being # processed. For the first frame, frame1 and frame2 are identical. if tf.reduce_all(tf.equal(frame1, frame2)): h = tf.shape(input_tensor)[1] w = tf.shape(input_tensor)[2] prev_center = tf.zeros((1, h, w, 1), dtype=tf.float32) self._prev_center_list.assign(tf.zeros((0, 5), dtype=tf.int32)) self._next_tracking_id.assign(1) else: prev_center = self._prev_center_prediction output_tensor = tf.concat([frame1, frame2, prev_center], axis=3) output_tensor.set_shape([None, None, None, 7]) return output_tensor def reset_pooling_layer(self): """Resets the ASPP pooling layer to global average pooling.""" self._decoder.reset_pooling_layer() def set_pool_size(self, pool_size: Tuple[int, int]): """Sets the pooling size of the ASPP pooling layer. Args: pool_size: A tuple specifying the pooling size of the ASPP pooling layer. """ self._decoder.set_pool_size(pool_size) @property def checkpoint_items(self) -> Dict[Text, Any]: items = dict(encoder=self._encoder) items.update(self._decoder.checkpoint_items) return items