rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
No virus
4.81 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import copy
import logging
import numpy as np
from typing import List, Union
import torch
import detectron2.data.detection_utils as utils
import detectron2.data.transforms as T
from detectron2.config import configurable
from .detection_utils import annotations_to_instances, transform_instance_annotations
__all__ = [
"PointSupDatasetMapper",
]
class PointSupDatasetMapper:
"""
The callable currently does the following:
1. Read the image from "file_name"
2. Applies transforms to the image and annotations
3. Prepare data and annotations to Tensor and :class:`Instances`
"""
@configurable
def __init__(
self,
is_train: bool,
*,
augmentations: List[Union[T.Augmentation, T.Transform]],
image_format: str,
# Extra data augmentation for point supervision
sample_points: int = 0,
):
"""
NOTE: this interface is experimental.
Args:
is_train: whether it's used in training or inference
augmentations: a list of augmentations or deterministic transforms to apply
image_format: an image format supported by :func:`detection_utils.read_image`.
sample_points: subsample points at each iteration
"""
# fmt: off
self.is_train = is_train
self.augmentations = T.AugmentationList(augmentations)
self.image_format = image_format
self.sample_points = sample_points
# fmt: on
logger = logging.getLogger(__name__)
mode = "training" if is_train else "inference"
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
logger.info(f"Point Augmentations used in {mode}: sample {sample_points} points")
@classmethod
def from_config(cls, cfg, is_train: bool = True):
augs = utils.build_augmentation(cfg, is_train)
if cfg.INPUT.CROP.ENABLED and is_train:
raise ValueError("Crop augmentation not supported to point supervision.")
ret = {
"is_train": is_train,
"augmentations": augs,
"image_format": cfg.INPUT.FORMAT,
"sample_points": cfg.INPUT.SAMPLE_POINTS,
}
return ret
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
utils.check_image_size(dataset_dict, image)
aug_input = T.AugInput(image)
transforms = self.augmentations(aug_input)
image = aug_input.image
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if not self.is_train:
dataset_dict.pop("annotations", None)
return dataset_dict
if "annotations" in dataset_dict:
# Maps points from the closed interval [0, image_size - 1] on discrete
# image coordinates to the half-open interval [x1, x2) on continuous image
# coordinates. We use the continuous-discrete conversion from Heckbert
# 1990 ("What is the coordinate of a pixel?"): d = floor(c) and c = d + 0.5,
# where d is a discrete coordinate and c is a continuous coordinate.
for ann in dataset_dict["annotations"]:
point_coords_wrt_image = np.array(ann["point_coords"]).astype(float)
point_coords_wrt_image = point_coords_wrt_image + 0.5
ann["point_coords"] = point_coords_wrt_image
annos = [
# also need to transform point coordinates
transform_instance_annotations(
obj,
transforms,
image_shape,
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = annotations_to_instances(
annos,
image_shape,
sample_points=self.sample_points,
)
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict