|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import logging |
|
import numpy as np |
|
from typing import List, Optional, Union |
|
import torch |
|
|
|
from detectron2.config import configurable |
|
|
|
from detectron2.data import detection_utils as utils |
|
from detectron2.data import transforms as T |
|
from oneformer.data.tokenizer import SimpleTokenizer, Tokenize |
|
|
|
__all__ = ["DatasetMapper"] |
|
|
|
|
|
class DatasetMapper: |
|
""" |
|
A callable which takes a dataset dict in Detectron2 Dataset format, |
|
and map it into a format used by the model. |
|
|
|
This is the default callable to be used to map your dataset dict into training data. |
|
You may need to follow it to implement your own one for customized logic, |
|
such as a different way to read or transform images. |
|
See :doc:`/tutorials/data_loading` for details. |
|
|
|
The callable currently does the following: |
|
|
|
1. Read the image from "file_name" |
|
2. Applies cropping/geometric transforms to the image and annotations |
|
3. Prepare data and annotations to Tensor and :class:`Instances` |
|
""" |
|
|
|
@configurable |
|
def __init__( |
|
self, |
|
is_train: bool, |
|
*, |
|
augmentations: List[Union[T.Augmentation, T.Transform]], |
|
image_format: str, |
|
task_seq_len: int, |
|
task: str = "panoptic", |
|
use_instance_mask: bool = False, |
|
use_keypoint: bool = False, |
|
instance_mask_format: str = "polygon", |
|
keypoint_hflip_indices: Optional[np.ndarray] = None, |
|
precomputed_proposal_topk: Optional[int] = None, |
|
recompute_boxes: bool = False, |
|
): |
|
""" |
|
NOTE: this interface is experimental. |
|
|
|
Args: |
|
is_train: whether it's used in training or inference |
|
augmentations: a list of augmentations or deterministic transforms to apply |
|
image_format: an image format supported by :func:`detection_utils.read_image`. |
|
use_instance_mask: whether to process instance segmentation annotations, if available |
|
use_keypoint: whether to process keypoint annotations if available |
|
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation |
|
masks into this format. |
|
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` |
|
precomputed_proposal_topk: if given, will load pre-computed |
|
proposals from dataset_dict and keep the top k proposals for each image. |
|
recompute_boxes: whether to overwrite bounding box annotations |
|
by computing tight bounding boxes from instance mask annotations. |
|
""" |
|
if recompute_boxes: |
|
assert use_instance_mask, "recompute_boxes requires instance masks" |
|
|
|
self.is_train = is_train |
|
self.augmentations = T.AugmentationList(augmentations) |
|
self.image_format = image_format |
|
self.use_instance_mask = use_instance_mask |
|
self.instance_mask_format = instance_mask_format |
|
self.use_keypoint = use_keypoint |
|
self.keypoint_hflip_indices = keypoint_hflip_indices |
|
self.proposal_topk = precomputed_proposal_topk |
|
self.recompute_boxes = recompute_boxes |
|
self.task_tokenizer = Tokenize(SimpleTokenizer(), max_seq_len=task_seq_len) |
|
self.task = task |
|
assert self.task in ["panoptic", "semantic", "instance"] |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
mode = "training" if is_train else "inference" |
|
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") |
|
|
|
@classmethod |
|
def from_config(cls, cfg, is_train: bool = True): |
|
augs = utils.build_augmentation(cfg, is_train) |
|
if cfg.INPUT.CROP.ENABLED and is_train: |
|
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) |
|
recompute_boxes = cfg.MODEL.MASK_ON |
|
else: |
|
recompute_boxes = False |
|
|
|
ret = { |
|
"is_train": is_train, |
|
"augmentations": augs, |
|
"image_format": cfg.INPUT.FORMAT, |
|
"use_instance_mask": cfg.MODEL.MASK_ON, |
|
"instance_mask_format": cfg.INPUT.MASK_FORMAT, |
|
"use_keypoint": cfg.MODEL.KEYPOINT_ON, |
|
"task_seq_len": cfg.INPUT.TASK_SEQ_LEN, |
|
"recompute_boxes": recompute_boxes, |
|
"task": cfg.MODEL.TEST.TASK, |
|
} |
|
|
|
if cfg.MODEL.KEYPOINT_ON: |
|
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) |
|
|
|
if cfg.MODEL.LOAD_PROPOSALS: |
|
ret["precomputed_proposal_topk"] = ( |
|
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN |
|
if is_train |
|
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST |
|
) |
|
return ret |
|
|
|
def _transform_annotations(self, dataset_dict, transforms, image_shape): |
|
|
|
for anno in dataset_dict["annotations"]: |
|
if not self.use_instance_mask: |
|
anno.pop("segmentation", None) |
|
if not self.use_keypoint: |
|
anno.pop("keypoints", None) |
|
|
|
|
|
annos = [ |
|
utils.transform_instance_annotations( |
|
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices |
|
) |
|
for obj in dataset_dict.pop("annotations") |
|
if obj.get("iscrowd", 0) == 0 |
|
] |
|
instances = utils.annotations_to_instances( |
|
annos, image_shape, mask_format=self.instance_mask_format |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.recompute_boxes: |
|
instances.gt_boxes = instances.gt_masks.get_bounding_boxes() |
|
dataset_dict["instances"] = utils.filter_empty_instances(instances) |
|
|
|
def __call__(self, dataset_dict): |
|
""" |
|
Args: |
|
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. |
|
|
|
Returns: |
|
dict: a format that builtin models in detectron2 accept |
|
""" |
|
dataset_dict = copy.deepcopy(dataset_dict) |
|
|
|
image = utils.read_image(dataset_dict["file_name"], format=self.image_format) |
|
utils.check_image_size(dataset_dict, image) |
|
|
|
task = f"The task is {self.task}" |
|
dataset_dict["task"] = task |
|
|
|
|
|
if "sem_seg_file_name" in dataset_dict: |
|
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) |
|
else: |
|
sem_seg_gt = None |
|
|
|
aug_input = T.AugInput(image, sem_seg=sem_seg_gt) |
|
transforms = self.augmentations(aug_input) |
|
image, sem_seg_gt = aug_input.image, aug_input.sem_seg |
|
|
|
image_shape = image.shape[:2] |
|
|
|
|
|
|
|
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) |
|
if sem_seg_gt is not None: |
|
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) |
|
|
|
|
|
|
|
if self.proposal_topk is not None: |
|
utils.transform_proposals( |
|
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk |
|
) |
|
|
|
if not self.is_train: |
|
|
|
dataset_dict.pop("annotations", None) |
|
dataset_dict.pop("sem_seg_file_name", None) |
|
return dataset_dict |
|
|
|
if "annotations" in dataset_dict: |
|
self._transform_annotations(dataset_dict, transforms, image_shape) |
|
|
|
return dataset_dict |