|
|
|
|
|
import copy |
|
import numpy as np |
|
import torch |
|
|
|
from detectron2.config import configurable |
|
|
|
from detectron2.data import detection_utils as utils |
|
from detectron2.data import transforms as T |
|
from detectron2.data.dataset_mapper import DatasetMapper |
|
from .custom_build_augmentation import build_custom_augmentation |
|
from itertools import compress |
|
import logging |
|
|
|
__all__ = ["CustomDatasetMapper", "ObjDescription"] |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CustomDatasetMapper(DatasetMapper): |
|
@configurable |
|
def __init__(self, is_train: bool, |
|
dataset_augs=[], |
|
**kwargs): |
|
if is_train: |
|
self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs] |
|
super().__init__(is_train, **kwargs) |
|
|
|
@classmethod |
|
def from_config(cls, cfg, is_train: bool = True): |
|
ret = super().from_config(cfg, is_train) |
|
if is_train: |
|
if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop': |
|
dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE |
|
dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE |
|
ret['dataset_augs'] = [ |
|
build_custom_augmentation(cfg, True, scale, size) \ |
|
for scale, size in zip(dataset_scales, dataset_sizes)] |
|
else: |
|
assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge' |
|
min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES |
|
max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES |
|
ret['dataset_augs'] = [ |
|
build_custom_augmentation( |
|
cfg, True, min_size=mi, max_size=ma) \ |
|
for mi, ma in zip(min_sizes, max_sizes)] |
|
else: |
|
ret['dataset_augs'] = [] |
|
|
|
return ret |
|
|
|
def __call__(self, dataset_dict): |
|
dataset_dict_out = self.prepare_data(dataset_dict) |
|
|
|
|
|
retry = 0 |
|
while (dataset_dict_out["image"].shape[1] < 32 or dataset_dict_out["image"].shape[2] < 32): |
|
retry += 1 |
|
if retry == 100: |
|
logger.info('Retry 100 times for augmentation. Make sure the image size is not too small.') |
|
logger.info('Find image information below') |
|
logger.info(dataset_dict) |
|
dataset_dict_out = self.prepare_data(dataset_dict) |
|
|
|
return dataset_dict_out |
|
|
|
def prepare_data(self, dataset_dict_in): |
|
dataset_dict = copy.deepcopy(dataset_dict_in) |
|
if 'file_name' in dataset_dict: |
|
ori_image = utils.read_image( |
|
dataset_dict["file_name"], format=self.image_format) |
|
else: |
|
ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]] |
|
ori_image = utils._apply_exif_orientation(ori_image) |
|
ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format) |
|
utils.check_image_size(dataset_dict, ori_image) |
|
|
|
aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=None) |
|
if self.is_train: |
|
transforms = \ |
|
self.dataset_augs[dataset_dict['dataset_source']](aug_input) |
|
else: |
|
transforms = self.augmentations(aug_input) |
|
image, sem_seg_gt = aug_input.image, aug_input.sem_seg |
|
|
|
image_shape = image.shape[:2] |
|
dataset_dict["image"] = torch.as_tensor( |
|
np.ascontiguousarray(image.transpose(2, 0, 1))) |
|
|
|
if not self.is_train: |
|
|
|
dataset_dict.pop("annotations", None) |
|
return dataset_dict |
|
|
|
if "annotations" in dataset_dict: |
|
if len(dataset_dict["annotations"]) > 0: |
|
object_descriptions = [an['object_description'] for an in dataset_dict["annotations"]] |
|
else: |
|
object_descriptions = [] |
|
|
|
for anno in dataset_dict["annotations"]: |
|
if not self.use_instance_mask: |
|
anno.pop("segmentation", None) |
|
if not self.use_keypoint: |
|
anno.pop("keypoints", None) |
|
|
|
all_annos = [ |
|
(utils.transform_instance_annotations( |
|
obj, transforms, image_shape, |
|
keypoint_hflip_indices=self.keypoint_hflip_indices, |
|
), obj.get("iscrowd", 0)) |
|
for obj in dataset_dict.pop("annotations") |
|
] |
|
annos = [ann[0] for ann in all_annos if ann[1] == 0] |
|
instances = utils.annotations_to_instances( |
|
annos, image_shape, mask_format=self.instance_mask_format |
|
) |
|
|
|
instances.gt_object_descriptions = ObjDescription(object_descriptions) |
|
|
|
del all_annos |
|
if self.recompute_boxes: |
|
instances.gt_boxes = instances.gt_masks.get_bounding_boxes() |
|
dataset_dict["instances"] = utils.filter_empty_instances(instances) |
|
|
|
return dataset_dict |
|
|
|
|
|
class ObjDescription: |
|
def __init__(self, object_descriptions): |
|
self.data = object_descriptions |
|
|
|
def __getitem__(self, item): |
|
assert type(item) == torch.Tensor |
|
assert item.dim() == 1 |
|
if len(item) > 0: |
|
assert item.dtype == torch.int64 or item.dtype == torch.bool |
|
if item.dtype == torch.int64: |
|
return ObjDescription([self.data[x.item()] for x in item]) |
|
elif item.dtype == torch.bool: |
|
return ObjDescription(list(compress(self.data, item))) |
|
|
|
return ObjDescription(list(compress(self.data, item))) |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __repr__(self): |
|
return "ObjDescription({})".format(self.data) |