# Copyright (c) OpenMMLab. All rights reserved. import os import torch from scipy.optimize import linear_sum_assignment from torch import Tensor import torch.nn.functional as F from mmdet.registry import MODELS from mmdet.structures import SampleList, TrackDataSample from seg.models.detectors import Mask2formerVideo from seg.models.utils import mask_pool BACKBONE_BATCH = 50 def video_split(total, tube_size, overlap=0): assert tube_size > overlap total -= overlap tube_size -= overlap if total % tube_size == 0: splits = total // tube_size else: splits = (total // tube_size) + 1 ind_list = [] for i in range(splits): ind_list.append((i + 1) * tube_size) diff = ind_list[-1] - total # currently only supports diff < splits if diff < splits: for i in range(diff): ind_list[splits - 1 - i] -= diff - i else: ind_list[splits - 1] -= diff assert ind_list[splits - 1] > 0 print("Warning: {} / {}".format(total, tube_size)) for idx in range(len(ind_list)): ind_list[idx] += overlap return ind_list def match_from_embeds(tgt_embds, cur_embds): cur_embds = cur_embds / cur_embds.norm(dim=-1, keepdim=True) tgt_embds = tgt_embds / tgt_embds.norm(dim=-1, keepdim=True) cos_sim = torch.bmm(cur_embds, tgt_embds.transpose(1, 2)) cost_embd = 1 - cos_sim C = 1.0 * cost_embd C = C.cpu() indices = [] for i in range(len(cur_embds)): indice = linear_sum_assignment(C[i].transpose(0, 1)) # target x current indice = indice[1] # permutation that makes current aligns to target indices.append(indice) return indices @MODELS.register_module() class Mask2formerVideoMinVIS(Mask2formerVideo): r"""Implementation of `Per-Pixel Classification is NOT All You Need for Semantic Segmentation `_.""" OVERLAPPING = None def __init__(self, *args, clip_size=6, clip_size_small=3, whole_clip_thr=0, small_clip_thr=12, overlap=0, **kwargs, ): super().__init__(*args, **kwargs) self.clip_size = clip_size self.clip_size_small = clip_size_small self.overlap = overlap self.whole_clip_thr = whole_clip_thr self.small_clip_thr = small_clip_thr def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True) -> SampleList: """Predict results from a batch of inputs and data samples with post- processing. Args: batch_inputs (Tensor): Inputs with shape (N, C, H, W). batch_data_samples (List[:obj:`DetDataSample`]): The Data Samples. It usually includes information such as `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. rescale (bool): Whether to rescale the results. Defaults to True. Returns: list[:obj:`DetDataSample`]: Detection results of the input images. Each DetDataSample usually contain 'pred_instances' and `pred_panoptic_seg`. And the ``pred_instances`` usually contains following keys. - scores (Tensor): Classification scores, has a shape (num_instance, ) - labels (Tensor): Labels of bboxes, has a shape (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - masks (Tensor): Has a shape (num_instances, H, W). And the ``pred_panoptic_seg`` contains the following key - sem_seg (Tensor): panoptic segmentation mask, has a shape (1, h, w). """ assert isinstance(batch_data_samples[0], TrackDataSample) bs, num_frames, three, h, w = batch_inputs.shape assert three == 3, "Only supporting images with 3 channels." if num_frames <= self.whole_clip_thr: return super().predict(batch_inputs, batch_data_samples, rescale) device = batch_inputs.device if num_frames > self.small_clip_thr: tube_inds = video_split(num_frames, self.clip_size, self.overlap) else: tube_inds = video_split(num_frames, self.clip_size_small, self.overlap) if num_frames > BACKBONE_BATCH: feat_bins = [[], [], [], []] num_clip = num_frames // BACKBONE_BATCH + 1 step_size = num_frames // num_clip + 1 for i in range(num_clip): start = i * step_size end = min(num_frames, (i + 1) * step_size) inputs = batch_inputs[:, start:end].reshape( (bs * (end - start), three, h, w)) _feats = self.extract_feat(inputs) assert len(_feats) == 4 for idx, item in enumerate(_feats): feat_bins[idx].append(item.to('cpu')) feats = [] for item in feat_bins: feat = torch.cat(item, dim=0) assert feat.size(0) == bs * num_frames, "{} vs {}".format(feat.size(0), bs * num_frames) feats.append(feat) else: x = batch_inputs.reshape((bs * num_frames, three, h, w)) feats = self.extract_feat(x) assert len(feats[0]) == bs * num_frames del batch_inputs ind_pre = 0 cls_list = [] mask_list = [] query_list = [] iou_list = [] flag = False for ind in tube_inds: tube_feats = [itm[ind_pre:ind].to(device=device) for itm in feats] tube_data_samples = [TrackDataSample(video_data_samples=itm[ind_pre:ind]) for itm in batch_data_samples] _mask_cls_results, _mask_pred_results, _query_feat, _iou_results = \ self.panoptic_head.predict(tube_feats, tube_data_samples, return_query=True) cls_list.append(_mask_cls_results) if not flag: mask_list.append(_mask_pred_results.cpu()) flag = True else: mask_list.append(_mask_pred_results[:, self.overlap:].cpu()) query_list.append(_query_feat.cpu()) iou_list.append(_iou_results) ind_pre = ind ind_pre -= self.overlap num_tubes = len(tube_inds) out_cls = [cls_list[0]] out_mask = [mask_list[0]] out_embed = [query_list[0]] ious = [iou_list[0]] for i in range(1, num_tubes): indices = match_from_embeds(out_embed[-1], query_list[i]) indices = indices[0] # since bs == 1 out_cls.append(cls_list[i][:, indices]) out_mask.append(mask_list[i][:, indices]) out_embed.append(query_list[i][:, indices]) ious.append(iou_list[i][:, indices]) del mask_list del out_embed mask_cls_results = sum(out_cls) / num_tubes mask_pred_results = torch.cat(out_mask, dim=2) iou_results = sum(ious) / num_tubes if self.OVERLAPPING is not None: assert len(self.OVERLAPPING) == self.num_classes mask_cls_results = self.open_voc_inference(feats, mask_cls_results, mask_pred_results) del feats mask_cls_results = mask_cls_results.to(device='cpu') iou_results = iou_results.to(device='cpu') id_assigner = [{} for _ in range(bs)] for frame_id in range(num_frames): results_list_img = self.panoptic_fusion_head.predict( mask_cls_results, mask_pred_results[:, :, frame_id], [batch_data_samples[idx][frame_id] for idx in range(bs)], iou_results=iou_results, rescale=rescale ) if frame_id == 0 and 'pro_results' in results_list_img[0]: for batch_id in range(bs): mask = results_list_img[batch_id]['pro_results'].to(dtype=torch.int32) mask_gt = torch.tensor(batch_data_samples[batch_id][frame_id].gt_instances.masks.masks, dtype=torch.int32) a, b = mask.flatten(1), mask_gt.flatten(1) intersection = torch.einsum('nc,mc->nm', a, b) union = (a[:, None] + b[None]).clamp(min=0, max=1).sum(-1) iou_cost = intersection / union a_indices, b_indices = linear_sum_assignment(-iou_cost.numpy()) for a_ind, b_ind in zip(a_indices, b_indices): id_assigner[batch_id][a_ind] = batch_data_samples[batch_id][frame_id].gt_instances.instances_ids[b_ind].item() if 'pro_results' in results_list_img[0]: h, w = results_list_img[batch_id]['pro_results'].shape[-2:] seg_map = torch.full((h, w), 0, dtype=torch.int32, device='cpu') for ind in id_assigner[batch_id]: seg_map[results_list_img[batch_id]['pro_results'][ind]] = id_assigner[batch_id][ind] results_list_img[batch_id]['pro_results'] = seg_map.cpu().numpy() _ = self.add_track_pred_to_datasample( [batch_data_samples[idx][frame_id] for idx in range(bs)], results_list_img ) results = batch_data_samples return results def open_voc_inference(self, feats, mask_cls_results, mask_pred_results): if len(mask_pred_results.shape) == 5: batch_size = mask_cls_results.shape[0] num_frames = mask_pred_results.shape[2] mask_pred_results = mask_pred_results.permute(0, 2, 1, 3, 4).flatten(0, 1) else: batch_size = mask_cls_results.shape[0] num_frames = 0 clip_feat = self.backbone.get_clip_feature(feats[-1]).to(device=mask_cls_results.device) clip_feat_mask = F.interpolate( mask_pred_results, size=clip_feat.shape[-2:], mode='bilinear', align_corners=False ).to(device=mask_cls_results.device) if num_frames > 0: clip_feat_mask = clip_feat_mask.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) clip_feat = clip_feat.unflatten(0, (batch_size, num_frames)).permute(0, 2, 1, 3, 4).flatten(2, 3) instance_feat = mask_pool(clip_feat, clip_feat_mask) instance_feat = self.backbone.forward_feat(instance_feat) clip_logit = self.panoptic_head.forward_logit(instance_feat) clip_logit = clip_logit[..., :-1] query_logit = mask_cls_results[..., :-1] clip_logit = clip_logit.softmax(-1) query_logit = query_logit.softmax(-1) overlapping_mask = torch.tensor(self.OVERLAPPING, dtype=torch.float32, device=clip_logit.device) valid_masking = ((clip_feat_mask > 0).to(dtype=torch.float32).flatten(-2).sum(-1) > 0).to( torch.float32)[..., None] alpha = torch.ones_like(clip_logit) * self.alpha * valid_masking beta = torch.ones_like(clip_logit) * self.beta * valid_masking cls_logits_seen = ( (query_logit ** (1 - alpha) * clip_logit ** alpha).log() * overlapping_mask ) cls_logits_unseen = ( (query_logit ** (1 - beta) * clip_logit ** beta).log() * (1 - overlapping_mask) ) cls_results = cls_logits_seen + cls_logits_unseen is_void_prob = F.softmax(mask_cls_results, dim=-1)[..., -1:] mask_cls_results = torch.cat([ cls_results.softmax(-1) * (1.0 - is_void_prob), is_void_prob], dim=-1) mask_cls_results = torch.log(mask_cls_results + 1e-8) return mask_cls_results