deeplab2 / model /post_processor /vip_deeplab.py
akhaliq3
spaces demo
506da10
# coding=utf-8
# Copyright 2021 The Deeplab2 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains functions to post-process ViP-DeepLab results."""
import numpy as np
def stitch_video_panoptic_prediction(
concat_panoptic: np.ndarray,
next_panoptic: np.ndarray,
label_divisor: int,
overlap_offset: int = 128,
combine_offset: int = 2 ** 32) -> np.ndarray:
"""The stitching algorithm in ViP-DeepLab.
This function stitches a pair of image panoptic predictions to form video
panoptic predictions by propagating instance IDs from concat_panoptic to
next_panoptic based on IoU matching.
Siyuan Qiao, Yukun Zhu, Hartwig Adam, Alan Yuille, and Liang-Chieh Chen.
"ViP-DeepLab: Learning Visual Perception with Depth-aware Video Panoptic
Segmentation." CVPR, 2021.
Args:
concat_panoptic: Panoptic prediction of the next frame by concatenating
it with the current frame.
next_panoptic: Panoptic prediction of the next frame.
label_divisor: An integer specifying the label divisor of the dataset.
overlap_offset: An integer offset to avoid overlap between the IDs in
next_panoptic and the propagated IDs from concat_panoptic.
combine_offset: An integer offset to combine concat and next panoptic.
Returns:
Panoptic prediction of the next frame with the instance IDs propragated
from the concatenated panoptic prediction.
"""
def _ids_to_counts(id_array: np.ndarray):
"""Given a numpy array, a mapping from each entry to its count."""
ids, counts = np.unique(id_array, return_counts=True)
return dict(zip(ids, counts))
new_panoptic = next_panoptic.copy()
# Increase the panoptic instance ID to avoid overlap.
new_category = new_panoptic // label_divisor
new_instance = new_panoptic % label_divisor
# We skip 0 which is reserved for crowd.
instance_mask = new_instance > 0
new_instance[instance_mask] = new_instance[instance_mask] + overlap_offset
new_panoptic = new_category * label_divisor + new_instance
# Pre-compute areas for all the segments.
concat_segment_areas = _ids_to_counts(concat_panoptic)
next_segment_areas = _ids_to_counts(next_panoptic)
# Combine concat_panoptic and next_panoptic.
intersection_id_array = (concat_panoptic.astype(np.int64) *
combine_offset + next_panoptic.astype(np.int64))
intersection_areas = _ids_to_counts(intersection_id_array)
# Compute IoU and sort them.
intersection_ious = []
for intersection_id, intersection_area in intersection_areas.items():
concat_panoptic_label = int(intersection_id // combine_offset)
next_panoptic_label = int(intersection_id % combine_offset)
concat_category_label = concat_panoptic_label // label_divisor
next_category_label = next_panoptic_label // label_divisor
if concat_category_label != next_category_label:
continue
concat_instance_label = concat_panoptic_label % label_divisor
next_instance_label = next_panoptic_label % label_divisor
# We skip 0 which is reserved for crowd.
if concat_instance_label == 0 or next_instance_label == 0:
continue
union = (
concat_segment_areas[concat_panoptic_label] +
next_segment_areas[next_panoptic_label] -
intersection_area)
iou = intersection_area / union
intersection_ious.append([
concat_panoptic_label, next_panoptic_label, iou])
intersection_ious = sorted(
intersection_ious, key=lambda e: e[2])
# Build mapping and inverse mapping. Two-way mapping guarantees 1-to-1
# matching.
map_concat_to_next = {}
map_next_to_concat = {}
for (concat_panoptic_label, next_panoptic_label,
iou) in intersection_ious:
map_concat_to_next[concat_panoptic_label] = next_panoptic_label
map_next_to_concat[next_panoptic_label] = concat_panoptic_label
# Match and propagate.
for (concat_panoptic_label,
next_panoptic_label) in map_concat_to_next.items():
if map_next_to_concat[next_panoptic_label] == concat_panoptic_label:
propagate_mask = next_panoptic == next_panoptic_label
new_panoptic[propagate_mask] = concat_panoptic_label
return new_panoptic