deeplab2 / video /motion_deeplab.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains the Motion-DeepLab architecture."""
import functools
from typing import Any, Dict, Text, Tuple
from absl import logging
import tensorflow as tf
from deeplab2 import common
from deeplab2 import config_pb2
from deeplab2.data import dataset
from deeplab2.model import builder
from deeplab2.model import utils
from deeplab2.model.post_processor import motion_deeplab
from deeplab2.model.post_processor import post_processor_builder
class MotionDeepLab(tf.keras.Model):
"""This class represents the Motion-DeepLab meta architecture.
This class is the basis of the Motion-DeepLab architecture. This Model can be
used for Video Panoptic Segmentation or Segmenting and Tracking Every Pixel
(STEP).
"""
def __init__(self,
config: config_pb2.ExperimentOptions,
dataset_descriptor: dataset.DatasetDescriptor):
"""Initializes a Motion-DeepLab architecture.
Args:
config: A config_pb2.ExperimentOptions configuration.
dataset_descriptor: A dataset.DatasetDescriptor.
"""
super(MotionDeepLab, self).__init__(name='MotionDeepLab')
if config.trainer_options.solver_options.use_sync_batchnorm:
logging.info('Synchronized Batchnorm is used.')
bn_layer = functools.partial(
tf.keras.layers.experimental.SyncBatchNormalization,
momentum=config.trainer_options.solver_options.batchnorm_momentum,
epsilon=config.trainer_options.solver_options.batchnorm_epsilon)
else:
logging.info('Standard (unsynchronized) Batchnorm is used.')
bn_layer = functools.partial(
tf.keras.layers.BatchNormalization,
momentum=config.trainer_options.solver_options.batchnorm_momentum,
epsilon=config.trainer_options.solver_options.batchnorm_epsilon)
self._encoder = builder.create_encoder(
config.model_options.backbone, bn_layer,
conv_kernel_weight_decay=(
config.trainer_options.solver_options.weight_decay))
self._decoder = builder.create_decoder(config.model_options, bn_layer,
dataset_descriptor.ignore_label)
self._prev_center_prediction = tf.Variable(
0.0,
trainable=False,
validate_shape=False,
shape=tf.TensorShape(None),
dtype=tf.float32,
name='prev_prediction_buffer')
self._prev_center_list = tf.Variable(
tf.zeros((0, 5), dtype=tf.int32),
trainable=False,
validate_shape=False,
shape=tf.TensorShape(None),
name='prev_prediction_list')
self._next_tracking_id = tf.Variable(
1,
trainable=False,
validate_shape=False,
dtype=tf.int32,
name='next+_tracking_id')
self._post_processor = post_processor_builder.get_post_processor(
config, dataset_descriptor)
self._render_fn = functools.partial(
motion_deeplab.render_panoptic_map_as_heatmap,
sigma=8,
label_divisor=dataset_descriptor.panoptic_label_divisor,
void_label=dataset_descriptor.ignore_label)
self._track_fn = functools.partial(
motion_deeplab.assign_instances_to_previous_tracks,
label_divisor=dataset_descriptor.panoptic_label_divisor)
# The ASPP pooling size is always set to train crop size, which is found to
# be experimentally better.
pool_size = config.train_dataset_options.crop_size
output_stride = float(config.model_options.backbone.output_stride)
pool_size = tuple(
utils.scale_mutable_sequence(pool_size, 1.0 / output_stride))
logging.info('Setting pooling size to %s', pool_size)
self.set_pool_size(pool_size)
def call(self, input_tensor: tf.Tensor, training=False) -> Dict[Text, Any]:
"""Performs a forward pass.
Args:
input_tensor: An input tensor of type tf.Tensor with shape [batch, height,
width, channels]. The input tensor should contain batches of RGB images.
training: A boolean flag indicating whether training behavior should be
used (default: False).
Returns:
A dictionary containing the results of the specified DeepLab architecture.
The results are bilinearly upsampled to input size before returning.
"""
if not training:
# During evaluation, we add the previous predicted heatmap as 7th input
# channel (cf. during training, we use groundtruth heatmap).
input_tensor = self._add_previous_heatmap_to_input(input_tensor)
# Normalize the input in the same way as Inception. We normalize it outside
# the encoder so that we can extend encoders to different backbones without
# copying the normalization to each encoder. We normalize it after data
# preprocessing because it is faster on TPUs than on host CPUs. The
# normalization should not increase TPU memory consumption because it does
# not require gradient.
input_tensor = input_tensor / 127.5 - 1.0
# Get the static spatial shape of the input tensor.
_, input_h, input_w, _ = input_tensor.get_shape().as_list()
pred = self._decoder(
self._encoder(input_tensor, training=training), training=training)
result_dict = dict()
for key, value in pred.items():
if (key == common.PRED_OFFSET_MAP_KEY or
key == common.PRED_FRAME_OFFSET_MAP_KEY):
result_dict[key] = utils.resize_and_rescale_offsets(
value, [input_h, input_w])
else:
result_dict[key] = utils.resize_bilinear(
value, [input_h, input_w])
# Change the semantic logits to probabilities with softmax.
result_dict[common.PRED_SEMANTIC_PROBS_KEY] = tf.nn.softmax(
result_dict[common.PRED_SEMANTIC_LOGITS_KEY])
if not training:
result_dict.update(self._post_processor(result_dict))
next_heatmap, next_centers = self._render_fn(
result_dict[common.PRED_PANOPTIC_KEY])
panoptic_map, next_centers, next_id = self._track_fn(
self._prev_center_list.value(),
next_centers,
next_heatmap,
result_dict[common.PRED_FRAME_OFFSET_MAP_KEY],
result_dict[common.PRED_PANOPTIC_KEY],
self._next_tracking_id.value()
)
result_dict[common.PRED_PANOPTIC_KEY] = panoptic_map
self._next_tracking_id.assign(next_id)
self._prev_center_prediction.assign(
tf.expand_dims(next_heatmap, axis=3, name='expand_prev_centermap'))
self._prev_center_list.assign(next_centers)
if common.PRED_CENTER_HEATMAP_KEY in result_dict:
result_dict[common.PRED_CENTER_HEATMAP_KEY] = tf.squeeze(
result_dict[common.PRED_CENTER_HEATMAP_KEY], axis=3)
return result_dict
def _add_previous_heatmap_to_input(self, input_tensor: tf.Tensor
) -> tf.Tensor:
frame1, frame2 = tf.split(input_tensor, [3, 3], axis=3)
# We use a simple way to detect if the first frame of a sequence is being
# processed. For the first frame, frame1 and frame2 are identical.
if tf.reduce_all(tf.equal(frame1, frame2)):
h = tf.shape(input_tensor)[1]
w = tf.shape(input_tensor)[2]
prev_center = tf.zeros((1, h, w, 1), dtype=tf.float32)
self._prev_center_list.assign(tf.zeros((0, 5), dtype=tf.int32))
self._next_tracking_id.assign(1)
else:
prev_center = self._prev_center_prediction
output_tensor = tf.concat([frame1, frame2, prev_center], axis=3)
output_tensor.set_shape([None, None, None, 7])
return output_tensor
def reset_pooling_layer(self):
"""Resets the ASPP pooling layer to global average pooling."""
self._decoder.reset_pooling_layer()
def set_pool_size(self, pool_size: Tuple[int, int]):
"""Sets the pooling size of the ASPP pooling layer.
Args:
pool_size: A tuple specifying the pooling size of the ASPP pooling layer.
"""
self._decoder.set_pool_size(pool_size)
@property
def checkpoint_items(self) -> Dict[Text, Any]:
items = dict(encoder=self._encoder)
items.update(self._decoder.checkpoint_items)
return items