|
|
|
import numpy as np |
|
from mmcv.parallel import DataContainer as DC |
|
from mmseg.datasets.builder import PIPELINES |
|
from mmseg.datasets.pipelines.formatting import to_tensor |
|
|
|
|
|
@PIPELINES.register_module(force=True) |
|
class DefaultFormatBundle(object): |
|
"""Default formatting bundle. |
|
|
|
It simplifies the pipeline of formatting common fields, including "img" |
|
and "gt_semantic_seg". These fields are formatted as follows. |
|
|
|
- img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) |
|
- gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, |
|
(3)to DataContainer (stack=True) |
|
""" |
|
def __call__(self, results): |
|
"""Call function to transform and format common fields in results. |
|
|
|
Args: |
|
results (dict): Result dict contains the data to convert. |
|
|
|
Returns: |
|
dict: The result dict contains the data that is formatted with |
|
default bundle. |
|
""" |
|
|
|
if 'img' in results: |
|
img = results['img'] |
|
if len(img.shape) < 3: |
|
img = np.expand_dims(img, -1) |
|
img = np.ascontiguousarray(img.transpose(2, 0, 1)) |
|
results['img'] = DC(to_tensor(img), stack=True) |
|
if 'gt_semantic_seg' in results: |
|
|
|
results['gt_semantic_seg'] = DC(to_tensor( |
|
results['gt_semantic_seg'][None, ...].astype(np.int64)), |
|
stack=True) |
|
if 'gt_masks' in results: |
|
results['gt_masks'] = DC(to_tensor(results['gt_masks'])) |
|
if 'gt_labels' in results: |
|
results['gt_labels'] = DC(to_tensor(results['gt_labels'])) |
|
|
|
return results |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ |
|
|
|
|
|
@PIPELINES.register_module() |
|
class ToMask(object): |
|
"""Transfer gt_semantic_seg to binary mask and generate gt_labels.""" |
|
def __init__(self, ignore_index=255): |
|
self.ignore_index = ignore_index |
|
|
|
def __call__(self, results): |
|
gt_semantic_seg = results['gt_semantic_seg'] |
|
gt_labels = np.unique(gt_semantic_seg) |
|
|
|
gt_labels = gt_labels[gt_labels != self.ignore_index] |
|
|
|
gt_masks = [] |
|
for class_id in gt_labels: |
|
gt_masks.append(gt_semantic_seg == class_id) |
|
|
|
if len(gt_masks) == 0: |
|
|
|
gt_masks = np.empty((0, ) + results['pad_shape'][:-1], dtype=np.int64) |
|
gt_labels = np.empty((0, ), dtype=np.int64) |
|
else: |
|
gt_masks = np.asarray(gt_masks, dtype=np.int64) |
|
gt_labels = np.asarray(gt_labels, dtype=np.int64) |
|
|
|
results['gt_labels'] = gt_labels |
|
results['gt_masks'] = gt_masks |
|
return results |
|
|
|
def __repr__(self): |
|
return self.__class__.__name__ + \ |
|
f'(ignore_index={self.ignore_index})' |
|
|