Spaces:
Build error
Build error
import os | |
import os.path | |
from pathlib import Path | |
from typing import Any, Callable, Optional, Tuple | |
import torch | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
from PIL import Image, ImageDraw | |
from torchvision.datasets.vision import VisionDataset | |
from .modulated_coco import ConvertCocoPolysToMask, has_valid_annotation | |
class CustomCocoDetection(VisionDataset): | |
"""Coco-style dataset imported from TorchVision. | |
It is modified to handle several image sources | |
Args: | |
root_coco (string): Path to the coco images | |
root_vg (string): Path to the vg images | |
annFile (string): Path to json annotation file. | |
transform (callable, optional): A function/transform that takes in an PIL image | |
and returns a transformed version. E.g, ``transforms.ToTensor`` | |
target_transform (callable, optional): A function/transform that takes in the | |
target and transforms it. | |
transforms (callable, optional): A function/transform that takes input sample and its target as entry | |
and returns a transformed version. | |
""" | |
def __init__( | |
self, | |
root_coco: str, | |
root_vg: str, | |
annFile: str, | |
transform: Optional[Callable] = None, | |
target_transform: Optional[Callable] = None, | |
transforms: Optional[Callable] = None, | |
) -> None: | |
super(CustomCocoDetection, self).__init__(root_coco, transforms, transform, target_transform) | |
from pycocotools.coco import COCO | |
self.coco = COCO(annFile) | |
self.ids = list(sorted(self.coco.imgs.keys())) | |
ids = [] | |
for img_id in self.ids: | |
if isinstance(img_id, str): | |
ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
else: | |
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
anno = self.coco.loadAnns(ann_ids) | |
if has_valid_annotation(anno): | |
ids.append(img_id) | |
self.ids = ids | |
self.root_coco = root_coco | |
self.root_vg = root_vg | |
def __getitem__(self, index): | |
""" | |
Args: | |
index (int): Index | |
Returns: | |
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
""" | |
coco = self.coco | |
img_id = self.ids[index] | |
ann_ids = coco.getAnnIds(imgIds=img_id) | |
target = coco.loadAnns(ann_ids) | |
img_info = coco.loadImgs(img_id)[0] | |
path = img_info["file_name"] | |
dataset = img_info["data_source"] | |
cur_root = self.root_coco if dataset == "coco" else self.root_vg | |
img = Image.open(os.path.join(cur_root, path)).convert("RGB") | |
if self.transforms is not None: | |
img, target = self.transforms(img, target) | |
return img, target | |
def __len__(self): | |
return len(self.ids) | |
class MixedDataset(CustomCocoDetection): | |
"""Same as the modulated detection dataset, except with multiple img sources""" | |
def __init__(self, | |
img_folder_coco, | |
img_folder_vg, | |
ann_file, | |
transforms, | |
return_masks, | |
return_tokens, | |
tokenizer=None, | |
disable_clip_to_image=False, | |
no_mask_for_gold=False, | |
max_query_len=256, | |
**kwargs): | |
super(MixedDataset, self).__init__(img_folder_coco, img_folder_vg, ann_file) | |
self._transforms = transforms | |
self.max_query_len = max_query_len | |
self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) | |
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
self.disable_clip_to_image = disable_clip_to_image | |
self.no_mask_for_gold = no_mask_for_gold | |
def __getitem__(self, idx): | |
img, target = super(MixedDataset, self).__getitem__(idx) | |
image_id = self.ids[idx] | |
caption = self.coco.loadImgs(image_id)[0]["caption"] | |
anno = {"image_id": image_id, "annotations": target, "caption": caption} | |
anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] | |
if self.no_mask_for_gold: | |
anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) | |
img, anno = self.prepare(img, anno) | |
# convert to BoxList (bboxes, labels) | |
boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes | |
target = BoxList(boxes, img.size, mode="xyxy") | |
classes = anno["labels"] | |
target.add_field("labels", classes) | |
if not self.disable_clip_to_image: | |
num_boxes = len(boxes) | |
target = target.clip_to_image(remove_empty=True) | |
assert len(target.bbox) == num_boxes, "Box removed in MixedDataset!!!" | |
if self._transforms is not None: | |
img, target = self._transforms(img, target) | |
# add additional property | |
for ann in anno: | |
target.add_field(ann, anno[ann]) | |
return img, target, idx | |
def get_img_info(self, index): | |
img_id = self.id_to_img_map[index] | |
img_data = self.coco.imgs[img_id] | |
return img_data | |