|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""This file contains functions to post-process MaX-DeepLab results.""" |
|
|
|
import functools |
|
from typing import List, Tuple, Dict, Text |
|
|
|
import tensorflow as tf |
|
|
|
from deeplab2 import common |
|
from deeplab2 import config_pb2 |
|
from deeplab2.data import dataset |
|
from deeplab2.model import utils |
|
|
|
|
|
def _get_transformer_class_prediction( |
|
transformer_class_probs: tf.Tensor, |
|
transformer_class_confidence_threshold: float |
|
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: |
|
"""Computes the transformer class prediction and confidence score. |
|
|
|
Args: |
|
transformer_class_probs: A tf.Tensor of shape [num_mask_slots, |
|
num_thing_stuff_classes + 1]. It is a pixel level logit scores where the |
|
num_mask_slots is the number of mask slots (for both thing classes and |
|
stuff classes) in MaX-DeepLab. The last channel indicates a `void` class. |
|
transformer_class_confidence_threshold: A float for thresholding the |
|
confidence of the transformer_class_probs. The panoptic mask slots with |
|
class confidence less than the threshold are filtered and not used for |
|
panoptic prediction. Only masks whose confidence is larger than the |
|
threshold are counted in num_detections. |
|
|
|
Returns: |
|
A tuple of: |
|
- the detected mask class prediction as float32 tf.Tensor of shape |
|
[num_detections]. |
|
- the detected mask indices as tf.Tensor of shape [num_detections]. |
|
- the number of detections as tf.Tensor of shape [1]. |
|
""" |
|
transformer_class_pred = tf.cast( |
|
tf.argmax(transformer_class_probs, axis=-1), tf.float32) |
|
transformer_class_confidence = tf.reduce_max( |
|
transformer_class_probs, axis=-1, keepdims=False) |
|
|
|
thresholded_mask = tf.cast( |
|
tf.greater_equal(transformer_class_confidence, |
|
transformer_class_confidence_threshold), tf.float32) |
|
transformer_class_confidence = (transformer_class_confidence |
|
* thresholded_mask) |
|
|
|
detected_mask_indices = tf.where(tf.greater(thresholded_mask, 0.5))[:, 0] |
|
detected_mask_class_pred = tf.gather( |
|
transformer_class_pred, detected_mask_indices) |
|
num_detections = tf.shape(detected_mask_indices)[0] |
|
return detected_mask_class_pred, detected_mask_indices, num_detections |
|
|
|
|
|
def _get_mask_id_and_semantic_maps( |
|
thing_class_ids: List[int], |
|
stuff_class_ids: List[int], |
|
pixel_space_mask_logits: tf.Tensor, |
|
transformer_class_probs: tf.Tensor, |
|
image_shape: List[int], |
|
pixel_confidence_threshold=0.4, |
|
transformer_class_confidence_threshold=0.7, |
|
pieces=1) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: |
|
"""Computes the pixel-level mask ID map and semantic map per image. |
|
|
|
Args: |
|
thing_class_ids: A List of integers of shape [num_thing_classes] containing |
|
thing class indices. |
|
stuff_class_ids: A List of integers of shape [num_thing_classes] containing |
|
stuff class indices. |
|
pixel_space_mask_logits: A tf.Tensor of shape [height, width, |
|
num_mask_slots]. It is a pixel level logit scores where the |
|
num_mask_slots is the number of mask slots (for both thing classes |
|
and stuff classes) in MaX-DeepLab. |
|
transformer_class_probs: A tf.Tensor of shape [num_mask_slots, |
|
num_thing_stuff_classes + 1]. It is a pixel level logit scores where the |
|
num_mask_slots is the number of mask slots (for both thing classes and |
|
stuff classes) in MaX-DeepLab. The last channel indicates a `void` class. |
|
image_shape: A list of integers specifying the [height, width] of input |
|
image. |
|
pixel_confidence_threshold: A float indicating a threshold for the pixel |
|
level softmax probability confidence of transformer mask logits. If less |
|
than the threshold, the pixel locations have confidence `0` in |
|
`confident_regions` output, and represent `void` (ignore) regions. |
|
transformer_class_confidence_threshold: A float for thresholding the |
|
confidence of the transformer_class_probs. The panoptic mask slots with |
|
class confidence less than the threshold are filtered and not used for |
|
panoptic prediction. |
|
pieces: An integer indicating the number of pieces in the piece-wise |
|
operation. When computing panpotic prediction and confident regions, the |
|
mask logits are divided width-wise into multiple pieces and processed |
|
piece-wise due to the GPU memory limit. Then, the piece-wise outputs are |
|
concatenated along the width into the original mask shape. Defaults to 1. |
|
|
|
Returns: |
|
A tuple of: |
|
- the mask ID prediction as tf.Tensor with shape [height, width]. |
|
- the semantic prediction as tf.Tensor with shape [height, width]. |
|
- the thing region mask as tf.Tensor with shape [height, width]. |
|
- the stuff region mask as tf.Tensor with shape [height, width]. |
|
|
|
Raises: |
|
ValueError: When input image's `width - 1` is not divisible by `pieces`. |
|
""" |
|
|
|
transformer_class_probs = transformer_class_probs[..., :-1] |
|
|
|
thing_stuff_class_ids = thing_class_ids + stuff_class_ids |
|
|
|
detected_mask_class_pred, detected_mask_indices, num_detections = ( |
|
_get_transformer_class_prediction(transformer_class_probs, |
|
transformer_class_confidence_threshold)) |
|
|
|
def _return_empty_mask_id_and_semantic_maps(): |
|
return ( |
|
tf.ones([image_shape[0], image_shape[1]], dtype=tf.int32), |
|
tf.zeros([image_shape[0], image_shape[1]], dtype=tf.int32), |
|
tf.zeros([image_shape[0], image_shape[1]], dtype=tf.float32), |
|
tf.zeros([image_shape[0], image_shape[1]], dtype=tf.float32)) |
|
|
|
|
|
def _generate_mask_id_and_semantic_maps(): |
|
output_mask_id_map = [] |
|
output_confident_region = [] |
|
logits_width = pixel_space_mask_logits.get_shape().as_list()[1] |
|
output_width = image_shape[1] |
|
|
|
if (output_width - 1) % pieces > 0: |
|
raise ValueError('`output_width - 1` must be divisible by `pieces`.') |
|
|
|
|
|
piece_output_width = (output_width - 1) // pieces + 1 |
|
|
|
for piece_id in range(pieces): |
|
piece_begin = (logits_width - 1) // pieces * piece_id |
|
|
|
|
|
piece_end = (logits_width - 1) // pieces * (piece_id + 1) + 1 |
|
piece_pixel_mask_logits = ( |
|
pixel_space_mask_logits[:, piece_begin:piece_end, :]) |
|
piece_pixel_mask_logits = tf.compat.v1.image.resize_bilinear( |
|
tf.expand_dims(piece_pixel_mask_logits, 0), |
|
(image_shape[0], piece_output_width), |
|
align_corners=True) |
|
piece_pixel_mask_logits = tf.squeeze(piece_pixel_mask_logits, axis=0) |
|
piece_detected_pixel_mask_logits = tf.gather( |
|
piece_pixel_mask_logits, detected_mask_indices, axis=-1) |
|
|
|
piece_max_logits = tf.reduce_max(piece_pixel_mask_logits, axis=-1) |
|
piece_detected_max_logits = tf.reduce_max( |
|
piece_detected_pixel_mask_logits, axis=-1) |
|
piece_detected_mask = tf.cast(tf.math.equal( |
|
piece_max_logits, piece_detected_max_logits), tf.float32) |
|
|
|
piece_pixel_confidence_map = tf.reduce_max( |
|
tf.nn.softmax(piece_detected_pixel_mask_logits, axis=-1), axis=-1) |
|
piece_confident_region = tf.cast( |
|
piece_pixel_confidence_map > pixel_confidence_threshold, tf.float32) |
|
piece_confident_region = piece_confident_region * piece_detected_mask |
|
piece_mask_id_map = tf.cast( |
|
tf.argmax(piece_detected_pixel_mask_logits, axis=-1), tf.int32) |
|
if piece_id == pieces - 1: |
|
output_mask_id_map.append(piece_mask_id_map) |
|
output_confident_region.append(piece_confident_region) |
|
else: |
|
output_mask_id_map.append(piece_mask_id_map[:, :-1]) |
|
output_confident_region.append(piece_confident_region[:, :-1]) |
|
|
|
mask_id_map = tf.concat(output_mask_id_map, axis=1) |
|
confident_region = tf.concat(output_confident_region, axis=1) |
|
mask_id_map_flat = tf.reshape(mask_id_map, [-1]) |
|
mask_id_semantic_map_flat = tf.gather( |
|
detected_mask_class_pred, mask_id_map_flat) |
|
mask_id_semantic_map = tf.reshape( |
|
mask_id_semantic_map_flat, [image_shape[0], image_shape[1]]) |
|
|
|
|
|
thing_mask = tf.cast(mask_id_semantic_map < len(thing_class_ids), |
|
tf.float32) * confident_region |
|
stuff_mask = tf.cast(mask_id_semantic_map >= len(thing_class_ids), |
|
tf.float32) * confident_region |
|
|
|
semantic_map = tf.gather( |
|
tf.convert_to_tensor(thing_stuff_class_ids), |
|
tf.cast(tf.round(mask_id_semantic_map_flat), tf.int32)) |
|
semantic_map = tf.reshape(semantic_map, [image_shape[0], image_shape[1]]) |
|
|
|
mask_id_map_plus_one = mask_id_map + 1 |
|
semantic_map = tf.cast(tf.round(semantic_map), tf.int32) |
|
return (mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask) |
|
|
|
mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask = tf.cond( |
|
tf.cast(num_detections, tf.float32) < tf.cast(0.5, tf.float32), |
|
_return_empty_mask_id_and_semantic_maps, |
|
_generate_mask_id_and_semantic_maps) |
|
|
|
return (mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask) |
|
|
|
|
|
def _filter_by_count(input_index_map: tf.Tensor, |
|
area_limit: int) -> Tuple[tf.Tensor, tf.Tensor]: |
|
"""Filters input index map by area limit threshold per index. |
|
|
|
Args: |
|
input_index_map: A float32 tf.Tensor of shape [batch, height, width]. |
|
area_limit: An integer specifying the number of pixels that each index |
|
regions need to have at least. If not over the limit, the index regions |
|
are masked (zeroed) out. |
|
|
|
Returns: |
|
masked input_index_map: A tf.Tensor with shape [batch, height, width], |
|
masked by the area_limit threshold. |
|
mask: A tf.Tensor with shape [batch, height, width]. It is a pixel-level |
|
mask with 1. indicating the regions over the area limit, and 0. otherwise. |
|
""" |
|
batch_size = tf.shape(input_index_map)[0] |
|
index_map = tf.cast(tf.round(input_index_map), tf.int32) |
|
index_map_flat = tf.reshape(index_map, [batch_size, -1]) |
|
counts = tf.math.bincount(index_map_flat, axis=-1) |
|
counts_map = tf.gather(counts, index_map_flat, batch_dims=1) |
|
counts_map = tf.reshape(counts_map, tf.shape(index_map)) |
|
|
|
mask = tf.cast( |
|
tf.cast(counts_map, tf.float32) > tf.cast(area_limit - 0.5, tf.float32), |
|
input_index_map.dtype) |
|
return input_index_map * mask, mask |
|
|
|
|
|
def _merge_mask_id_and_semantic_maps( |
|
mask_id_maps_plus_one: tf.Tensor, |
|
semantic_maps: tf.Tensor, |
|
thing_masks: tf.Tensor, |
|
stuff_masks: tf.Tensor, |
|
void_label: int, |
|
label_divisor: int, |
|
thing_area_limit: int, |
|
stuff_area_limit: int,) -> tf.Tensor: |
|
"""Merges mask_id maps and semantic_maps to obtain panoptic segmentation. |
|
|
|
Args: |
|
mask_id_maps_plus_one: A tf.Tensor of shape [batch, height, width]. |
|
semantic_maps: A tf.Tensor of shape [batch, height, width]. |
|
thing_masks: A float32 tf.Tensor of shape [batch, height, width] containing |
|
masks with 1. at thing regions, 0. otherwise. |
|
stuff_masks: A float32 tf.Tensor of shape [batch, height, width] containing |
|
masks with 1. at thing regions, 0. otherwise. |
|
void_label: An integer specifying the void label. |
|
label_divisor: An integer specifying the label divisor of the dataset. |
|
thing_area_limit: An integer specifying the number of pixels that thing |
|
regions need to have at least. The thing region will be included in the |
|
panoptic prediction, only if its area is larger than the limit; otherwise, |
|
it will be re-assigned as void_label. |
|
stuff_area_limit: An integer specifying the number of pixels that stuff |
|
regions need to have at least. The stuff region will be included in the |
|
panoptic prediction, only if its area is larger than the limit; otherwise, |
|
it will be re-assigned as void_label. |
|
|
|
Returns: |
|
panoptic_maps: A tf.Tensor with shape [batch, height, width]. |
|
|
|
""" |
|
thing_mask_id_maps_plus_one = (tf.cast(mask_id_maps_plus_one, tf.float32) |
|
* thing_masks) |
|
|
|
|
|
|
|
semantic_maps_plus_one = semantic_maps + 1 |
|
tf.debugging.assert_less( |
|
tf.reduce_sum(thing_masks * stuff_masks), 0.5, |
|
message='thing_masks and stuff_masks must be mutually exclusive.') |
|
|
|
thing_semantic_maps = (tf.cast(semantic_maps_plus_one, tf.float32) |
|
* thing_masks) |
|
stuff_semantic_maps = (tf.cast(semantic_maps_plus_one, tf.float32) |
|
* stuff_masks) |
|
|
|
|
|
stuff_semantic_maps, _ = _filter_by_count( |
|
stuff_semantic_maps, stuff_area_limit) |
|
|
|
thing_mask_id_maps_plus_one, mask_id_count_filter_mask = _filter_by_count( |
|
thing_mask_id_maps_plus_one, thing_area_limit) |
|
thing_semantic_maps = thing_semantic_maps * mask_id_count_filter_mask |
|
|
|
|
|
|
|
|
|
semantic_maps_new = thing_semantic_maps + stuff_semantic_maps - 1.0 |
|
semantic_maps_new = (tf.cast(semantic_maps_new < -0.5, tf.float32) |
|
* tf.cast(void_label + 1, tf.float32) |
|
+ semantic_maps_new) |
|
panoptic_maps = (semantic_maps_new * label_divisor |
|
+ thing_mask_id_maps_plus_one) |
|
panoptic_maps = tf.cast(tf.round(panoptic_maps), tf.int32) |
|
return panoptic_maps |
|
|
|
|
|
def _get_panoptic_predictions( |
|
pixel_space_mask_logits: tf.Tensor, |
|
transformer_class_logits: tf.Tensor, |
|
thing_class_ids: List[int], |
|
void_label: int, |
|
label_divisor: int, |
|
thing_area_limit: int, |
|
stuff_area_limit: int, |
|
image_shape: List[int], |
|
pixel_confidence_threshold=0.4, |
|
transformer_class_confidence_threshold=0.7, |
|
pieces=1) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: |
|
"""Computes the pixel-level panoptic, mask ID, and semantic maps. |
|
|
|
Args: |
|
pixel_space_mask_logits: A tf.Tensor of shape [batch, strided_height, |
|
strided_width, num_mask_slots]. It is a pixel level logit scores where the |
|
num_mask_slots is the number of mask slots (for both thing classes |
|
and stuff classes) in MaX-DeepLab. |
|
transformer_class_logits: A tf.Tensor of shape [batch, num_mask_slots, |
|
num_thing_stuff_classes + 1]. It is a pixel level logit scores where the |
|
num_mask_slots is the number of mask slots (for both thing classes and |
|
stuff classes) in MaX-DeepLab. The last channel indicates a `void` class. |
|
thing_class_ids: A List of integers of shape [num_thing_classes] containing |
|
thing class indices. |
|
void_label: An integer specifying the void label. |
|
label_divisor: An integer specifying the label divisor of the dataset. |
|
thing_area_limit: An integer specifying the number of pixels that thing |
|
regions need to have at least. The thing region will be included in the |
|
panoptic prediction, only if its area is larger than the limit; otherwise, |
|
it will be re-assigned as void_label. |
|
stuff_area_limit: An integer specifying the number of pixels that stuff |
|
regions need to have at least. The stuff region will be included in the |
|
panoptic prediction, only if its area is larger than the limit; otherwise, |
|
it will be re-assigned as void_label. |
|
image_shape: A list of integers specifying the [height, width] of input |
|
image. |
|
pixel_confidence_threshold: A float indicating a threshold for the pixel |
|
level softmax probability confidence of transformer mask logits. If less |
|
than the threshold, the pixel locations have confidence `0` in |
|
`confident_regions` output, and represent `void` (ignore) regions. |
|
transformer_class_confidence_threshold: A float for thresholding the |
|
confidence of the transformer_class_probs. The panoptic mask slots with |
|
class confidence less than the threshold are filtered and not used for |
|
panoptic prediction. |
|
pieces: An integer indicating the number of pieces in the piece-wise |
|
operation in `_get_mask_id_and_semantic_maps`. When computing panoptic |
|
prediction and confident regions, the mask logits are divided width-wise |
|
into multiple pieces and processed piece-wise due to the GPU memory limit. |
|
Then, the piece-wise outputs are concatenated along the width into the |
|
original mask shape. Defaults to 1. |
|
|
|
Returns: |
|
A tuple of: |
|
- the panoptic prediction as tf.Tensor with shape [batch, height, width]. |
|
- the mask ID prediction as tf.Tensor with shape [batch, height, width]. |
|
- the semantic prediction as tf.Tensor with shape [batch, height, width]. |
|
""" |
|
transformer_class_probs = tf.nn.softmax(transformer_class_logits, axis=-1) |
|
batch_size = tf.shape(transformer_class_logits)[0] |
|
|
|
num_thing_stuff_classes = ( |
|
transformer_class_logits.get_shape().as_list()[-1] - 1) |
|
|
|
stuff_class_ids = utils.get_stuff_class_ids( |
|
num_thing_stuff_classes, thing_class_ids, void_label) |
|
|
|
mask_id_map_plus_one_lists = tf.TensorArray( |
|
tf.int32, size=batch_size, dynamic_size=False) |
|
semantic_map_lists = tf.TensorArray( |
|
tf.int32, size=batch_size, dynamic_size=False) |
|
thing_mask_lists = tf.TensorArray( |
|
tf.float32, size=batch_size, dynamic_size=False) |
|
stuff_mask_lists = tf.TensorArray( |
|
tf.float32, size=batch_size, dynamic_size=False) |
|
for i in tf.range(batch_size): |
|
mask_id_map_plus_one, semantic_map, thing_mask, stuff_mask = ( |
|
_get_mask_id_and_semantic_maps( |
|
thing_class_ids, stuff_class_ids, |
|
pixel_space_mask_logits[i, ...], transformer_class_probs[i, ...], |
|
image_shape, pixel_confidence_threshold, |
|
transformer_class_confidence_threshold, pieces) |
|
) |
|
mask_id_map_plus_one_lists = mask_id_map_plus_one_lists.write( |
|
i, mask_id_map_plus_one) |
|
semantic_map_lists = semantic_map_lists.write(i, semantic_map) |
|
thing_mask_lists = thing_mask_lists.write(i, thing_mask) |
|
stuff_mask_lists = stuff_mask_lists.write(i, stuff_mask) |
|
|
|
mask_id_maps_plus_one = mask_id_map_plus_one_lists.stack() |
|
semantic_maps = semantic_map_lists.stack() |
|
thing_masks = thing_mask_lists.stack() |
|
stuff_masks = stuff_mask_lists.stack() |
|
|
|
panoptic_maps = _merge_mask_id_and_semantic_maps( |
|
mask_id_maps_plus_one, semantic_maps, thing_masks, stuff_masks, |
|
void_label, label_divisor, thing_area_limit, stuff_area_limit) |
|
return panoptic_maps, mask_id_maps_plus_one, semantic_maps |
|
|
|
|
|
class PostProcessor(tf.keras.layers.Layer): |
|
"""This class contains code of a MaX-DeepLab post-processor.""" |
|
|
|
def __init__( |
|
self, |
|
config: config_pb2.ExperimentOptions, |
|
dataset_descriptor: dataset.DatasetDescriptor): |
|
"""Initializes a MaX-DeepLab post-processor. |
|
|
|
Args: |
|
config: A config_pb2.ExperimentOptions configuration. |
|
dataset_descriptor: A dataset.DatasetDescriptor. |
|
""" |
|
super(PostProcessor, self).__init__(name='PostProcessor') |
|
self._post_processor = functools.partial( |
|
_get_panoptic_predictions, |
|
thing_class_ids=list(dataset_descriptor.class_has_instances_list), |
|
void_label=dataset_descriptor.ignore_label, |
|
label_divisor=dataset_descriptor.panoptic_label_divisor, |
|
thing_area_limit=config.evaluator_options.thing_area_limit, |
|
stuff_area_limit=config.evaluator_options.stuff_area_limit, |
|
image_shape=list(config.eval_dataset_options.crop_size), |
|
transformer_class_confidence_threshold=config.evaluator_options |
|
.transformer_class_confidence_threshold, |
|
pixel_confidence_threshold=config.evaluator_options |
|
.pixel_confidence_threshold, |
|
pieces=1) |
|
|
|
def call(self, result_dict: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]: |
|
"""Performs the post-processing given model predicted results. |
|
|
|
Args: |
|
result_dict: A dictionary of tf.Tensor containing model results. The dict |
|
has to contain |
|
- common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY, |
|
- common.PRED_TRANSFORMER_CLASS_LOGITS_KEY, |
|
|
|
Returns: |
|
The post-processed dict of tf.Tensor, containing the following: |
|
- common.PRED_SEMANTIC_KEY, |
|
- common.PRED_INSTANCE_KEY, |
|
- common.PRED_PANOPTIC_KEY, |
|
""" |
|
processed_dict = {} |
|
(processed_dict[common.PRED_PANOPTIC_KEY], |
|
processed_dict[common.PRED_INSTANCE_KEY], |
|
processed_dict[common.PRED_SEMANTIC_KEY] |
|
) = self._post_processor( |
|
result_dict[common.PRED_PIXEL_SPACE_MASK_LOGITS_KEY], |
|
result_dict[common.PRED_TRANSFORMER_CLASS_LOGITS_KEY]) |
|
return processed_dict |
|
|