File size: 4,827 Bytes
5282eae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import base64
import json
import os
import tarfile
import uuid
import zipfile
import time

import braceexpand
import webdataset as wds
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--output_dir", type=str)
arg_parser.add_argument(
    "--image_shards",
    type=str,
    help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar",
)
arg_parser.add_argument(
    "--doc_shards",
    type=str,
    help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip",
)
arg_parser.add_argument(
    "--thread",
    type=int,
    default=128,
)
args = arg_parser.parse_args()

def get_txt_to_filename_dict(image_shards, disable_tqdm=False):
    txt_to_filename_dict = {}
    dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json")
    for data in tqdm(dataset, disable=disable_tqdm):
        txt = data[0].split(".")[0]
        txt_to_filename_dict[txt] = data[1]['key']
    return txt_to_filename_dict


def single_thread(args):
    i = args["i"]
    output_dir = args["output_dir"]
    doc_shards = args["doc_shards"]
    image_shards = args["image_shards"]
    if i == 0:
        tqdm.write(f"output_dir: {output_dir}")
        tqdm.write(f"doc_shards: {doc_shards[:5]}")
        tqdm.write(f"image_shards: {image_shards[:5]}")
    with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink:
        sink.verbose = False
        for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)):
            # txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0))
            # image_tar = tarfile.open(image_shard)
            # Open the ZIP archive and extract the JSON file
            with zipfile.ZipFile(doc_shard, "r") as zip_file:
                # Assumes the JSON file is the first file in the archive
                json_filename = zip_file.namelist()[0]
                with zip_file.open(json_filename, "r") as json_file:
                    pbar = tqdm(json_file, disable=True)
                    total_num = 0
                    exist_num = 0
                    for sample_data in pbar:
                        # get image names from json
                        sample_data = json.loads(sample_data)
                        image_info = sample_data["image_info"]
                        image_names = [image["image_name"] for image in image_info]

                        # Add each image to the tar file
                        for img_idx, image_name in enumerate(image_names):
                            total_num += 1
                            try:
                                image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg")
                                # convert to base64
                                image_bytes = image.read()
                                image_base64 = base64.b64encode(image_bytes).decode("utf-8")
                                exist_num += 1
                            except:
                                tqdm.write(f"{image_name.split('.')[0]}")
                                image_base64 = "null"
                            sample_data["image_info"][img_idx][
                                "image_base64"
                            ] = image_base64

                        key_str = uuid.uuid4().hex
                        sink.write({"__key__": key_str, "json": sample_data})
                        pbar.set_description(f"{exist_num/total_num:.2f}")
            # image_tar.close()


def main():
    timestamp = int(time.time())
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True)
    tasks = []
    for i in range(args.thread):
        thread_dir = os.path.join(args.output_dir, str(timestamp), str(i))
        os.makedirs(thread_dir, exist_ok=True)
        tasks.append({
            "i": i,
            "output_dir": thread_dir,
            "doc_shards": [],
            "image_shards": [],
        })

    doc_shards = list(braceexpand.braceexpand(args.doc_shards))
    image_shards = list(braceexpand.braceexpand(args.image_shards))

    assert len(doc_shards) == len(
        image_shards
    ), "Each doc shards must have a corresponding image shard"

    for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)):
        tasks[i % args.thread]["doc_shards"].append(doc_shard)
        tasks[i % args.thread]["image_shards"].append(image_shard)

    # assert len(tasks) == args.thread
    # process_map(single_thread, tasks, max_workers=args.thread, disable=True)
    single_thread(tasks[0])

if __name__ == "__main__":
    main()