nakas's picture
forked from huggingface demucs
fe84f3e
raw
history blame
7.97 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Dora the Explorer, special thank to @pierrestock.
"""
import argparse
import json
import logging
import shlex
import subprocess as sp
import time
from collections import namedtuple
from functools import partial
from itertools import product # noqa
from pathlib import Path
import treetable as tt # really great package for ascii art tables
from demucs.parser import get_name, get_parser
logger = logging.getLogger(__name__)
parser = get_parser()
logs = Path("logs")
logs.mkdir(exist_ok=True)
Job = namedtuple("Job", "args name sid")
def fname(name, kind):
return logs / f"{name}.{kind}"
def get_sid(name):
sid_file = fname(name, "sid")
try:
return int(open(sid_file).read().strip())
except IOError:
return None
def cancel(sid):
sp.run(["scancel", str(sid)], check=True)
def reset_job(name):
sid_file = fname(name, "sid")
if sid_file.is_file():
sid_file.unlink()
def get_done(name):
done_file = fname(name, "done")
return done_file.exists()
def get_metrics(name):
json_file = fname(name, "json")
try:
return json.load(open(json_file))
except IOError:
return []
def schedule(name, args, nodes=1, partition="priority", time=2 * 24 * 60, large=True, gpus=8):
log = fname(name, "log")
command = [
"sbatch",
f"--job-name={name}",
f"--output={log}.%t",
"--mem=460G",
f"--cpus-per-task={8*gpus}",
f"--gpus={gpus}",
f"--nodes={nodes}",
"--tasks-per-node=1",
f"--partition={partition}",
# "--exclude=learnfair0748,learnfair0821",
"--comment=Old codebase, not requeue, very few jobs",
f"--time={time}",
]
if large:
command += ["--constraint=volta32gb"]
srun_flags = f"--output={shlex.quote(str(log))}.%t"
run_cmd = ["#!/bin/bash"]
run_cmd.append(f"srun {srun_flags} python3 run_slurm.py " + " ".join(args))
result = sp.run(command, stdout=sp.PIPE, input="\n".join(run_cmd).encode('utf-8'),
check=True).stdout.decode('utf-8')
sid = int(result.strip().rsplit(' ', 1)[1])
open(fname(name, "sid"), "w").write(str(sid))
return sid
def _check(sids):
cs_ids = ','.join(map(str, sids))
result = sp.run(['squeue', f'-j{cs_ids}', '-o%A,%T,%P', '--noheader'],
check=True,
capture_output=True)
lines = result.stdout.decode('utf-8').strip().split('\n')
results = {}
for line in lines:
line = line.strip()
if not line:
continue
sid, status, partition = line.split(',', 2)
sid = int(sid)
results[sid] = status.lower()
for sid in sids:
if sid not in results:
results[sid] = 'failed'
return results
class Monitor:
def __init__(self, cancel=False, base=[]):
self.cancel = cancel
self.base = base
self.jobs = []
def schedule(self, args, *vargs, **kwargs):
args = self.base + args
name = get_name(parser, parser.parse_args(args))
sid = get_sid(name)
if sid is None and not self.cancel:
sid = schedule(name, args, *vargs, **kwargs)
self.jobs.append(Job(sid=sid, name=name, args=args))
def gc(self):
names = set(job.name for job in self.jobs)
for f in logs.iterdir():
stem, suffix = f.name.rsplit(".", 1)
if suffix == "sid":
if stem not in names:
sid = get_sid(stem)
if sid is not None:
print(f"GCing {stem} / {sid}")
cancel(sid)
f.unlink()
def check(self, trim=None, reset=False):
to_check = []
statuses = {}
for job in self.jobs:
if get_done(job.name):
statuses[job.sid] = "done"
elif job.sid is not None:
to_check.append(job.sid)
statuses.update(_check(to_check))
if trim is not None:
trim = len(get_metrics(self.jobs[trim].name))
lines = []
for index, job in enumerate(self.jobs):
status = statuses.get(job.sid, "failed")
if status in ["failed", "completing"] and reset:
reset_job(job.name)
status = "reset"
meta = {'name': job.name, 'sid': job.sid, 'status': status[:2], "index": index}
metrics = get_metrics(job.name)
if trim is not None:
metrics = metrics[:trim]
meta["epoch"] = len(metrics)
if metrics:
metrics = metrics[-1]
else:
metrics = {}
lines.append({'meta': meta, 'metrics': metrics})
table = tt.table(shorten=True,
groups=[
tt.group("meta", [
tt.leaf("index", align=">"),
tt.leaf("name"),
tt.leaf("sid", align=">"),
tt.leaf("status"),
tt.leaf("epoch", align=">")
]),
tt.group("metrics", [
tt.leaf("train", ".2%"),
tt.leaf("valid", ".2%"),
tt.leaf("best", ".2%"),
tt.leaf("true_model_size", ".2f"),
tt.leaf("compressed_model_size", ".2f"),
])
])
print(tt.treetable(lines, table, colors=["0", "38;5;245"]))
def main():
parser = argparse.ArgumentParser("grid.py")
parser.add_argument("-c", "--cancel", action="store_true", help="Cancel all jobs")
parser.add_argument(
"-r",
"--reset",
action="store_true",
help="Will reset the state of failed jobs. Next invocation will reschedule them")
parser.add_argument("-t", "--trim", type=int, help="Trim metrics to match job with given index")
args = parser.parse_args()
monitor = Monitor(base=[], cancel=args.cancel)
sched = partial(monitor.schedule, nodes=1)
tasnet = ["--tasnet", "--split_valid", "--samples=80000", "--X=10", "-b", "32"]
extra_path = Path.home() / "musdb_raw_44_allstems"
extra = [f"--raw={extra_path}"]
sched([])
sched(extra)
sched(tasnet)
sched(tasnet + ["--repitch=0"])
sched(tasnet + extra + ["--repitch=0"])
ch48 = ["--channels=48"]
sched(ch48)
ch32 = ["--channels=32"]
sched(ch32)
# Main models
for seed in [43, 44]:
base = [f"--seed={seed}"]
sched(base)
sched(base + extra)
sched(base + tasnet)
sched(base + tasnet + extra)
# Ablation study
sched(["--no_glu"])
sched(["--no_rewrite"])
sched(["--context=1"])
sched(["--rescale=0"])
sched(["--mse"])
sched(["--lstm_layers=0"])
sched(["--lstm_layers=0", "--depth=7"])
sched(["--no_resample"])
sched(["--repitch=0"])
# Quantization
sched(["--diffq=0.0003"])
if args.cancel:
for job in monitor.jobs:
if job.sid is not None:
print(f"Canceling {job.name}/{job.sid}")
cancel(job.sid)
return
names = [job.name for job in monitor.jobs]
json.dump(names, open(logs / "herd.json", "w"))
# Cancel any running job that was removed from the above sched calls.
monitor.gc()
while True:
if args.reset:
monitor.check(reset=True)
return
monitor.check(trim=args.trim)
time.sleep(5 * 60)
if __name__ == "__main__":
main()