# Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py import copy import logging import numpy as np import torch from detectron2.config import configurable from detectron2.data import detection_utils as utils from detectron2.data import transforms as T from detectron2.data.transforms import TransformGen from detectron2.structures import BitMasks, Instances __all__ = ["DETRPanopticDatasetMapper"] def build_transform_gen(cfg, is_train): """ Create a list of :class:`TransformGen` from config. Returns: list[TransformGen] """ if is_train: min_size = cfg.INPUT.MIN_SIZE_TRAIN max_size = cfg.INPUT.MAX_SIZE_TRAIN sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING else: min_size = cfg.INPUT.MIN_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST sample_style = "choice" if sample_style == "range": assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( len(min_size) ) logger = logging.getLogger(__name__) tfm_gens = [] if is_train: tfm_gens.append(T.RandomFlip()) tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) if is_train: logger.info("TransformGens used in training: " + str(tfm_gens)) return tfm_gens # This is specifically designed for the COCO dataset. class DETRPanopticDatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by MaskFormer. This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. The callable currently does the following: 1. Read the image from "file_name" 2. Applies geometric transforms to the image and annotation 3. Find and applies suitable cropping to the image and annotation 4. Prepare image and annotation to Tensors """ @configurable def __init__( self, is_train=True, *, crop_gen, tfm_gens, image_format, ): """ NOTE: this interface is experimental. Args: is_train: for training or inference augmentations: a list of augmentations or deterministic transforms to apply crop_gen: crop augmentation tfm_gens: data augmentation image_format: an image format supported by :func:`detection_utils.read_image`. """ self.crop_gen = crop_gen self.tfm_gens = tfm_gens logging.getLogger(__name__).info( "[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format( str(self.tfm_gens), str(self.crop_gen) ) ) self.img_format = image_format self.is_train = is_train @classmethod def from_config(cls, cfg, is_train=True): # Build augmentation if cfg.INPUT.CROP.ENABLED and is_train: crop_gen = [ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), ] else: crop_gen = None tfm_gens = build_transform_gen(cfg, is_train) ret = { "is_train": is_train, "crop_gen": crop_gen, "tfm_gens": tfm_gens, "image_format": cfg.INPUT.FORMAT, } return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below image = utils.read_image(dataset_dict["file_name"], format=self.img_format) utils.check_image_size(dataset_dict, image) if self.crop_gen is None: image, transforms = T.apply_transform_gens(self.tfm_gens, image) else: if np.random.rand() > 0.5: image, transforms = T.apply_transform_gens(self.tfm_gens, image) else: image, transforms = T.apply_transform_gens( self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image ) image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) return dataset_dict if "pan_seg_file_name" in dataset_dict: pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") segments_info = dataset_dict["segments_info"] # apply the same transformation to panoptic segmentation pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) from panopticapi.utils import rgb2id pan_seg_gt = rgb2id(pan_seg_gt) instances = Instances(image_shape) classes = [] masks = [] for segment_info in segments_info: class_id = segment_info["category_id"] if not segment_info["iscrowd"]: classes.append(class_id) masks.append(pan_seg_gt == segment_info["id"]) classes = np.array(classes) instances.gt_classes = torch.tensor(classes, dtype=torch.int64) if len(masks) == 0: # Some image does not have annotation (all ignored) instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) else: masks = BitMasks( torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) ) instances.gt_masks = masks.tensor dataset_dict["instances"] = instances return dataset_dict