import argparse import re import simdjson import sys import subprocess import multiprocessing as mp from pathlib import Path from cloudpathlib import CloudPath from tqdm import tqdm def path_or_cloudpath(s): if re.match(r"^\w+://", s): return CloudPath(s) return Path(s) def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument( "--data-dir", type=path_or_cloudpath, required=True, help="Directory containing a dataset in webdataset format.", ) parser.add_argument( "--manifest-filename", type=str, default="manifest.jsonl", help="Filename for the manifest that will be stored in the webdataset directory.", ) parser.add_argument("--tmp-dir", type=str, default=None, help="Temporary directory.") parser.add_argument("--num-workers", type=int, default=2, help="Number of workers.") args = parser.parse_args(args) return args def count_samples(shard_path, tmp_dir): if isinstance(shard_path, CloudPath): temp_shard_path = Path(tmp_dir) / shard_path.name shard_path.download_to(temp_shard_path) else: temp_shard_path = shard_path count = int(subprocess.check_output(f"tar tf {temp_shard_path} | wc -l", shell=True)) if isinstance(shard_path, CloudPath): temp_shard_path.unlink() return count def worker_fn(input_data): basename, data_dir, tmp_dir = input_data shard_path = data_dir / basename return ( basename, { "shard": basename.split(".")[0], "num_sequences": count_samples(shard_path, tmp_dir), }, ) def main(args): args = parse_args(args) shards = sorted([x for x in args.data_dir.iterdir() if x.name.endswith(".tar")]) input_data = [(shard.name, args.data_dir, args.tmp_dir) for shard in shards] print(f"Shards to process: {len(shards)}") print("Creating pool.") with mp.Pool(args.num_workers) as pool: data = [] for worker_data in tqdm(pool.imap_unordered(worker_fn, input_data)): data.append(worker_data) data = sorted(data) data = [item[1] for item in data] manifest_path = args.data_dir / args.manifest_filename with manifest_path.open("w") as fp: for item in data: simdjson.dump(item, fp) fp.write("\n") if __name__ == "__main__": main(sys.argv[1:])