Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 4,068 Bytes
			
			| 7bc29af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | # Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import sys
from concurrent import futures
import musdb
import museval
import torch as th
import tqdm
from scipy.io import wavfile
from torch import distributed
from .audio import convert_audio
from .utils import apply_model
def evaluate(model,
             musdb_path,
             eval_folder,
             workers=2,
             device="cpu",
             rank=0,
             save=False,
             shifts=0,
             split=False,
             overlap=0.25,
             is_wav=False,
             world_size=1):
    """
    Evaluate model using museval. Run the model
    on a single GPU, the bottleneck being the call to museval.
    """
    output_dir = eval_folder / "results"
    output_dir.mkdir(exist_ok=True, parents=True)
    json_folder = eval_folder / "results/test"
    json_folder.mkdir(exist_ok=True, parents=True)
    # we load tracks from the original musdb set
    test_set = musdb.DB(musdb_path, subsets=["test"], is_wav=is_wav)
    src_rate = 44100  # hardcoded for now...
    for p in model.parameters():
        p.requires_grad = False
        p.grad = None
    pendings = []
    with futures.ProcessPoolExecutor(workers or 1) as pool:
        for index in tqdm.tqdm(range(rank, len(test_set), world_size), file=sys.stdout):
            track = test_set.tracks[index]
            out = json_folder / f"{track.name}.json.gz"
            if out.exists():
                continue
            mix = th.from_numpy(track.audio).t().float()
            ref = mix.mean(dim=0)  # mono mixture
            mix = (mix - ref.mean()) / ref.std()
            mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
            estimates = apply_model(model, mix.to(device),
                                    shifts=shifts, split=split, overlap=overlap)
            estimates = estimates * ref.std() + ref.mean()
            estimates = estimates.transpose(1, 2)
            references = th.stack(
                [th.from_numpy(track.targets[name].audio).t() for name in model.sources])
            references = convert_audio(references, src_rate,
                                       model.samplerate, model.audio_channels)
            references = references.transpose(1, 2).numpy()
            estimates = estimates.cpu().numpy()
            win = int(1. * model.samplerate)
            hop = int(1. * model.samplerate)
            if save:
                folder = eval_folder / "wav/test" / track.name
                folder.mkdir(exist_ok=True, parents=True)
                for name, estimate in zip(model.sources, estimates):
                    wavfile.write(str(folder / (name + ".wav")), 44100, estimate)
            if workers:
                pendings.append((track.name, pool.submit(
                    museval.evaluate, references, estimates, win=win, hop=hop)))
            else:
                pendings.append((track.name, museval.evaluate(
                    references, estimates, win=win, hop=hop)))
            del references, mix, estimates, track
        for track_name, pending in tqdm.tqdm(pendings, file=sys.stdout):
            if workers:
                pending = pending.result()
            sdr, isr, sir, sar = pending
            track_store = museval.TrackStore(win=44100, hop=44100, track_name=track_name)
            for idx, target in enumerate(model.sources):
                values = {
                    "SDR": sdr[idx].tolist(),
                    "SIR": sir[idx].tolist(),
                    "ISR": isr[idx].tolist(),
                    "SAR": sar[idx].tolist()
                }
                track_store.add_target(target_name=target, values=values)
                json_path = json_folder / f"{track_name}.json.gz"
                gzip.open(json_path, "w").write(track_store.json.encode('utf-8'))
    if world_size > 1:
        distributed.barrier()
 | 
