zdou0830's picture
desco
749745d
raw
history blame
No virus
6.49 kB
import os
import os.path
from pathlib import Path
from typing import Any, Callable, Optional, Tuple
import torch
from maskrcnn_benchmark.structures.bounding_box import BoxList
import pdb
from PIL import Image, ImageDraw
from torchvision.datasets.vision import VisionDataset
from .modulated_coco import ConvertCocoPolysToMask, has_valid_annotation
from maskrcnn_benchmark.data.datasets._caption_aug import CaptionAugmentation
import numpy as np
class CustomCocoDetection(VisionDataset):
"""Coco-style dataset imported from TorchVision.
It is modified to handle several image sources
Args:
root_coco (string): Path to the coco images
root_vg (string): Path to the vg images
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
"""
def __init__(
self,
root_coco: str,
root_vg: str,
annFile: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
transforms: Optional[Callable] = None,
) -> None:
super(CustomCocoDetection, self).__init__(root_coco, transforms, transform, target_transform)
from pycocotools.coco import COCO
self.coco = COCO(annFile)
self.ids = list(sorted(self.coco.imgs.keys()))
ids = []
for img_id in self.ids:
if isinstance(img_id, str):
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None)
else:
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = self.coco.loadAnns(ann_ids)
if has_valid_annotation(anno):
ids.append(img_id)
self.ids = ids
self.root_coco = root_coco
self.root_vg = root_vg
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
coco = self.coco
img_id = self.ids[index]
ann_ids = coco.getAnnIds(imgIds=img_id)
target = coco.loadAnns(ann_ids)
img_info = coco.loadImgs(img_id)[0]
path = img_info["file_name"]
dataset = img_info["data_source"]
cur_root = self.root_coco if dataset == "coco" else self.root_vg
img = Image.open(os.path.join(cur_root, path)).convert("RGB")
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.ids)
class MixedDataset(CustomCocoDetection):
"""Same as the modulated detection dataset, except with multiple img sources"""
def __init__(
self,
img_folder_coco,
img_folder_vg,
ann_file,
transforms,
return_masks,
return_tokens,
tokenizer=None,
disable_clip_to_image=False,
no_mask_for_gold=False,
max_query_len=256,
caption_augmentation_version=None,
caption_vocab_file=None,
**kwargs
):
super(MixedDataset, self).__init__(img_folder_coco, img_folder_vg, ann_file)
self._transforms = transforms
self.max_query_len = max_query_len
self.prepare = ConvertCocoPolysToMask(
return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len
)
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
self.disable_clip_to_image = disable_clip_to_image
self.no_mask_for_gold = no_mask_for_gold
self.caption_augmentation_version = caption_augmentation_version
if self.caption_augmentation_version is not None:
self.caption_augmentation = CaptionAugmentation(
self.caption_augmentation_version,
tokenizer,
caption_vocab_file=caption_vocab_file
)
def __getitem__(self, idx):
#try:
img, target = super(MixedDataset, self).__getitem__(idx)
image_id = self.ids[idx]
__anno = self.coco.loadImgs(image_id)[0]
caption = __anno["caption"]
if self.caption_augmentation_version is not None:
caption, target, spans = self.caption_augmentation(caption, target, gpt3_outputs = __anno.get("gpt3_outputs", None))
# print("augmented caption: ", caption)
# print("\n")
else:
spans = None
anno = {"image_id": image_id, "annotations": target, "caption": caption}
anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))]
if self.no_mask_for_gold:
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1))
img, anno = self.prepare(img, anno)
# convert to BoxList (bboxes, labels)
boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes
target = BoxList(boxes, img.size, mode="xyxy")
classes = anno["labels"]
target.add_field("labels", classes)
# if spans is not None:
# target.add_field("spans", spans) # add spans to target
if not self.disable_clip_to_image:
num_boxes = len(boxes)
target = target.clip_to_image(remove_empty=True)
assert len(target.bbox) == num_boxes, "Box removed in MixedDataset!!!"
if self._transforms is not None:
img, target = self._transforms(img, target)
# add additional property
for ann in anno:
target.add_field(ann, anno[ann])
return img, target, idx
# except:
# print("error in __getitem__ in mixed", idx)
# return self[np.random.choice(len(self))]
def get_img_info(self, index):
img_id = self.id_to_img_map[index]
img_data = self.coco.imgs[img_id]
return img_data