|
import argparse |
|
import re |
|
import simdjson |
|
import sys |
|
import subprocess |
|
import multiprocessing as mp |
|
from pathlib import Path |
|
from cloudpathlib import CloudPath |
|
from tqdm import tqdm |
|
|
|
|
|
def path_or_cloudpath(s): |
|
if re.match(r"^\w+://", s): |
|
return CloudPath(s) |
|
return Path(s) |
|
|
|
|
|
def parse_args(args): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--data-dir", |
|
type=path_or_cloudpath, |
|
required=True, |
|
help="Directory containing a dataset in webdataset format.", |
|
) |
|
parser.add_argument( |
|
"--manifest-filename", |
|
type=str, |
|
default="manifest.jsonl", |
|
help="Filename for the manifest that will be stored in the webdataset directory.", |
|
) |
|
parser.add_argument("--tmp-dir", type=str, default=None, help="Temporary directory.") |
|
parser.add_argument("--num-workers", type=int, default=2, help="Number of workers.") |
|
args = parser.parse_args(args) |
|
return args |
|
|
|
|
|
def count_samples(shard_path, tmp_dir): |
|
if isinstance(shard_path, CloudPath): |
|
temp_shard_path = Path(tmp_dir) / shard_path.name |
|
shard_path.download_to(temp_shard_path) |
|
else: |
|
temp_shard_path = shard_path |
|
|
|
count = int(subprocess.check_output(f"tar tf {temp_shard_path} | wc -l", shell=True)) |
|
|
|
if isinstance(shard_path, CloudPath): |
|
temp_shard_path.unlink() |
|
|
|
return count |
|
|
|
|
|
def worker_fn(input_data): |
|
basename, data_dir, tmp_dir = input_data |
|
shard_path = data_dir / basename |
|
return ( |
|
basename, |
|
{ |
|
"shard": basename.split(".")[0], |
|
"num_sequences": count_samples(shard_path, tmp_dir), |
|
}, |
|
) |
|
|
|
|
|
def main(args): |
|
args = parse_args(args) |
|
|
|
shards = sorted([x for x in args.data_dir.iterdir() if x.name.endswith(".tar")]) |
|
input_data = [(shard.name, args.data_dir, args.tmp_dir) for shard in shards] |
|
|
|
print(f"Shards to process: {len(shards)}") |
|
print("Creating pool.") |
|
with mp.Pool(args.num_workers) as pool: |
|
data = [] |
|
for worker_data in tqdm(pool.imap_unordered(worker_fn, input_data)): |
|
data.append(worker_data) |
|
|
|
data = sorted(data) |
|
data = [item[1] for item in data] |
|
manifest_path = args.data_dir / args.manifest_filename |
|
with manifest_path.open("w") as fp: |
|
for item in data: |
|
simdjson.dump(item, fp) |
|
fp.write("\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(sys.argv[1:]) |
|
|