olm-chat-7b / open_lm /utils /make_wds_manifest.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
raw
history blame
2.43 kB
import argparse
import re
import simdjson
import sys
import subprocess
import multiprocessing as mp
from pathlib import Path
from cloudpathlib import CloudPath
from tqdm import tqdm
def path_or_cloudpath(s):
if re.match(r"^\w+://", s):
return CloudPath(s)
return Path(s)
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-dir",
type=path_or_cloudpath,
required=True,
help="Directory containing a dataset in webdataset format.",
)
parser.add_argument(
"--manifest-filename",
type=str,
default="manifest.jsonl",
help="Filename for the manifest that will be stored in the webdataset directory.",
)
parser.add_argument("--tmp-dir", type=str, default=None, help="Temporary directory.")
parser.add_argument("--num-workers", type=int, default=2, help="Number of workers.")
args = parser.parse_args(args)
return args
def count_samples(shard_path, tmp_dir):
if isinstance(shard_path, CloudPath):
temp_shard_path = Path(tmp_dir) / shard_path.name
shard_path.download_to(temp_shard_path)
else:
temp_shard_path = shard_path
count = int(subprocess.check_output(f"tar tf {temp_shard_path} | wc -l", shell=True))
if isinstance(shard_path, CloudPath):
temp_shard_path.unlink()
return count
def worker_fn(input_data):
basename, data_dir, tmp_dir = input_data
shard_path = data_dir / basename
return (
basename,
{
"shard": basename.split(".")[0],
"num_sequences": count_samples(shard_path, tmp_dir),
},
)
def main(args):
args = parse_args(args)
shards = sorted([x for x in args.data_dir.iterdir() if x.name.endswith(".tar")])
input_data = [(shard.name, args.data_dir, args.tmp_dir) for shard in shards]
print(f"Shards to process: {len(shards)}")
print("Creating pool.")
with mp.Pool(args.num_workers) as pool:
data = []
for worker_data in tqdm(pool.imap_unordered(worker_fn, input_data)):
data.append(worker_data)
data = sorted(data)
data = [item[1] for item in data]
manifest_path = args.data_dir / args.manifest_filename
with manifest_path.open("w") as fp:
for item in data:
simdjson.dump(item, fp)
fp.write("\n")
if __name__ == "__main__":
main(sys.argv[1:])