| |
| |
| |
| |
| |
|
|
| import gzip |
| import sys |
| from concurrent import futures |
|
|
| import musdb |
| import museval |
| import torch as th |
| import tqdm |
| from scipy.io import wavfile |
| from torch import distributed |
|
|
| from .audio import convert_audio |
| from .utils import apply_model |
|
|
|
|
| def evaluate(model, |
| musdb_path, |
| eval_folder, |
| workers=2, |
| device="cpu", |
| rank=0, |
| save=False, |
| shifts=0, |
| split=False, |
| overlap=0.25, |
| is_wav=False, |
| world_size=1): |
| """ |
| Evaluate model using museval. Run the model |
| on a single GPU, the bottleneck being the call to museval. |
| """ |
|
|
| output_dir = eval_folder / "results" |
| output_dir.mkdir(exist_ok=True, parents=True) |
| json_folder = eval_folder / "results/test" |
| json_folder.mkdir(exist_ok=True, parents=True) |
|
|
| |
| test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav) |
| src_rate = 44100 |
|
|
| for p in model.parameters(): |
| p.requires_grad = False |
| p.grad = None |
|
|
| pendings = [] |
| with futures.ProcessPoolExecutor(workers or 1) as pool: |
| for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout): |
| track = test_set.tracks[index] |
|
|
| out = json_folder / f"{track.name}.json.gz" |
| if out.exists(): |
| continue |
|
|
| mix = th.from_numpy(track.audio).t().float() |
| ref = mix.mean(dim=0) |
| mix = (mix - ref.mean()) / ref.std() |
| mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) |
| estimates = apply_model(model, mix.to(device), |
| shifts=shifts, split=split, overlap=overlap) |
| estimates = estimates * ref.std() + ref.mean() |
|
|
| estimates = estimates.transpose(1, 2) |
| references = th.stack( |
| [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) |
| references = convert_audio(references, src_rate, |
| model.samplerate, model.audio_channels) |
| references = references.transpose(1, 2).numpy() |
| estimates = estimates.cpu().numpy() |
| win = int(1. * model.samplerate) |
| hop = int(1. * model.samplerate) |
| if save: |
| folder = eval_folder / "wav/test" / track.name |
| folder.mkdir(exist_ok=True, parents=True) |
| for name, estimate in zip(model.sources, estimates): |
| wavfile.write(str(folder / (name + ".wav")), 44100, estimate) |
|
|
| if workers: |
| pendings.append((track.name, pool.submit( |
| museval.evaluate, references, estimates, win=win, hop=hop))) |
| else: |
| pendings.append((track.name, museval.evaluate( |
| references, estimates, win=win, hop=hop))) |
| del references, mix, estimates, track |
|
|
| for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout): |
| if workers: |
| pending = pending.result() |
| sdr, isr, sir, sar = pending |
| track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name) |
| for idx, target in enumerate(model.sources): |
| values = { |
| "SDR": sdr[idx].tolist(), |
| "SIR": sir[idx].tolist(), |
| "ISR": isr[idx].tolist(), |
| "SAR": sar[idx].tolist() |
| } |
|
|
| track_store.add_target(target_name=target, values=values) |
| json_path = json_folder / f"{track_name}.json.gz" |
| gzip.open(json_path, "w").write(track_store.json.encode('utf-8')) |
| if world_size > 1: |
| distributed.barrier() |
|
|