|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This file contains the Motion-DeepLab architecture.""" |
|
|
|
import functools |
|
from typing import Any, Dict, Text, Tuple |
|
|
|
from absl import logging |
|
import tensorflow as tf |
|
|
|
from deeplab2 import common |
|
from deeplab2 import config_pb2 |
|
from deeplab2.data import dataset |
|
from deeplab2.model import builder |
|
from deeplab2.model import utils |
|
from deeplab2.model.post_processor import motion_deeplab |
|
from deeplab2.model.post_processor import post_processor_builder |
|
|
|
|
|
class MotionDeepLab(tf.keras.Model): |
|
"""This class represents the Motion-DeepLab meta architecture. |
|
|
|
This class is the basis of the Motion-DeepLab architecture. This Model can be |
|
used for Video Panoptic Segmentation or Segmenting and Tracking Every Pixel |
|
(STEP). |
|
""" |
|
|
|
def __init__(self, |
|
config: config_pb2.ExperimentOptions, |
|
dataset_descriptor: dataset.DatasetDescriptor): |
|
"""Initializes a Motion-DeepLab architecture. |
|
|
|
Args: |
|
config: A config_pb2.ExperimentOptions configuration. |
|
dataset_descriptor: A dataset.DatasetDescriptor. |
|
""" |
|
super(MotionDeepLab, self).__init__(name='MotionDeepLab') |
|
|
|
if config.trainer_options.solver_options.use_sync_batchnorm: |
|
logging.info('Synchronized Batchnorm is used.') |
|
bn_layer = functools.partial( |
|
tf.keras.layers.experimental.SyncBatchNormalization, |
|
momentum=config.trainer_options.solver_options.batchnorm_momentum, |
|
epsilon=config.trainer_options.solver_options.batchnorm_epsilon) |
|
else: |
|
logging.info('Standard (unsynchronized) Batchnorm is used.') |
|
bn_layer = functools.partial( |
|
tf.keras.layers.BatchNormalization, |
|
momentum=config.trainer_options.solver_options.batchnorm_momentum, |
|
epsilon=config.trainer_options.solver_options.batchnorm_epsilon) |
|
|
|
self._encoder = builder.create_encoder( |
|
config.model_options.backbone, bn_layer, |
|
conv_kernel_weight_decay=( |
|
config.trainer_options.solver_options.weight_decay)) |
|
|
|
self._decoder = builder.create_decoder(config.model_options, bn_layer, |
|
dataset_descriptor.ignore_label) |
|
|
|
self._prev_center_prediction = tf.Variable( |
|
0.0, |
|
trainable=False, |
|
validate_shape=False, |
|
shape=tf.TensorShape(None), |
|
dtype=tf.float32, |
|
name='prev_prediction_buffer') |
|
self._prev_center_list = tf.Variable( |
|
tf.zeros((0, 5), dtype=tf.int32), |
|
trainable=False, |
|
validate_shape=False, |
|
shape=tf.TensorShape(None), |
|
name='prev_prediction_list') |
|
self._next_tracking_id = tf.Variable( |
|
1, |
|
trainable=False, |
|
validate_shape=False, |
|
dtype=tf.int32, |
|
name='next+_tracking_id') |
|
|
|
self._post_processor = post_processor_builder.get_post_processor( |
|
config, dataset_descriptor) |
|
self._render_fn = functools.partial( |
|
motion_deeplab.render_panoptic_map_as_heatmap, |
|
sigma=8, |
|
label_divisor=dataset_descriptor.panoptic_label_divisor, |
|
void_label=dataset_descriptor.ignore_label) |
|
self._track_fn = functools.partial( |
|
motion_deeplab.assign_instances_to_previous_tracks, |
|
label_divisor=dataset_descriptor.panoptic_label_divisor) |
|
|
|
|
|
pool_size = config.train_dataset_options.crop_size |
|
output_stride = float(config.model_options.backbone.output_stride) |
|
pool_size = tuple( |
|
utils.scale_mutable_sequence(pool_size, 1.0 / output_stride)) |
|
logging.info('Setting pooling size to %s', pool_size) |
|
self.set_pool_size(pool_size) |
|
|
|
def call(self, input_tensor: tf.Tensor, training=False) -> Dict[Text, Any]: |
|
"""Performs a forward pass. |
|
|
|
Args: |
|
input_tensor: An input tensor of type tf.Tensor with shape [batch, height, |
|
width, channels]. The input tensor should contain batches of RGB images. |
|
training: A boolean flag indicating whether training behavior should be |
|
used (default: False). |
|
|
|
Returns: |
|
A dictionary containing the results of the specified DeepLab architecture. |
|
The results are bilinearly upsampled to input size before returning. |
|
""" |
|
if not training: |
|
|
|
|
|
input_tensor = self._add_previous_heatmap_to_input(input_tensor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_tensor = input_tensor / 127.5 - 1.0 |
|
|
|
_, input_h, input_w, _ = input_tensor.get_shape().as_list() |
|
|
|
pred = self._decoder( |
|
self._encoder(input_tensor, training=training), training=training) |
|
result_dict = dict() |
|
for key, value in pred.items(): |
|
if (key == common.PRED_OFFSET_MAP_KEY or |
|
key == common.PRED_FRAME_OFFSET_MAP_KEY): |
|
result_dict[key] = utils.resize_and_rescale_offsets( |
|
value, [input_h, input_w]) |
|
else: |
|
result_dict[key] = utils.resize_bilinear( |
|
value, [input_h, input_w]) |
|
|
|
|
|
result_dict[common.PRED_SEMANTIC_PROBS_KEY] = tf.nn.softmax( |
|
result_dict[common.PRED_SEMANTIC_LOGITS_KEY]) |
|
if not training: |
|
result_dict.update(self._post_processor(result_dict)) |
|
|
|
next_heatmap, next_centers = self._render_fn( |
|
result_dict[common.PRED_PANOPTIC_KEY]) |
|
panoptic_map, next_centers, next_id = self._track_fn( |
|
self._prev_center_list.value(), |
|
next_centers, |
|
next_heatmap, |
|
result_dict[common.PRED_FRAME_OFFSET_MAP_KEY], |
|
result_dict[common.PRED_PANOPTIC_KEY], |
|
self._next_tracking_id.value() |
|
) |
|
|
|
result_dict[common.PRED_PANOPTIC_KEY] = panoptic_map |
|
self._next_tracking_id.assign(next_id) |
|
self._prev_center_prediction.assign( |
|
tf.expand_dims(next_heatmap, axis=3, name='expand_prev_centermap')) |
|
self._prev_center_list.assign(next_centers) |
|
|
|
if common.PRED_CENTER_HEATMAP_KEY in result_dict: |
|
result_dict[common.PRED_CENTER_HEATMAP_KEY] = tf.squeeze( |
|
result_dict[common.PRED_CENTER_HEATMAP_KEY], axis=3) |
|
return result_dict |
|
|
|
def _add_previous_heatmap_to_input(self, input_tensor: tf.Tensor |
|
) -> tf.Tensor: |
|
frame1, frame2 = tf.split(input_tensor, [3, 3], axis=3) |
|
|
|
|
|
if tf.reduce_all(tf.equal(frame1, frame2)): |
|
h = tf.shape(input_tensor)[1] |
|
w = tf.shape(input_tensor)[2] |
|
prev_center = tf.zeros((1, h, w, 1), dtype=tf.float32) |
|
self._prev_center_list.assign(tf.zeros((0, 5), dtype=tf.int32)) |
|
self._next_tracking_id.assign(1) |
|
else: |
|
prev_center = self._prev_center_prediction |
|
output_tensor = tf.concat([frame1, frame2, prev_center], axis=3) |
|
output_tensor.set_shape([None, None, None, 7]) |
|
return output_tensor |
|
|
|
def reset_pooling_layer(self): |
|
"""Resets the ASPP pooling layer to global average pooling.""" |
|
self._decoder.reset_pooling_layer() |
|
|
|
def set_pool_size(self, pool_size: Tuple[int, int]): |
|
"""Sets the pooling size of the ASPP pooling layer. |
|
|
|
Args: |
|
pool_size: A tuple specifying the pooling size of the ASPP pooling layer. |
|
""" |
|
self._decoder.set_pool_size(pool_size) |
|
|
|
@property |
|
def checkpoint_items(self) -> Dict[Text, Any]: |
|
items = dict(encoder=self._encoder) |
|
items.update(self._decoder.checkpoint_items) |
|
return items |
|
|