# Copyright (c) OpenMMLab. All rights reserved. import numpy as np from mmcv.parallel import DataContainer as DC from mmseg.datasets.builder import PIPELINES from mmseg.datasets.pipelines.formatting import to_tensor @PIPELINES.register_module(force=True) class DefaultFormatBundle(object): """Default formatting bundle. It simplifies the pipeline of formatting common fields, including "img" and "gt_semantic_seg". These fields are formatted as follows. - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, (3)to DataContainer (stack=True) """ def __call__(self, results): """Call function to transform and format common fields in results. Args: results (dict): Result dict contains the data to convert. Returns: dict: The result dict contains the data that is formatted with default bundle. """ if 'img' in results: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) img = np.ascontiguousarray(img.transpose(2, 0, 1)) results['img'] = DC(to_tensor(img), stack=True) if 'gt_semantic_seg' in results: # convert to long results['gt_semantic_seg'] = DC(to_tensor( results['gt_semantic_seg'][None, ...].astype(np.int64)), stack=True) if 'gt_masks' in results: results['gt_masks'] = DC(to_tensor(results['gt_masks'])) if 'gt_labels' in results: results['gt_labels'] = DC(to_tensor(results['gt_labels'])) return results def __repr__(self): return self.__class__.__name__ @PIPELINES.register_module() class ToMask(object): """Transfer gt_semantic_seg to binary mask and generate gt_labels.""" def __init__(self, ignore_index=255): self.ignore_index = ignore_index def __call__(self, results): gt_semantic_seg = results['gt_semantic_seg'] gt_labels = np.unique(gt_semantic_seg) # remove ignored region gt_labels = gt_labels[gt_labels != self.ignore_index] gt_masks = [] for class_id in gt_labels: gt_masks.append(gt_semantic_seg == class_id) if len(gt_masks) == 0: # Some image does not have annotation (all ignored) gt_masks = np.empty((0, ) + results['pad_shape'][:-1], dtype=np.int64) gt_labels = np.empty((0, ), dtype=np.int64) else: gt_masks = np.asarray(gt_masks, dtype=np.int64) gt_labels = np.asarray(gt_labels, dtype=np.int64) results['gt_labels'] = gt_labels results['gt_masks'] = gt_masks return results def __repr__(self): return self.__class__.__name__ + \ f'(ignore_index={self.ignore_index})'