File size: 3,104 Bytes
e1aaaac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import argparse
import json
import os
import uuid
import zipfile
from PIL import Image
import base64
from io import BytesIO
import braceexpand
import webdataset as wds
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--output_dir",
type=str,
help="Pass in the directory where the output shards (as tar files) will be written to.",
)
arg_parser.add_argument(
"--zip_files",
type=str,
help="Pass in a list of MMC4 shards in the format path_to_shard/shard_{0..23098}.zip",
)
arg_parser.add_argument(
"--image_dir",
type=str,
help="Pass in the directory where the images have been downloaded to.",
)
arg_parser.add_argument(
"--num_files_per_shard",
type=int,
default=1000,
)
args = arg_parser.parse_args()
def main():
os.makedirs(args.output_dir, exist_ok=True)
doc_shards = list(braceexpand.braceexpand(args.zip_files))
with wds.ShardWriter(args.output_dir + "/%09d.tar") as sink:
for idx in range(len(doc_shards)):
# Open the ZIP archive and extract the JSON file
with zipfile.ZipFile(doc_shards[idx], "r") as zip_file:
# Assumes the JSON file is the first file in the archive
json_filename = zip_file.namelist()[0]
with zip_file.open(json_filename, "r") as json_file:
for sample_data in json_file:
# get image names from json
sample_data = json.loads(sample_data)
image_info = sample_data["image_info"]
image_names = [image["image_name"] for image in image_info]
# Add each image to the tar file
for img_idx, image_name in enumerate(image_names):
try:
# load image
img = Image.open(
os.path.join(args.image_dir, str(idx), image_name)
).convert("RGB")
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
# convert to base64
sample_data["image_info"][img_idx][
"image_base64"
] = img_str.decode("utf-8")
except FileNotFoundError:
print(
f"Did not find {image_name} downloaded. This can happen if the url is now 404."
)
except Exception as e:
print(f"Error processing {image_name}: {e}")
key_str = uuid.uuid4().hex
sink.write({"__key__": key_str, "json": sample_data})
if (idx + 1) % args.num_files_per_shard == 0:
sink.next_stream()
if __name__ == "__main__":
main()
|