|
import json |
|
import os |
|
import pickle |
|
import logging |
|
|
|
import datasets |
|
import pycocotools.mask as mask |
|
import dotenv |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
_CITATION = """\ |
|
@article{DBLP:journals/corr/LinMBHPRDZ14, |
|
author = {Tsung{-}Yi Lin and |
|
Michael Maire and |
|
Serge J. Belongie and |
|
Lubomir D. Bourdev and |
|
Ross B. Girshick and |
|
James Hays and |
|
Pietro Perona and |
|
Deva Ramanan and |
|
Piotr Doll{'{a} }r and |
|
C. Lawrence Zitnick}, |
|
title = {Microsoft {COCO:} Common Objects in Context}, |
|
journal = {CoRR}, |
|
volume = {abs/1405.0312}, |
|
year = {2014}, |
|
url = {http://arxiv.org/abs/1405.0312}, |
|
archivePrefix = {arXiv}, |
|
eprint = {1405.0312}, |
|
timestamp = {Mon, 13 Aug 2018 16:48:13 +0200}, |
|
biburl = {https://dblp.org/rec/bib/journals/corr/LinMBHPRDZ14}, |
|
bibsource = {dblp computer science bibliography, https://dblp.org} |
|
} |
|
""" |
|
|
|
|
|
|
|
_DESCRIPTION = """\ |
|
COCO is a large-scale object detection, segmentation, and captioning dataset. |
|
""" |
|
|
|
|
|
_HOMEPAGE = "http://cocodataset.org/#home" |
|
|
|
|
|
_LICENSE = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
_URLs = {} |
|
|
|
_BASE_REGION_FEATURES = { |
|
"region_id": datasets.Value("int64"), |
|
"image_id": datasets.Value("int32"), |
|
"phrases": [datasets.Value("string")], |
|
"x": datasets.Value("int32"), |
|
"y": datasets.Value("int32"), |
|
"width": datasets.Value("int32"), |
|
"height": datasets.Value("int32"), |
|
} |
|
|
|
_BASE_MASK_FEATURES = { |
|
"size": [datasets.Value("int32")], |
|
"counts": datasets.Value("string"), |
|
} |
|
|
|
_BASE_MASK_REGION_FEATURES = { |
|
"region_id": datasets.Value("int64"), |
|
"image_id": datasets.Value("int32"), |
|
"phrases": [datasets.Value("string")], |
|
"x": datasets.Value("int32"), |
|
"y": datasets.Value("int32"), |
|
"width": datasets.Value("int32"), |
|
"height": datasets.Value("int32"), |
|
"mask": _BASE_MASK_FEATURES, |
|
} |
|
|
|
_ANNOTATION_FEATURES = { |
|
"region_descriptions": {"regions": [_BASE_REGION_FEATURES]}, |
|
"mask_region_descriptions": {"regions": [_BASE_MASK_REGION_FEATURES]}, |
|
} |
|
|
|
_BASE_IMAGE_METADATA_FEATURES = { |
|
"image_id": datasets.Value("int32"), |
|
|
|
|
|
"height": datasets.Value("int32"), |
|
"width": datasets.Value("int32"), |
|
"file_name": datasets.Value("string"), |
|
"coco_url": datasets.Value("string"), |
|
|
|
"task_type": datasets.Value("string"), |
|
} |
|
|
|
|
|
_SPLIT_BYS = { |
|
"refclef": ["unc", "berkeley"], |
|
|
|
|
|
"refcoco": ["unc"], |
|
"refcoco+": ["unc"], |
|
"refcocog": ["umd", "google"], |
|
} |
|
_SPLITS = { |
|
"refclef-unc": ["train", "val", "testA", "testB", "testC"], |
|
"refclef-berkeley": ["train", "val", "test"], |
|
|
|
|
|
**{f"refcoco-{_split_by}": ["train", "val", "testA", "testB"] for _split_by in _SPLIT_BYS["refcoco"]}, |
|
**{f"refcoco+-{_split_by}": ["train", "val", "testA", "testB"] for _split_by in _SPLIT_BYS["refcoco+"]}, |
|
**{f"refcocog-{_split_by}": ["train", "val"] for _split_by in _SPLIT_BYS["refcocog"]}, |
|
} |
|
datasets.Split("testA") |
|
datasets.Split("testB") |
|
|
|
|
|
class RefCOCOBuilderConfig(datasets.BuilderConfig): |
|
def __init__( |
|
self, |
|
name, |
|
splits, |
|
with_image=True, |
|
with_mask=True, |
|
base_url=None, |
|
sas_key=None, |
|
task_type="caption", |
|
**kwargs, |
|
): |
|
super().__init__(name, **kwargs) |
|
self.splits = splits |
|
self.dataset_name = name.split("-")[0] |
|
self.split_by = name.split("-")[-1] |
|
self.with_image = with_image |
|
self.with_mask = with_mask |
|
self.base_url = base_url |
|
self.sas_key = sas_key |
|
self.task_type = task_type |
|
|
|
@property |
|
def features(self): |
|
annoation_type = "mask_region_descriptions" if self.with_mask else "region_descriptions" |
|
logger.info(f"Using annotation type: {annoation_type} due to with_mask={self.with_mask}") |
|
return datasets.Features( |
|
{ |
|
**({"image": datasets.Image()} if self.with_image else {}), |
|
**_BASE_IMAGE_METADATA_FEATURES, |
|
**_ANNOTATION_FEATURES[annoation_type], |
|
} |
|
) |
|
|
|
|
|
|
|
class RefCOCODataset(datasets.GeneratorBasedBuilder): |
|
"""An example dataset script to work with the local (downloaded) COCO dataset""" |
|
|
|
VERSION = datasets.Version("0.0.0") |
|
|
|
BUILDER_CONFIG_CLASS = RefCOCOBuilderConfig |
|
BUILDER_CONFIGS = [RefCOCOBuilderConfig(name=name, splits=splits) for name, splits in _SPLITS.items()] |
|
|
|
DEFAULT_CONFIG_NAME = "refcoco-unc" |
|
config: RefCOCOBuilderConfig |
|
|
|
def _info(self): |
|
|
|
features = self.config.features |
|
|
|
return datasets.DatasetInfo( |
|
|
|
description=_DESCRIPTION, |
|
|
|
features=features, |
|
|
|
|
|
|
|
supervised_keys=None, |
|
|
|
homepage=_HOMEPAGE, |
|
|
|
license=_LICENSE, |
|
|
|
citation=_CITATION, |
|
) |
|
|
|
def _split_generators(self, dl_manager): |
|
"""Returns SplitGenerators.""" |
|
|
|
|
|
|
|
|
|
|
|
base_url = self.config.base_url |
|
if base_url is None: |
|
raise ValueError( |
|
"This script is supposed to work with local or remote RefCOCO dataset. It is either a local path or remote url. The argument `base_url` in `load_dataset()` is required." |
|
) |
|
logger.info(f"Using base_url: {base_url}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_DL_URLS = {} |
|
if self.config.dataset_name in ["refcoco", "refcoco+", "refcocog"]: |
|
_DL_URLS["image_dir"] = os.path.join(base_url, "train2014.zip") |
|
elif self.config.dataset_name == "refclef": |
|
_DL_URLS["image_dir"] = os.path.join(base_url, "saiapr_tc-12.zip") |
|
else: |
|
raise ValueError(f"Unknown dataset name: {self.config.dataset_name}") |
|
_DL_URLS["annotation_dir"] = os.path.join(base_url, f"{self.config.dataset_name}.zip") |
|
|
|
sas_key = self.config.sas_key |
|
if sas_key is None: |
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
sas_key = os.getenv("REFCOCO_SAS_KEY") |
|
if sas_key is not None and not os.path.exists(base_url): |
|
logger.info(f"Using sas_key: {sas_key}") |
|
_DL_URLS = {k: f"{v}{sas_key}" for k, v in _DL_URLS.items()} |
|
|
|
if dl_manager.is_streaming is True: |
|
raise ValueError( |
|
"dl_manager.is_streaming is True, which is very slow due to the random access inside zip files with streaming loading." |
|
) |
|
|
|
archive_path = dl_manager.download_and_extract(_DL_URLS) |
|
|
|
|
|
with open( |
|
os.path.join(archive_path["annotation_dir"], self.config.dataset_name, f"refs({self.config.split_by}).p"), |
|
"rb", |
|
) as fp: |
|
refs = pickle.load(fp) |
|
with open( |
|
os.path.join(archive_path["annotation_dir"], self.config.dataset_name, f"instances.json"), |
|
"r", |
|
encoding="UTF-8", |
|
) as fp: |
|
instances = json.load(fp) |
|
self.data = {} |
|
self.data["dataset"] = self.config.dataset_name |
|
self.data["refs"] = refs |
|
self.data["images"] = instances["images"] |
|
self.data["annotations"] = instances["annotations"] |
|
self.data["categories"] = instances["categories"] |
|
self.createIndex() |
|
print(f"num refs: {len(self.Refs)}") |
|
|
|
splits = [] |
|
for split in self.config.splits: |
|
if split == "train": |
|
dataset = datasets.SplitGenerator( |
|
name=datasets.Split.TRAIN, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs={ |
|
"image_dir": archive_path["image_dir"], |
|
"split": split, |
|
}, |
|
) |
|
elif split in ["val"]: |
|
dataset = datasets.SplitGenerator( |
|
name=datasets.Split.VALIDATION, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs={ |
|
"image_dir": archive_path["image_dir"], |
|
"split": split, |
|
}, |
|
) |
|
elif split == "test": |
|
dataset = datasets.SplitGenerator( |
|
name=datasets.Split.TEST, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs={ |
|
"image_dir": archive_path["image_dir"], |
|
"split": split, |
|
}, |
|
) |
|
elif split in ["testA", "testB", "testC"]: |
|
dataset = datasets.SplitGenerator( |
|
name=datasets.Split(split), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gen_kwargs={ |
|
"image_dir": archive_path["image_dir"], |
|
"split": split, |
|
}, |
|
) |
|
else: |
|
raise ValueError(f"Unknown split name: {split}") |
|
|
|
splits.append(dataset) |
|
|
|
return splits |
|
|
|
def _generate_examples( |
|
|
|
self, |
|
image_dir, |
|
split, |
|
): |
|
"""Yields examples as (key, example) tuples.""" |
|
|
|
|
|
|
|
ref_ids = self.getRefIds(split=split) |
|
img_ids = self.getImgIds(ref_ids=ref_ids) |
|
|
|
logger.info(f"Generating examples from {len(ref_ids)} refs and {len(img_ids)} images in split {split}...") |
|
|
|
if self.config.dataset_name in ["refcoco", "refcoco+", "refcocog"]: |
|
image_dir_name = "train2014" |
|
elif self.config.dataset_name == "refclef": |
|
image_dir_name = "saiapr_tc-12" |
|
else: |
|
raise ValueError(f"Unknown dataset name: {self.config.dataset_name}") |
|
|
|
for idx, img_id in enumerate(img_ids): |
|
img = self.Imgs[img_id] |
|
image_metadata = { |
|
"coco_url": img.get("coco_url", None), |
|
"file_name": img["file_name"], |
|
"height": img["height"], |
|
"width": img["width"], |
|
"image_id": img["id"], |
|
} |
|
image_dict = ( |
|
{"image": os.path.join(image_dir, image_dir_name, img["file_name"])} if self.config.with_image else {} |
|
) |
|
|
|
annotation = [] |
|
|
|
img_to_refs = self.imgToRefs[img_id] |
|
for img_to_ref in img_to_refs: |
|
ref_to_ann = self.refToAnn[img_to_ref["ref_id"]] |
|
x, y, width, height = ref_to_ann["bbox"] |
|
|
|
annotation_dict = { |
|
"image_id": img_to_ref["image_id"], |
|
"region_id": img_to_ref["ref_id"], |
|
"x": int(x), |
|
"y": int(y), |
|
"width": int(width), |
|
"height": int(height), |
|
} |
|
annotation_dict["phrases"] = [sent["sent"] for sent in img_to_ref["sentences"]] |
|
|
|
if self.config.with_mask: |
|
if type(ref_to_ann["segmentation"][0]) == list: |
|
rle = mask.frPyObjects(ref_to_ann["segmentation"], img["height"], img["width"]) |
|
else: |
|
rle = ref_to_ann["segmentation"] |
|
mask_dict = rle[0] |
|
annotation_dict["mask"] = { |
|
"size": mask_dict["size"], |
|
"counts": mask_dict["counts"].decode("utf-8"), |
|
} |
|
annotation.append(annotation_dict) |
|
annotation = {"regions": annotation} |
|
yield idx, {**image_dict, **image_metadata, **annotation, "task_type": self.config.task_type} |
|
|
|
""" |
|
{ |
|
'coco_url': Value(dtype='string', id=None), |
|
'file_name': Value(dtype='string', id=None), |
|
'height': Value(dtype='int32', id=None), |
|
'image': Image(decode=True, id=None), |
|
'image_id': Value(dtype='int32', id=None), |
|
'regions': [{ |
|
'height': Value(dtype='int32', id=None), |
|
'image_id': Value(dtype='int32', id=None), |
|
'mask': { |
|
'counts': Value(dtype='string', id=None), |
|
'size': [Value(dtype='int32', id=None)] |
|
}, |
|
'phrases': [Value(dtype='string', id=None)], |
|
'region_id': Value(dtype='int32', id=None), |
|
'width': Value(dtype='int32', id=None), |
|
'x': Value(dtype='int32', id=None), |
|
'y': Value(dtype='int32', id=None) |
|
}], |
|
'width': Value(dtype='int32', id=None) |
|
} |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def createIndex(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"creating index for {self.config.name}...") |
|
|
|
Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {} |
|
for ann in self.data["annotations"]: |
|
Anns[ann["id"]] = ann |
|
imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann] |
|
for img in self.data["images"]: |
|
Imgs[img["id"]] = img |
|
for cat in self.data["categories"]: |
|
Cats[cat["id"]] = cat["name"] |
|
|
|
|
|
Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {} |
|
Sents, sentToRef, sentToTokens = {}, {}, {} |
|
for ref in self.data["refs"]: |
|
|
|
ref_id = ref["ref_id"] |
|
ann_id = ref["ann_id"] |
|
category_id = ref["category_id"] |
|
image_id = ref["image_id"] |
|
|
|
|
|
Refs[ref_id] = ref |
|
imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref] |
|
catToRefs[category_id] = catToRefs.get(category_id, []) + [ref] |
|
refToAnn[ref_id] = Anns[ann_id] |
|
annToRef[ann_id] = ref |
|
|
|
|
|
for sent in ref["sentences"]: |
|
Sents[sent["sent_id"]] = sent |
|
sentToRef[sent["sent_id"]] = ref |
|
sentToTokens[sent["sent_id"]] = sent["tokens"] |
|
|
|
|
|
self.Refs = Refs |
|
self.Anns = Anns |
|
self.Imgs = Imgs |
|
self.Cats = Cats |
|
self.Sents = Sents |
|
self.imgToRefs = imgToRefs |
|
self.imgToAnns = imgToAnns |
|
self.refToAnn = refToAnn |
|
self.annToRef = annToRef |
|
self.catToRefs = catToRefs |
|
self.sentToRef = sentToRef |
|
self.sentToTokens = sentToTokens |
|
logger.info("index created.") |
|
""" |
|
Dataset Statistic: |
|
refcoco-unc |
|
Refs 50000 |
|
Anns 196771 |
|
Imgs 19994 |
|
Cats 80 |
|
Sents 142210 |
|
imgToRefs 19994 |
|
imgToAnns 19994 |
|
refToAnn 50000 |
|
annToRef 50000 |
|
catToRefs 78 |
|
sentToRef 142210 |
|
sentToTokens 142210 |
|
""" |
|
|
|
def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""): |
|
image_ids = image_ids if type(image_ids) == list else [image_ids] |
|
cat_ids = cat_ids if type(cat_ids) == list else [cat_ids] |
|
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
|
if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0: |
|
refs = self.data["refs"] |
|
else: |
|
if not len(image_ids) == 0: |
|
refs = [self.imgToRefs[image_id] for image_id in image_ids] |
|
else: |
|
refs = self.data["refs"] |
|
if not len(cat_ids) == 0: |
|
refs = [ref for ref in refs if ref["category_id"] in cat_ids] |
|
if not len(ref_ids) == 0: |
|
refs = [ref for ref in refs if ref["ref_id"] in ref_ids] |
|
if not len(split) == 0: |
|
if split in ["testA", "testB", "testC"]: |
|
|
|
refs = [ref for ref in refs if split[-1] in ref["split"]] |
|
elif split in ["testAB", "testBC", "testAC"]: |
|
|
|
refs = [ref for ref in refs if ref["split"] == split] |
|
elif split == "test": |
|
refs = [ref for ref in refs if "test" in ref["split"]] |
|
elif split == "train" or split == "val": |
|
refs = [ref for ref in refs if ref["split"] == split] |
|
else: |
|
raise ValueError("No such split [%s]" % split) |
|
ref_ids = [ref["ref_id"] for ref in refs] |
|
return ref_ids |
|
|
|
def getImgIds(self, ref_ids=[]): |
|
ref_ids = ref_ids if type(ref_ids) == list else [ref_ids] |
|
|
|
if not len(ref_ids) == 0: |
|
image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids])) |
|
else: |
|
image_ids = list(self.Imgs.keys()) |
|
return image_ids |
|
|