# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json import os import sys import time from dataclasses import dataclass, field from fractions import Fraction import torch as th from torch import distributed, nn from torch.nn.parallel.distributed import DistributedDataParallel from .augment import FlipChannels, FlipSign, Remix, Shift from .compressed import StemsSet, build_musdb_metadata, get_musdb_tracks from .model import Demucs from .parser import get_name, get_parser from .raw import Rawset from .tasnet import ConvTasNet from .test import evaluate from .train import train_model, validate_model from .utils import human_seconds, load_model, save_model, sizeof_fmt @dataclass class SavedState: metrics: list = field(default_factory=list) last_state: dict = None best_state: dict = None optimizer: dict = None def main(): parser = get_parser() args = parser.parse_args() name = get_name(parser, args) print(f"Experiment {name}") if args.musdb is None and args.rank == 0: print( "You must provide the path to the MusDB dataset with the --musdb flag. " "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", file=sys.stderr) sys.exit(1) eval_folder = args.evals / name eval_folder.mkdir(exist_ok=True, parents=True) args.logs.mkdir(exist_ok=True) metrics_path = args.logs / f"{name}.json" eval_folder.mkdir(exist_ok=True, parents=True) args.checkpoints.mkdir(exist_ok=True, parents=True) args.models.mkdir(exist_ok=True, parents=True) if args.device is None: device = "cpu" if th.cuda.is_available(): device = "cuda" else: device = args.device th.manual_seed(args.seed) # Prevents too many threads to be started when running `museval` as it can be quite # inefficient on NUMA architectures. os.environ["OMP_NUM_THREADS"] = "1" if args.world_size > 1: if device != "cuda" and args.rank == 0: print("Error: distributed training is only available with cuda device", file=sys.stderr) sys.exit(1) th.cuda.set_device(args.rank % th.cuda.device_count()) distributed.init_process_group(backend="nccl", init_method="tcp://" + args.master, rank=args.rank, world_size=args.world_size) checkpoint = args.checkpoints / f"{name}.th" checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" if args.restart and checkpoint.exists(): checkpoint.unlink() if args.test: args.epochs = 1 args.repeat = 0 model = load_model(args.models / args.test) elif args.tasnet: model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X) else: model = Demucs( audio_channels=args.audio_channels, channels=args.channels, context=args.context, depth=args.depth, glu=args.glu, growth=args.growth, kernel_size=args.kernel_size, lstm_layers=args.lstm_layers, rescale=args.rescale, rewrite=args.rewrite, sources=4, stride=args.conv_stride, upsample=args.upsample, samplerate=args.samplerate ) model.to(device) if args.show: print(model) size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) print(f"Model size {size}") return optimizer = th.optim.Adam(model.parameters(), lr=args.lr) try: saved = th.load(checkpoint, map_location='cpu') except IOError: saved = SavedState() else: model.load_state_dict(saved.last_state) optimizer.load_state_dict(saved.optimizer) if args.save_model: if args.rank == 0: model.to("cpu") model.load_state_dict(saved.best_state) save_model(model, args.models / f"{name}.th") return if args.rank == 0: done = args.logs / f"{name}.done" if done.exists(): done.unlink() if args.augment: augment = nn.Sequential(FlipSign(), FlipChannels(), Shift(args.data_stride), Remix(group_size=args.remix_group_size)).to(device) else: augment = Shift(args.data_stride) if args.mse: criterion = nn.MSELoss() else: criterion = nn.L1Loss() # Setting number of samples so that all convolution windows are full. # Prevents hard to debug mistake with the prediction being shifted compared # to the input mixture. samples = model.valid_length(args.samples) print(f"Number of training samples adjusted to {samples}") if args.raw: train_set = Rawset(args.raw / "train", samples=samples + args.data_stride, channels=args.audio_channels, streams=[0, 1, 2, 3, 4], stride=args.data_stride) valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) else: if not args.metadata.is_file() and args.rank == 0: build_musdb_metadata(args.metadata, args.musdb, args.workers) if args.world_size > 1: distributed.barrier() metadata = json.load(open(args.metadata)) duration = Fraction(samples + args.data_stride, args.samplerate) stride = Fraction(args.data_stride, args.samplerate) train_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="train"), metadata, duration=duration, stride=stride, samplerate=args.samplerate, channels=args.audio_channels) valid_set = StemsSet(get_musdb_tracks(args.musdb, subsets=["train"], split="valid"), metadata, samplerate=args.samplerate, channels=args.audio_channels) best_loss = float("inf") for epoch, metrics in enumerate(saved.metrics): print(f"Epoch {epoch:03d}: " f"train={metrics['train']:.8f} " f"valid={metrics['valid']:.8f} " f"best={metrics['best']:.4f} " f"duration={human_seconds(metrics['duration'])}") best_loss = metrics['best'] if args.world_size > 1: dmodel = DistributedDataParallel(model, device_ids=[th.cuda.current_device()], output_device=th.cuda.current_device()) else: dmodel = model for epoch in range(len(saved.metrics), args.epochs): begin = time.time() model.train() train_loss = train_model(epoch, train_set, dmodel, criterion, optimizer, augment, batch_size=args.batch_size, device=device, repeat=args.repeat, seed=args.seed, workers=args.workers, world_size=args.world_size) model.eval() valid_loss = validate_model(epoch, valid_set, model, criterion, device=device, rank=args.rank, split=args.split_valid, world_size=args.world_size) duration = time.time() - begin if valid_loss < best_loss: best_loss = valid_loss saved.best_state = { key: value.to("cpu").clone() for key, value in model.state_dict().items() } saved.metrics.append({ "train": train_loss, "valid": valid_loss, "best": best_loss, "duration": duration }) if args.rank == 0: json.dump(saved.metrics, open(metrics_path, "w")) saved.last_state = model.state_dict() saved.optimizer = optimizer.state_dict() if args.rank == 0 and not args.test: th.save(saved, checkpoint_tmp) checkpoint_tmp.rename(checkpoint) print(f"Epoch {epoch:03d}: " f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} " f"duration={human_seconds(duration)}") del dmodel model.load_state_dict(saved.best_state) if args.eval_cpu: device = "cpu" model.to(device) model.eval() evaluate(model, args.musdb, eval_folder, rank=args.rank, world_size=args.world_size, device=device, save=args.save, split=args.split_valid, shifts=args.shifts, workers=args.eval_workers) model.to("cpu") save_model(model, args.models / f"{name}.th") if args.rank == 0: print("done") done.write_text("done") if __name__ == "__main__": main()