# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json import math import os import sys import time from dataclasses import dataclass, field import torch as th from torch import distributed, nn from torch.nn.parallel.distributed import DistributedDataParallel from .augment import FlipChannels, FlipSign, Remix, Scale, Shift from .compressed import get_compressed_datasets from .model import Demucs from .parser import get_name, get_parser from .raw import Rawset from .repitch import RepitchedWrapper from .pretrained import load_pretrained, SOURCES from .tasnet import ConvTasNet from .test import evaluate from .train import train_model, validate_model from .utils import (human_seconds, load_model, save_model, get_state, save_state, sizeof_fmt, get_quantizer) from .wav import get_wav_datasets, get_musdb_wav_datasets @dataclass class SavedState: metrics: list = field(default_factory=list) last_state: dict = None best_state: dict = None optimizer: dict = None def main(): parser = get_parser() args = parser.parse_args() name = get_name(parser, args) print(f"Experiment {name}") if args.musdb is None and args.rank == 0: print( "You must provide the path to the MusDB dataset with the --musdb flag. " "To download the MusDB dataset, see https://sigsep.github.io/datasets/musdb.html.", file=sys.stderr) sys.exit(1) eval_folder = args.evals / name eval_folder.mkdir(exist_ok=True, parents=True) args.logs.mkdir(exist_ok=True) metrics_path = args.logs / f"{name}.json" eval_folder.mkdir(exist_ok=True, parents=True) args.checkpoints.mkdir(exist_ok=True, parents=True) args.models.mkdir(exist_ok=True, parents=True) if args.device is None: device = "cpu" if th.cuda.is_available(): device = "cuda" else: device = args.device th.manual_seed(args.seed) # Prevents too many threads to be started when running `museval` as it can be quite # inefficient on NUMA architectures. os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" if args.world_size > 1: if device != "cuda" and args.rank == 0: print("Error: distributed training is only available with cuda device", file=sys.stderr) sys.exit(1) th.cuda.set_device(args.rank % th.cuda.device_count()) distributed.init_process_group(backend="nccl", init_method="tcp://" + args.master, rank=args.rank, world_size=args.world_size) checkpoint = args.checkpoints / f"{name}.th" checkpoint_tmp = args.checkpoints / f"{name}.th.tmp" if args.restart and checkpoint.exists() and args.rank == 0: checkpoint.unlink() if args.test or args.test_pretrained: args.epochs = 1 args.repeat = 0 if args.test: model = load_model(args.models / args.test) else: model = load_pretrained(args.test_pretrained) elif args.tasnet: model = ConvTasNet(audio_channels=args.audio_channels, samplerate=args.samplerate, X=args.X, segment_length=4 * args.samples, sources=SOURCES) else: model = Demucs( audio_channels=args.audio_channels, channels=args.channels, context=args.context, depth=args.depth, glu=args.glu, growth=args.growth, kernel_size=args.kernel_size, lstm_layers=args.lstm_layers, rescale=args.rescale, rewrite=args.rewrite, stride=args.conv_stride, resample=args.resample, normalize=args.normalize, samplerate=args.samplerate, segment_length=4 * args.samples, sources=SOURCES, ) model.to(device) if args.init: model.load_state_dict(load_pretrained(args.init).state_dict()) if args.show: print(model) size = sizeof_fmt(4 * sum(p.numel() for p in model.parameters())) print(f"Model size {size}") return try: saved = th.load(checkpoint, map_location='cpu') except IOError: saved = SavedState() optimizer = th.optim.Adam(model.parameters(), lr=args.lr) quantizer = None quantizer = get_quantizer(model, args, optimizer) if saved.last_state is not None: model.load_state_dict(saved.last_state, strict=False) if saved.optimizer is not None: optimizer.load_state_dict(saved.optimizer) model_name = f"{name}.th" if args.save_model: if args.rank == 0: model.to("cpu") model.load_state_dict(saved.best_state) save_model(model, quantizer, args, args.models / model_name) return elif args.save_state: model_name = f"{args.save_state}.th" if args.rank == 0: model.to("cpu") model.load_state_dict(saved.best_state) state = get_state(model, quantizer) save_state(state, args.models / model_name) return if args.rank == 0: done = args.logs / f"{name}.done" if done.exists(): done.unlink() augment = [Shift(args.data_stride)] if args.augment: augment += [FlipSign(), FlipChannels(), Scale(), Remix(group_size=args.remix_group_size)] augment = nn.Sequential(*augment).to(device) print("Agumentation pipeline:", augment) if args.mse: criterion = nn.MSELoss() else: criterion = nn.L1Loss() # Setting number of samples so that all convolution windows are full. # Prevents hard to debug mistake with the prediction being shifted compared # to the input mixture. samples = model.valid_length(args.samples) print(f"Number of training samples adjusted to {samples}") samples = samples + args.data_stride if args.repitch: # We need a bit more audio samples, to account for potential # tempo change. samples = math.ceil(samples / (1 - 0.01 * args.max_tempo)) args.metadata.mkdir(exist_ok=True, parents=True) if args.raw: train_set = Rawset(args.raw / "train", samples=samples, channels=args.audio_channels, streams=range(1, len(model.sources) + 1), stride=args.data_stride) valid_set = Rawset(args.raw / "valid", channels=args.audio_channels) elif args.wav: train_set, valid_set = get_wav_datasets(args, samples, model.sources) elif args.is_wav: train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources) else: train_set, valid_set = get_compressed_datasets(args, samples) if args.repitch: train_set = RepitchedWrapper( train_set, proba=args.repitch, max_tempo=args.max_tempo) best_loss = float("inf") for epoch, metrics in enumerate(saved.metrics): print(f"Epoch {epoch:03d}: " f"train={metrics['train']:.8f} " f"valid={metrics['valid']:.8f} " f"best={metrics['best']:.4f} " f"ms={metrics.get('true_model_size', 0):.2f}MB " f"cms={metrics.get('compressed_model_size', 0):.2f}MB " f"duration={human_seconds(metrics['duration'])}") best_loss = metrics['best'] if args.world_size > 1: dmodel = DistributedDataParallel(model, device_ids=[th.cuda.current_device()], output_device=th.cuda.current_device()) else: dmodel = model for epoch in range(len(saved.metrics), args.epochs): begin = time.time() model.train() train_loss, model_size = train_model( epoch, train_set, dmodel, criterion, optimizer, augment, quantizer=quantizer, batch_size=args.batch_size, device=device, repeat=args.repeat, seed=args.seed, diffq=args.diffq, workers=args.workers, world_size=args.world_size) model.eval() valid_loss = validate_model( epoch, valid_set, model, criterion, device=device, rank=args.rank, split=args.split_valid, overlap=args.overlap, world_size=args.world_size) ms = 0 cms = 0 if quantizer and args.rank == 0: ms = quantizer.true_model_size() cms = quantizer.compressed_model_size(num_workers=min(40, args.world_size * 10)) duration = time.time() - begin if valid_loss < best_loss and ms <= args.ms_target: best_loss = valid_loss saved.best_state = { key: value.to("cpu").clone() for key, value in model.state_dict().items() } saved.metrics.append({ "train": train_loss, "valid": valid_loss, "best": best_loss, "duration": duration, "model_size": model_size, "true_model_size": ms, "compressed_model_size": cms, }) if args.rank == 0: json.dump(saved.metrics, open(metrics_path, "w")) saved.last_state = model.state_dict() saved.optimizer = optimizer.state_dict() if args.rank == 0 and not args.test: th.save(saved, checkpoint_tmp) checkpoint_tmp.rename(checkpoint) print(f"Epoch {epoch:03d}: " f"train={train_loss:.8f} valid={valid_loss:.8f} best={best_loss:.4f} ms={ms:.2f}MB " f"cms={cms:.2f}MB " f"duration={human_seconds(duration)}") if args.world_size > 1: distributed.barrier() del dmodel model.load_state_dict(saved.best_state) if args.eval_cpu: device = "cpu" model.to(device) model.eval() evaluate(model, args.musdb, eval_folder, is_wav=args.is_wav, rank=args.rank, world_size=args.world_size, device=device, save=args.save, split=args.split_valid, shifts=args.shifts, overlap=args.overlap, workers=args.eval_workers) model.to("cpu") if args.rank == 0: if not (args.test or args.test_pretrained): save_model(model, quantizer, args, args.models / model_name) print("done") done.write_text("done") if __name__ == "__main__": main()