|
import json |
|
import os |
|
import datasets |
|
import dotenv |
|
from pycocotools.coco import COCO |
|
|
|
logger = datasets.logging.get_logger(__name__) |
|
|
|
|
|
_BASE_IMAGE_METADATA_FEATURES = { |
|
"image_id": datasets.Value("int32"), |
|
"width": datasets.Value("int32"), |
|
"height": datasets.Value("int32"), |
|
"file_name": datasets.Value("string"), |
|
"coco_url": datasets.Value("string"), |
|
"task_type": datasets.Value("string"), |
|
} |
|
|
|
_BASE_REGION_FEATURES = { |
|
|
|
"region_id": datasets.Value("int64"), |
|
"image_id": datasets.Value("int32"), |
|
"phrases": [datasets.Value("string")], |
|
"x": datasets.Value("int32"), |
|
"y": datasets.Value("int32"), |
|
"width": datasets.Value("int32"), |
|
"height": datasets.Value("int32"), |
|
} |
|
|
|
|
|
_BASE_MASK_FEATURES = { |
|
"size": [datasets.Value("int32")], |
|
"counts": datasets.Value("string"), |
|
} |
|
|
|
_BASE_MASK_REGION_FEATURES = { |
|
"region_id": datasets.Value("int64"), |
|
"image_id": datasets.Value("int32"), |
|
"phrases": [datasets.Value("string")], |
|
"x": datasets.Value("int32"), |
|
"y": datasets.Value("int32"), |
|
"width": datasets.Value("int32"), |
|
"height": datasets.Value("int32"), |
|
"mask": _BASE_MASK_FEATURES, |
|
|
|
|
|
} |
|
|
|
|
|
_ANNOTATION_FEATURES = { |
|
"region_descriptions": {"regions": [_BASE_REGION_FEATURES]}, |
|
"mask_region_descriptions": {"regions": [_BASE_MASK_REGION_FEATURES]}, |
|
} |
|
|
|
|
|
class COCOBuilderConfig(datasets.BuilderConfig): |
|
def __init__( |
|
self, |
|
name, |
|
splits, |
|
with_image: bool = True, |
|
with_mask: bool = False, |
|
coco_zip_url: str = None, |
|
coco_annotations_zip_url: str = None, |
|
task_type: str = "caption", |
|
**kwargs, |
|
): |
|
super().__init__(name, **kwargs) |
|
self.splits = splits |
|
self.with_image = with_image |
|
self.with_mask = with_mask |
|
self.coco_zip_url = coco_zip_url |
|
self.coco_annotations_zip_url = coco_annotations_zip_url |
|
self.task_type = task_type |
|
|
|
@property |
|
def features(self): |
|
if self.with_mask is True: |
|
raise ValueError("with_mask=True is not supported yet in COCO caption.") |
|
|
|
annoation_type = "mask_region_descriptions" if self.with_mask else "region_descriptions" |
|
logger.info(f"Using annotation type: {annoation_type} due to with_mask={self.with_mask}") |
|
return datasets.Features( |
|
{ |
|
**({"image": datasets.Image()} if self.with_image else {}), |
|
**_BASE_IMAGE_METADATA_FEATURES, |
|
**_ANNOTATION_FEATURES[annoation_type], |
|
} |
|
) |
|
|
|
|
|
|
|
class COCOCaptionPseudoRegionDataset(datasets.GeneratorBasedBuilder): |
|
"""An example dataset script to work with the local (downloaded) COCO dataset""" |
|
|
|
VERSION = datasets.Version("0.0.0") |
|
|
|
BUILDER_CONFIG_CLASS = COCOBuilderConfig |
|
BUILDER_CONFIGS = [ |
|
|
|
|
|
COCOBuilderConfig(name="2017", splits=["train", "valid"]), |
|
] |
|
DEFAULT_CONFIG_NAME = "2017" |
|
config: COCOBuilderConfig |
|
|
|
def _info(self): |
|
return datasets.DatasetInfo(features=self.config.features) |
|
|
|
def _split_generators(self, dl_manager): |
|
"""Returns SplitGenerators.""" |
|
|
|
|
|
|
|
|
|
coco_zip_url = self.config.coco_zip_url |
|
coco_annotations_zip_url = self.config.coco_annotations_zip_url |
|
if coco_zip_url is None: |
|
raise ValueError( |
|
"This script is supposed to work with local (downloaded) COCO dataset. The argument `coco_zip_url` in `load_dataset()` is required." |
|
) |
|
if coco_annotations_zip_url is None: |
|
raise ValueError( |
|
"This script is supposed to work with local (downloaded) COCO dataset. The argument `coco_annotations_zip_url` in `load_dataset()` is required." |
|
) |
|
|
|
|
|
logger.info(f"Try to load sas_key from .env file: {dotenv.load_dotenv('.env')}.") |
|
coco_zip_url_sas_key = os.getenv("COCO_ZIP_URL_SAS_KEY", "") |
|
coco_annotations_zip_url_sas_key = os.getenv("COCO_ANNOTATIONS_ZIP_URL_SAS_KEY", "") |
|
|
|
_DL_URLS = { |
|
"train": os.path.join(coco_zip_url, "train2017.zip") + coco_zip_url_sas_key, |
|
"val": os.path.join(coco_zip_url, "val2017.zip") + coco_zip_url_sas_key, |
|
"test": os.path.join(coco_zip_url, "test2017.zip") + coco_zip_url_sas_key, |
|
"annotations_trainval": os.path.join(coco_annotations_zip_url, "annotations_trainval2017.zip") |
|
+ coco_annotations_zip_url_sas_key, |
|
"image_info_test": os.path.join(coco_annotations_zip_url, "image_info_test2017.zip") |
|
+ coco_annotations_zip_url_sas_key, |
|
} |
|
|
|
archive_path = dl_manager.download_and_extract(_DL_URLS) |
|
|
|
splits = [] |
|
for split in self.config.splits: |
|
if split == "train": |
|
dataset = datasets.SplitGenerator( |
|
name=datasets.Split.TRAIN, |
|
|
|
gen_kwargs={ |
|
"json_path": os.path.join( |
|
archive_path["annotations_trainval"], "annotations", "captions_train2017.json" |
|
), |
|
|
|
"image_dir": archive_path["train"], |
|
}, |
|
) |
|
elif split in ["val", "valid", "validation", "dev"]: |
|
dataset = datasets.SplitGenerator( |
|
name=datasets.Split.VALIDATION, |
|
|
|
gen_kwargs={ |
|
"json_path": os.path.join( |
|
archive_path["annotations_trainval"], "annotations", "captions_val2017.json" |
|
), |
|
|
|
"image_dir": archive_path["val"], |
|
}, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
continue |
|
|
|
splits.append(dataset) |
|
|
|
return splits |
|
|
|
def _generate_examples( |
|
|
|
self, |
|
json_path, |
|
image_dir, |
|
): |
|
"""Yields examples as (key, example) tuples.""" |
|
|
|
coco = COCO(json_path) |
|
img_ids = coco.getImgIds() |
|
for idx, img_id in enumerate(img_ids): |
|
img = coco.imgs[img_id] |
|
image_metadata = { |
|
"coco_url": img["coco_url"], |
|
"file_name": img["file_name"], |
|
"height": img["height"], |
|
"width": img["width"], |
|
"image_id": img["id"], |
|
} |
|
image_dict = {"image": os.path.join(image_dir, img["file_name"])} if self.config.with_image else {} |
|
|
|
if img_id not in coco.imgToAnns: |
|
continue |
|
|
|
annotation = [] |
|
width, height = img["width"], img["height"] |
|
for ann in coco.imgToAnns[img_id]: |
|
x, y, width, height = 0, 0, width, height |
|
annotation_dict = { |
|
|
|
"region_id": ann["id"], |
|
"image_id": ann["image_id"], |
|
"x": x, |
|
"y": y, |
|
"width": width, |
|
"height": height, |
|
} |
|
|
|
phrases = [] |
|
phrases.append(ann["caption"]) |
|
|
|
|
|
annotation_dict["phrases"] = phrases |
|
|
|
if self.config.with_mask: |
|
mask_dict = coco.annToRLE(ann) |
|
mask_dict = { |
|
"size": mask_dict["size"], |
|
"counts": mask_dict["counts"].decode("utf-8"), |
|
} |
|
annotation_dict["mask"] = mask_dict |
|
|
|
annotation.append(annotation_dict) |
|
annotation = {"regions": annotation} |
|
|
|
yield idx, {**image_dict, **image_metadata, **annotation, "task_type": self.config.task_type} |
|
|