# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import gzip import sys from concurrent import futures import musdb import museval import torch as th import tqdm from scipy.io import wavfile from torch import distributed from .audio import convert_audio from .utils import apply_model def evaluate(model, musdb_path, eval_folder, workers=2, device="cpu", rank=0, save=False, shifts=0, split=False, overlap=0.25, is_wav=False, world_size=1): """ Evaluate model using museval. Run the model on a single GPU, the bottleneck being the call to museval. """ output_dir = eval_folder / "results" output_dir.mkdir(exist_ok=True, parents=True) json_folder = eval_folder / "results/test" json_folder.mkdir(exist_ok=True, parents=True) # we load tracks from the original musdb set test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav) src_rate = 44100 # hardcoded for now... for p in model.parameters(): p.requires_grad = False p.grad = None pendings = [] with futures.ProcessPoolExecutor(workers or 1) as pool: for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout): track = test_set.tracks[index] out = json_folder / f"{track.name}.json.gz" if out.exists(): continue mix = th.from_numpy(track.audio).t().float() ref = mix.mean(dim=0) # mono mixture mix = (mix - ref.mean()) / ref.std() mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) estimates = apply_model(model, mix.to(device), shifts=shifts, split=split, overlap=overlap) estimates = estimates * ref.std() + ref.mean() estimates = estimates.transpose(1, 2) references = th.stack( [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) references = convert_audio(references, src_rate, model.samplerate, model.audio_channels) references = references.transpose(1, 2).numpy() estimates = estimates.cpu().numpy() win = int(1. * model.samplerate) hop = int(1. * model.samplerate) if save: folder = eval_folder / "wav/test" / track.name folder.mkdir(exist_ok=True, parents=True) for name, estimate in zip(model.sources, estimates): wavfile.write(str(folder / (name + ".wav")), 44100, estimate) if workers: pendings.append((track.name, pool.submit( museval.evaluate, references, estimates, win=win, hop=hop))) else: pendings.append((track.name, museval.evaluate( references, estimates, win=win, hop=hop))) del references, mix, estimates, track for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout): if workers: pending = pending.result() sdr, isr, sir, sar = pending track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name) for idx, target in enumerate(model.sources): values = { "SDR": sdr[idx].tolist(), "SIR": sir[idx].tolist(), "ISR": isr[idx].tolist(), "SAR": sar[idx].tolist() } track_store.add_target(target_name=target, values=values) json_path = json_folder / f"{track_name}.json.gz" gzip.open(json_path, "w").write(track_store.json.encode('utf-8')) if world_size > 1: distributed.barrier()