Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Dora the Explorer, special thank to @pierrestock. | |
""" | |
import argparse | |
import json | |
import logging | |
import shlex | |
import subprocess as sp | |
import time | |
from collections import namedtuple | |
from functools import partial | |
from itertools import product # noqa | |
from pathlib import Path | |
import treetable as tt # really great package for ascii art tables | |
from demucs.parser import get_name, get_parser | |
logger = logging.getLogger(__name__) | |
parser = get_parser() | |
logs = Path("logs") | |
logs.mkdir(exist_ok=True) | |
Job = namedtuple("Job", "args name sid") | |
def fname(name, kind): | |
return logs / f"{name}.{kind}" | |
def get_sid(name): | |
sid_file = fname(name, "sid") | |
try: | |
return int(open(sid_file).read().strip()) | |
except IOError: | |
return None | |
def cancel(sid): | |
sp.run(["scancel", str(sid)], check=True) | |
def reset_job(name): | |
sid_file = fname(name, "sid") | |
if sid_file.is_file(): | |
sid_file.unlink() | |
def get_done(name): | |
done_file = fname(name, "done") | |
return done_file.exists() | |
def get_metrics(name): | |
json_file = fname(name, "json") | |
try: | |
return json.load(open(json_file)) | |
except IOError: | |
return [] | |
def schedule(name, args, nodes=1, partition="priority", time=2 * 24 * 60, large=True, gpus=8): | |
log = fname(name, "log") | |
command = [ | |
"sbatch", | |
f"--job-name={name}", | |
f"--output={log}.%t", | |
"--mem=460G", | |
f"--cpus-per-task={8*gpus}", | |
f"--gpus={gpus}", | |
f"--nodes={nodes}", | |
"--tasks-per-node=1", | |
f"--partition={partition}", | |
# "--exclude=learnfair0748,learnfair0821", | |
"--comment=Old codebase, not requeue, very few jobs", | |
f"--time={time}", | |
] | |
if large: | |
command += ["--constraint=volta32gb"] | |
srun_flags = f"--output={shlex.quote(str(log))}.%t" | |
run_cmd = ["#!/bin/bash"] | |
run_cmd.append(f"srun {srun_flags} python3 run_slurm.py " + " ".join(args)) | |
result = sp.run(command, stdout=sp.PIPE, input="\n".join(run_cmd).encode('utf-8'), | |
check=True).stdout.decode('utf-8') | |
sid = int(result.strip().rsplit(' ', 1)[1]) | |
open(fname(name, "sid"), "w").write(str(sid)) | |
return sid | |
def _check(sids): | |
cs_ids = ','.join(map(str, sids)) | |
result = sp.run(['squeue', f'-j{cs_ids}', '-o%A,%T,%P', '--noheader'], | |
check=True, | |
capture_output=True) | |
lines = result.stdout.decode('utf-8').strip().split('\n') | |
results = {} | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
sid, status, partition = line.split(',', 2) | |
sid = int(sid) | |
results[sid] = status.lower() | |
for sid in sids: | |
if sid not in results: | |
results[sid] = 'failed' | |
return results | |
class Monitor: | |
def __init__(self, cancel=False, base=[]): | |
self.cancel = cancel | |
self.base = base | |
self.jobs = [] | |
def schedule(self, args, *vargs, **kwargs): | |
args = self.base + args | |
name = get_name(parser, parser.parse_args(args)) | |
sid = get_sid(name) | |
if sid is None and not self.cancel: | |
sid = schedule(name, args, *vargs, **kwargs) | |
self.jobs.append(Job(sid=sid, name=name, args=args)) | |
def gc(self): | |
names = set(job.name for job in self.jobs) | |
for f in logs.iterdir(): | |
stem, suffix = f.name.rsplit(".", 1) | |
if suffix == "sid": | |
if stem not in names: | |
sid = get_sid(stem) | |
if sid is not None: | |
print(f"GCing {stem} / {sid}") | |
cancel(sid) | |
f.unlink() | |
def check(self, trim=None, reset=False): | |
to_check = [] | |
statuses = {} | |
for job in self.jobs: | |
if get_done(job.name): | |
statuses[job.sid] = "done" | |
elif job.sid is not None: | |
to_check.append(job.sid) | |
statuses.update(_check(to_check)) | |
if trim is not None: | |
trim = len(get_metrics(self.jobs[trim].name)) | |
lines = [] | |
for index, job in enumerate(self.jobs): | |
status = statuses.get(job.sid, "failed") | |
if status in ["failed", "completing"] and reset: | |
reset_job(job.name) | |
status = "reset" | |
meta = {'name': job.name, 'sid': job.sid, 'status': status[:2], "index": index} | |
metrics = get_metrics(job.name) | |
if trim is not None: | |
metrics = metrics[:trim] | |
meta["epoch"] = len(metrics) | |
if metrics: | |
metrics = metrics[-1] | |
else: | |
metrics = {} | |
lines.append({'meta': meta, 'metrics': metrics}) | |
table = tt.table(shorten=True, | |
groups=[ | |
tt.group("meta", [ | |
tt.leaf("index", align=">"), | |
tt.leaf("name"), | |
tt.leaf("sid", align=">"), | |
tt.leaf("status"), | |
tt.leaf("epoch", align=">") | |
]), | |
tt.group("metrics", [ | |
tt.leaf("train", ".2%"), | |
tt.leaf("valid", ".2%"), | |
tt.leaf("best", ".2%"), | |
tt.leaf("true_model_size", ".2f"), | |
tt.leaf("compressed_model_size", ".2f"), | |
]) | |
]) | |
print(tt.treetable(lines, table, colors=["0", "38;5;245"])) | |
def main(): | |
parser = argparse.ArgumentParser("grid.py") | |
parser.add_argument("-c", "--cancel", action="store_true", help="Cancel all jobs") | |
parser.add_argument( | |
"-r", | |
"--reset", | |
action="store_true", | |
help="Will reset the state of failed jobs. Next invocation will reschedule them") | |
parser.add_argument("-t", "--trim", type=int, help="Trim metrics to match job with given index") | |
args = parser.parse_args() | |
monitor = Monitor(base=[], cancel=args.cancel) | |
sched = partial(monitor.schedule, nodes=1) | |
tasnet = ["--tasnet", "--split_valid", "--samples=80000", "--X=10", "-b", "32"] | |
extra_path = Path.home() / "musdb_raw_44_allstems" | |
extra = [f"--raw={extra_path}"] | |
sched([]) | |
sched(extra) | |
sched(tasnet) | |
sched(tasnet + ["--repitch=0"]) | |
sched(tasnet + extra + ["--repitch=0"]) | |
ch48 = ["--channels=48"] | |
sched(ch48) | |
ch32 = ["--channels=32"] | |
sched(ch32) | |
# Main models | |
for seed in [43, 44]: | |
base = [f"--seed={seed}"] | |
sched(base) | |
sched(base + extra) | |
sched(base + tasnet) | |
sched(base + tasnet + extra) | |
# Ablation study | |
sched(["--no_glu"]) | |
sched(["--no_rewrite"]) | |
sched(["--context=1"]) | |
sched(["--rescale=0"]) | |
sched(["--mse"]) | |
sched(["--lstm_layers=0"]) | |
sched(["--lstm_layers=0", "--depth=7"]) | |
sched(["--no_resample"]) | |
sched(["--repitch=0"]) | |
# Quantization | |
sched(["--diffq=0.0003"]) | |
if args.cancel: | |
for job in monitor.jobs: | |
if job.sid is not None: | |
print(f"Canceling {job.name}/{job.sid}") | |
cancel(job.sid) | |
return | |
names = [job.name for job in monitor.jobs] | |
json.dump(names, open(logs / "herd.json", "w")) | |
# Cancel any running job that was removed from the above sched calls. | |
monitor.gc() | |
while True: | |
if args.reset: | |
monitor.check(reset=True) | |
return | |
monitor.check(trim=args.trim) | |
time.sleep(5 * 60) | |
if __name__ == "__main__": | |
main() | |