import argparse import os import pickle import platform from typing import Any, List import albumentations as alb import lmdb from tqdm import tqdm from import DataLoader from import SimpleCocoCaptionsReader # fmt: off parser = argparse.ArgumentParser("Serialize a COCO Captions split to LMDB.") parser.add_argument( "-d", "--data-root", default="datasets/coco", help="Path to the root directory of COCO dataset.", ) parser.add_argument( "-s", "--split", choices=["train", "val"], help="Which split to process, either `train` or `val`.", ) parser.add_argument( "-b", "--batch-size", type=int, default=16, help="Batch size to process and serialize data. Set as per CPU memory.", ) parser.add_argument( "-j", "--cpu-workers", type=int, default=4, help="Number of CPU workers for data loading.", ) parser.add_argument( "-e", "--short-edge-size", type=int, default=None, help="""Resize shorter edge to this size (keeping aspect ratio constant) before serializing. Useful for saving disk memory, and faster read. If None, no images are resized.""" ) parser.add_argument( "-o", "--output", default="datasets/serialized/coco_train2017.lmdb", help="Path to store the file containing serialized dataset.", ) def collate_fn(instances: List[Any]): r"""Collate function for data loader to return list of instances as-is.""" return instances if __name__ == "__main__": _A = parser.parse_args() os.makedirs(os.path.dirname(_A.output), exist_ok=True) dloader = DataLoader( SimpleCocoCaptionsReader(_A.data_root, _A.split), batch_size=_A.batch_size, num_workers=_A.cpu_workers, shuffle=False, drop_last=False, collate_fn=collate_fn ) # Open an LMDB database. # Set a sufficiently large map size for LMDB (based on platform). map_size = 1099511627776 * 2 if platform.system() == "Linux" else 1280000 db = _A.output, map_size=map_size, subdir=False, meminit=False, map_async=True ) # Transform to resize shortest edge and keep aspect ratio same. if _A.short_edge_size is not None: resize = alb.SmallestMaxSize(max_size=_A.short_edge_size, always_apply=True) # Serialize each instance (as a dictionary). Use `pickle.dumps`. Key will # be an integer (cast as string) starting from `0`. INSTANCE_COUNTER: int = 0 for idx, batch in enumerate(tqdm(dloader)): txn = db.begin(write=True) for instance in batch: image = instance["image"] width, height, channels = image.shape # Resize image from instance and convert instance to tuple. if _A.short_edge_size is not None and min(width, height) > _A.short_edge_size: image = resize(image=image)["image"] instance = (instance["image_id"], instance["image"], instance["captions"]) txn.put( f"{INSTANCE_COUNTER}".encode("ascii"), pickle.dumps(instance, protocol=-1) ) INSTANCE_COUNTER += 1 txn.commit() db.sync() db.close()