Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/d2/detr/dataset_mapper.py | |
import copy | |
import logging | |
import numpy as np | |
import torch | |
from detectron2.config import configurable | |
from detectron2.data import detection_utils as utils | |
from detectron2.data import transforms as T | |
from detectron2.data.transforms import TransformGen | |
from detectron2.structures import BitMasks, Instances | |
__all__ = ["DETRPanopticDatasetMapper"] | |
def build_transform_gen(cfg, is_train): | |
""" | |
Create a list of :class:`TransformGen` from config. | |
Returns: | |
list[TransformGen] | |
""" | |
if is_train: | |
min_size = cfg.INPUT.MIN_SIZE_TRAIN | |
max_size = cfg.INPUT.MAX_SIZE_TRAIN | |
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING | |
else: | |
min_size = cfg.INPUT.MIN_SIZE_TEST | |
max_size = cfg.INPUT.MAX_SIZE_TEST | |
sample_style = "choice" | |
if sample_style == "range": | |
assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( | |
len(min_size) | |
) | |
logger = logging.getLogger(__name__) | |
tfm_gens = [] | |
if is_train: | |
tfm_gens.append(T.RandomFlip()) | |
tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) | |
if is_train: | |
logger.info("TransformGens used in training: " + str(tfm_gens)) | |
return tfm_gens | |
# This is specifically designed for the COCO dataset. | |
class DETRPanopticDatasetMapper: | |
""" | |
A callable which takes a dataset dict in Detectron2 Dataset format, | |
and map it into a format used by MaskFormer. | |
This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. | |
The callable currently does the following: | |
1. Read the image from "file_name" | |
2. Applies geometric transforms to the image and annotation | |
3. Find and applies suitable cropping to the image and annotation | |
4. Prepare image and annotation to Tensors | |
""" | |
def __init__( | |
self, | |
is_train=True, | |
*, | |
crop_gen, | |
tfm_gens, | |
image_format, | |
): | |
""" | |
NOTE: this interface is experimental. | |
Args: | |
is_train: for training or inference | |
augmentations: a list of augmentations or deterministic transforms to apply | |
crop_gen: crop augmentation | |
tfm_gens: data augmentation | |
image_format: an image format supported by :func:`detection_utils.read_image`. | |
""" | |
self.crop_gen = crop_gen | |
self.tfm_gens = tfm_gens | |
logging.getLogger(__name__).info( | |
"[DETRPanopticDatasetMapper] Full TransformGens used in training: {}, crop: {}".format( | |
str(self.tfm_gens), str(self.crop_gen) | |
) | |
) | |
self.img_format = image_format | |
self.is_train = is_train | |
def from_config(cls, cfg, is_train=True): | |
# Build augmentation | |
if cfg.INPUT.CROP.ENABLED and is_train: | |
crop_gen = [ | |
T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), | |
T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), | |
] | |
else: | |
crop_gen = None | |
tfm_gens = build_transform_gen(cfg, is_train) | |
ret = { | |
"is_train": is_train, | |
"crop_gen": crop_gen, | |
"tfm_gens": tfm_gens, | |
"image_format": cfg.INPUT.FORMAT, | |
} | |
return ret | |
def __call__(self, dataset_dict): | |
""" | |
Args: | |
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. | |
Returns: | |
dict: a format that builtin models in detectron2 accept | |
""" | |
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below | |
image = utils.read_image(dataset_dict["file_name"], format=self.img_format) | |
utils.check_image_size(dataset_dict, image) | |
if self.crop_gen is None: | |
image, transforms = T.apply_transform_gens(self.tfm_gens, image) | |
else: | |
if np.random.rand() > 0.5: | |
image, transforms = T.apply_transform_gens(self.tfm_gens, image) | |
else: | |
image, transforms = T.apply_transform_gens( | |
self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image | |
) | |
image_shape = image.shape[:2] # h, w | |
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, | |
# but not efficient on large generic data structures due to the use of pickle & mp.Queue. | |
# Therefore it's important to use torch.Tensor. | |
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) | |
if not self.is_train: | |
# USER: Modify this if you want to keep them for some reason. | |
dataset_dict.pop("annotations", None) | |
return dataset_dict | |
if "pan_seg_file_name" in dataset_dict: | |
pan_seg_gt = utils.read_image(dataset_dict.pop("pan_seg_file_name"), "RGB") | |
segments_info = dataset_dict["segments_info"] | |
# apply the same transformation to panoptic segmentation | |
pan_seg_gt = transforms.apply_segmentation(pan_seg_gt) | |
from panopticapi.utils import rgb2id | |
pan_seg_gt = rgb2id(pan_seg_gt) | |
instances = Instances(image_shape) | |
classes = [] | |
masks = [] | |
for segment_info in segments_info: | |
class_id = segment_info["category_id"] | |
if not segment_info["iscrowd"]: | |
classes.append(class_id) | |
masks.append(pan_seg_gt == segment_info["id"]) | |
classes = np.array(classes) | |
instances.gt_classes = torch.tensor(classes, dtype=torch.int64) | |
if len(masks) == 0: | |
# Some image does not have annotation (all ignored) | |
instances.gt_masks = torch.zeros((0, pan_seg_gt.shape[-2], pan_seg_gt.shape[-1])) | |
else: | |
masks = BitMasks( | |
torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks]) | |
) | |
instances.gt_masks = masks.tensor | |
dataset_dict["instances"] = instances | |
return dataset_dict | |