import argparse import json import os import tarfile import tempfile from typing import Dict, List from loguru import logger from tqdm import tqdm # fmt: off parser = argparse.ArgumentParser( description="""Pre-process RedCaps dataset for training VirTex models - make small shards of TAR files containing images and captions.""" ) parser.add_argument( "-a", "--annotations", required=True, help="Path to a RedCaps annotation file." ) parser.add_argument( "-i", "--images", default="datasets/redcaps/images", help="""Path to RedCaps image directory. This directory is expected to have subreddit specific sub-directories containing images.""", ) parser.add_argument( "-z", "--shard-size", type=int, default=1000, help="Maximum number of RedCaps instances in a single TAR file shard.", ) parser.add_argument( "-o", "--output-prefix", required=True, help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` " "will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...", ) # fmt: on def main(_A: argparse.Namespace): r""" Make TAR files containing images and annotations from a single RedCaps annotations file. These TAR files are arranged in a way that `WebDataset `_ can understand. """ ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"] # Keep track of the current index of TAR file shard and dataset index. SHARD_INDEX: int = 0 DATASET_INDEX: int = 0 # Create TAR file handle for the initial shard. tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w") # Keep a count of submissions that were skipped because their image was # not downloaded (not present in image dir). SKIPPED: int = 0 for ann in tqdm(ANNOTATIONS): image_path = os.path.join( _A.images, ann["subreddit"], f"{ann['image_id']}.jpg" ) # Add current image in shard if it exists. if os.path.exists(image_path): tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg") # Save subreddit name and caption as a JSON file. subreddit_and_caption = { "subreddit": ann["subreddit"], "caption": ann["caption"] } tmpfile = tempfile.NamedTemporaryFile("w+") tmpfile.write(json.dumps(subreddit_and_caption)) tmpfile.seek(0) tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json") tmpfile.close() DATASET_INDEX += 1 # Create new shard if current shard is full. if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0: tar_handle.close() logger.success( f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar" ) SHARD_INDEX += 1 # Open new TAR file shard. tar_handle = tarfile.open( f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w" ) else: SKIPPED += 1 # Close the file handle to properly save it. tar_handle.close() logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n") logger.info(f"Skipped {SKIPPED} annotations due to missing images.") if __name__ == "__main__": _A = parser.parse_args() main(_A)