File size: 3,429 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import json
import os
import tarfile
import tempfile
from typing import Dict, List

from loguru import logger
from tqdm import tqdm


# fmt: off
parser = argparse.ArgumentParser(
    description="""Pre-process RedCaps dataset for training VirTex models - make
    small shards of TAR files containing images and captions."""
)
parser.add_argument(
    "-a", "--annotations", required=True, help="Path to a RedCaps annotation file."
)
parser.add_argument(
    "-i", "--images", default="datasets/redcaps/images",
    help="""Path to RedCaps image directory. This directory is expected to have
    subreddit specific sub-directories containing images.""",
)
parser.add_argument(
    "-z", "--shard-size", type=int, default=1000,
    help="Maximum number of RedCaps instances in a single TAR file shard.",
)
parser.add_argument(
    "-o", "--output-prefix", required=True,
    help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` "
    "will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...",
)
# fmt: on


def main(_A: argparse.Namespace):
    r"""
    Make TAR files containing images and annotations from a single RedCaps
    annotations file. These TAR files are arranged in a way that
    `WebDataset <https://github.com/tmbdev/webdataset>`_ can understand.
    """

    ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"]

    # Keep track of the current index of TAR file shard and dataset index.
    SHARD_INDEX: int = 0
    DATASET_INDEX: int = 0

    # Create TAR file handle for the initial shard.
    tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w")

    # Keep a count of submissions that were skipped because their image was
    # not downloaded (not present in image dir).
    SKIPPED: int = 0

    for ann in tqdm(ANNOTATIONS):

        image_path = os.path.join(
            _A.images, ann["subreddit"], f"{ann['image_id']}.jpg"
        )
        # Add current image in shard if it exists.
        if os.path.exists(image_path):

            tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg")

            # Save subreddit name and caption as a JSON file.
            subreddit_and_caption = {
                "subreddit": ann["subreddit"], "caption": ann["caption"]
            }
            tmpfile = tempfile.NamedTemporaryFile("w+")
            tmpfile.write(json.dumps(subreddit_and_caption))
            tmpfile.seek(0)
            tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json")
            tmpfile.close()

            DATASET_INDEX += 1

            # Create new shard if current shard is full.
            if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0:
                tar_handle.close()
                logger.success(
                    f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar"
                )
                SHARD_INDEX += 1

                # Open new TAR file shard.
                tar_handle = tarfile.open(
                    f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w"
                )
        else:
            SKIPPED += 1

    # Close the file handle to properly save it.
    tar_handle.close()
    logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n")
    logger.info(f"Skipped {SKIPPED} annotations due to missing images.")


if __name__ == "__main__":
    _A = parser.parse_args()
    main(_A)