from typing import Optional, Sequence import numpy as np from mmcv.transforms import to_tensor from mmcv.transforms.base import BaseTransform from mmdet.registry import TRANSFORMS from mmdet.structures import DetDataSample, TrackDataSample from mmdet.structures.bbox import BaseBoxes from mmengine.structures import InstanceData @TRANSFORMS.register_module(force=True) class PackMatchInputs(BaseTransform): """Pack the inputs data for the multi object tracking and video instance segmentation. All the information of images are packed to ``inputs``. All the information except images are packed to ``data_samples``. In order to get the original annotaiton and meta info, we add `instances` key into meta keys. Args: meta_keys (Sequence[str]): Meta keys to be collected in ``data_sample.metainfo``. Defaults to None. default_meta_keys (tuple): Default meta keys. Defaults to ('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor', 'flip', 'flip_direction', 'frame_id', 'is_video_data', 'video_id', 'video_length', 'instances'). """ mapping_table = { "gt_bboxes": "bboxes", "gt_bboxes_labels": "labels", "gt_masks": "masks", "gt_instances_ids": "instances_ids", } def __init__( self, meta_keys: Optional[dict] = None, default_meta_keys: tuple = ( "img_id", "img_path", "ori_shape", "img_shape", "scale_factor", "flip", "flip_direction", "frame_id", "video_id", "video_length", "ori_video_length", "instances", ), ): self.meta_keys = default_meta_keys if meta_keys is not None: if isinstance(meta_keys, str): meta_keys = (meta_keys,) else: assert isinstance(meta_keys, tuple), "meta_keys must be str or tuple" self.meta_keys += meta_keys def transform(self, results: dict) -> dict: """Method to pack the input data. Args: results (dict): Result dict from the data pipeline. Returns: dict: - 'inputs' (dict[Tensor]): The forward data of models. - 'data_samples' (obj:`TrackDataSample`): The annotation info of the samples. """ packed_results = dict() packed_results["inputs"] = dict() # 1. Pack images if "img" in results: imgs = results["img"] imgs = np.stack(imgs, axis=0) # imgs = imgs.transpose(0, 3, 1, 2) if not imgs.flags.c_contiguous: imgs = np.ascontiguousarray(imgs.transpose(0, 3, 1, 2)) imgs = to_tensor(imgs) else: imgs = to_tensor(imgs).permute(0, 3, 1, 2).contiguous() packed_results["inputs"] = imgs # 2. Pack InstanceData if "gt_ignore_flags" in results: gt_ignore_flags_list = results["gt_ignore_flags"] valid_idx_list, ignore_idx_list = [], [] for gt_ignore_flags in gt_ignore_flags_list: valid_idx = np.where(gt_ignore_flags == 0)[0] ignore_idx = np.where(gt_ignore_flags == 1)[0] valid_idx_list.append(valid_idx) ignore_idx_list.append(ignore_idx) assert "img_id" in results, "'img_id' must contained in the results " "for counting the number of images" num_imgs = len(results["img_id"]) instance_data_list = [InstanceData() for _ in range(num_imgs)] ignore_instance_data_list = [InstanceData() for _ in range(num_imgs)] for key in self.mapping_table.keys(): if key not in results: continue if key == "gt_masks": mapped_key = self.mapping_table[key] gt_masks_list = results[key] if "gt_ignore_flags" in results: for i, gt_mask in enumerate(gt_masks_list): valid_idx, ignore_idx = valid_idx_list[i], ignore_idx_list[i] instance_data_list[i][mapped_key] = gt_mask[valid_idx] ignore_instance_data_list[i][mapped_key] = gt_mask[ignore_idx] else: for i, gt_mask in enumerate(gt_masks_list): instance_data_list[i][mapped_key] = gt_mask elif isinstance(results[key][0], BaseBoxes): mapped_key = self.mapping_table[key] gt_bboxes_list = results[key] if "gt_ignore_flags" in results: for i, gt_bbox in enumerate(gt_bboxes_list): gt_bbox = gt_bbox.tensor valid_idx, ignore_idx = valid_idx_list[i], ignore_idx_list[i] instance_data_list[i][mapped_key] = gt_bbox[valid_idx] ignore_instance_data_list[i][mapped_key] = gt_bbox[ignore_idx] else: anns_list = results[key] if "gt_ignore_flags" in results: for i, ann in enumerate(anns_list): valid_idx, ignore_idx = valid_idx_list[i], ignore_idx_list[i] instance_data_list[i][self.mapping_table[key]] = to_tensor( ann[valid_idx] ) ignore_instance_data_list[i][ self.mapping_table[key] ] = to_tensor(ann[ignore_idx]) else: for i, ann in enumerate(anns_list): instance_data_list[i][self.mapping_table[key]] = to_tensor(ann) det_data_samples_list = [] for i in range(num_imgs): det_data_sample = DetDataSample() det_data_sample.gt_instances = instance_data_list[i] det_data_sample.ignored_instances = ignore_instance_data_list[i] det_data_samples_list.append(det_data_sample) # 3. Pack metainfo for key in self.meta_keys: if key not in results: continue img_metas_list = results[key] for i, img_meta in enumerate(img_metas_list): det_data_samples_list[i].set_metainfo({f"{key}": img_meta}) track_data_sample = TrackDataSample() track_data_sample.video_data_samples = det_data_samples_list if "key_frame_flags" in results: key_frame_flags = np.asarray(results["key_frame_flags"]) key_frames_inds = np.where(key_frame_flags)[0].tolist() ref_frames_inds = np.where(~key_frame_flags)[0].tolist() track_data_sample.set_metainfo(dict(key_frames_inds=key_frames_inds)) track_data_sample.set_metainfo(dict(ref_frames_inds=ref_frames_inds)) packed_results["data_samples"] = track_data_sample return packed_results def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f"meta_keys={self.meta_keys}, " repr_str += f"default_meta_keys={self.default_meta_keys})" return repr_str