Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import subprocess | |
from pathlib import Path | |
import submitit | |
import typer | |
class PreprocessEntropiesJob(submitit.helpers.Checkpointable): | |
def __init__(self) -> None: | |
pass | |
def __call__(self, shard_file: str, output_filename: str): | |
subprocess.run( | |
[ | |
"python", | |
"-u", | |
"-m", | |
"bytelatent.preprocess.preprocess_entropies", | |
str(shard_file), | |
str(output_filename), | |
], | |
check=True, | |
) | |
return True | |
def chunk(items, size): | |
for i in range(0, len(items), size): | |
yield items[i : i + size] | |
def main( | |
job_folder: str, | |
input_dir: str, | |
output_dir: str, | |
qos: str = "explore", | |
slurm_batch_size: int = 1000, | |
check_only: bool = False, | |
wait: bool = False, | |
): | |
input_dir = Path(input_dir) | |
output_dir = Path(output_dir) | |
shard_files = [ | |
p for p in input_dir.glob("*.jsonl.shard*") if "COMPLETE" not in p.name | |
] | |
if check_only: | |
exist = [] | |
missing = [] | |
for shard_file in shard_files: | |
shard_file = Path(shard_file) | |
complete_file = output_dir / f"{shard_file.name}.arrow.complete" | |
if complete_file.exists(): | |
exist.append(complete_file) | |
else: | |
missing.append(complete_file) | |
print("Checked for output files for input_dir=", input_dir) | |
print("Exist:", len(exist)) | |
print("Missing:", len(missing)) | |
print(missing) | |
return | |
print("Running parallel job over N files=", len(shard_files)) | |
print("Input Directory:", input_dir) | |
print("Output Directory:", output_dir) | |
output_dir.mkdir(exist_ok=True, parents=True) | |
executor = submitit.SlurmExecutor(job_folder) | |
executor.update_parameters( | |
# 12 hours in minutes | |
time=60 * 12, | |
qos=qos, | |
exclusive="user", | |
cpus_per_task=4, | |
num_gpus=1, | |
mem_per_gpu="80G", | |
array_parallelism=slurm_batch_size, | |
) | |
jobs = [] | |
n_batches = 0 | |
n_skipped = 0 | |
n_launched = 0 | |
for file_batch in chunk(shard_files, slurm_batch_size): | |
with executor.batch(): | |
for shard_file in file_batch: | |
output_filename = Path(output_dir) / f"{shard_file.name}.arrow" | |
complete_output_filename = ( | |
Path(output_dir) / f"{shard_file.name}.arrow.complete" | |
) | |
if complete_output_filename.exists(): | |
n_skipped += 1 | |
else: | |
job = executor.submit( | |
PreprocessEntropiesJob(), str(shard_file), str(output_filename) | |
) | |
n_launched += 1 | |
jobs.append(job) | |
n_batches += 1 | |
print("launched array jobs n=", n_launched) | |
print("skipped (completed) array jobs n=", n_skipped) | |
print("number of slurm batches=", n_batches) | |
if wait: | |
output = [job.result() for job in jobs] | |
assert all(output) | |
if __name__ == "__main__": | |
typer.run(main) | |