# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf and Alexandre Defossez (adefossez) import argparse from concurrent.futures import ProcessPoolExecutor import json import logging import sys import numpy as np from pesq import pesq from pystoi import stoi import torch from .models.sisnr_loss import cal_loss from .data.data import Validset from . import distrib from .utils import bold, deserialize_model, LogProgress logger = logging.getLogger(__name__) parser = argparse.ArgumentParser( 'Evaluate separation performance using MulCat blocks') parser.add_argument('model_path', help='Path to model file created by training') parser.add_argument('data_dir', help='directory including mix.json, s1.json, s2.json, ... files') parser.add_argument('--device', default="cuda") parser.add_argument('--sdr', type=int, default=0) parser.add_argument('--sample_rate', default=16000, type=int, help='Sample rate') parser.add_argument('--num_workers', type=int, default=5) parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, default=logging.INFO, help="More loggging") def evaluate(args, model=None, data_loader=None, sr=None): total_sisnr = 0 total_pesq = 0 total_stoi = 0 total_cnt = 0 updates = 5 # Load model if not model: pkg = torch.load(args.model_path, map_location=args.device) if 'model' in pkg: model = pkg['model'] else: model = pkg model = deserialize_model(model) if 'best_state' in pkg: model.load_state_dict(pkg['best_state']) logger.debug(model) model.eval() model.to(args.device) # Load data if not data_loader: dataset = Validset(args.data_dir) data_loader = distrib.loader( dataset, batch_size=1, num_workers=args.num_workers) sr = args.sample_rate pendings = [] with ProcessPoolExecutor(args.num_workers) as pool: with torch.no_grad(): iterator = LogProgress(logger, data_loader, name="Eval estimates") for i, data in enumerate(iterator): # Get batch data mixture, lengths, sources = [x.to(args.device) for x in data] # Forward with torch.no_grad(): mixture /= mixture.max() estimate = model(mixture)[-1] sisnr_loss, snr, estimate, reorder_estimate = cal_loss( sources, estimate, lengths) reorder_estimate = reorder_estimate.cpu() sources = sources.cpu() mixture = mixture.cpu() pendings.append( pool.submit(_run_metrics, sources, reorder_estimate, mixture, None, sr=sr)) total_cnt += sources.shape[0] for pending in LogProgress(logger, pendings, updates, name="Eval metrics"): sisnr_i, pesq_i, stoi_i = pending.result() total_sisnr += sisnr_i total_pesq += pesq_i total_stoi += stoi_i metrics = [total_sisnr, total_pesq, total_stoi] sisnr, pesq, stoi = distrib.average( [m/total_cnt for m in metrics], total_cnt) logger.info( bold(f'Test set performance: SISNRi={sisnr:.2f} PESQ={pesq}, STOI={stoi}.')) return sisnr, pesq, stoi def _run_metrics(clean, estimate, mix, model, sr, pesq=False): if model is not None: torch.set_num_threads(1) # parallel evaluation here with torch.no_grad(): estimate = model(estimate)[-1] estimate = estimate.numpy() clean = clean.numpy() mix = mix.numpy() sisnr = cal_SISNRi(clean, estimate, mix) if pesq: pesq_i = cal_PESQ(clean, estimate, sr=sr) stoi_i = cal_STOI(clean, estimate, sr=sr) else: pesq_i = 0 stoi_i = 0 return sisnr.mean(), pesq_i, stoi_i def cal_SISNR(ref_sig, out_sig, eps=1e-8): """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR) Args: ref_sig: numpy.ndarray, [B, T] out_sig: numpy.ndarray, [B, T] Returns: SISNR """ assert len(ref_sig) == len(out_sig) B, T = ref_sig.shape ref_sig = ref_sig - np.mean(ref_sig, axis=1).reshape(B, 1) out_sig = out_sig - np.mean(out_sig, axis=1).reshape(B, 1) ref_energy = (np.sum(ref_sig ** 2, axis=1) + eps).reshape(B, 1) proj = (np.sum(ref_sig * out_sig, axis=1).reshape(B, 1)) * \ ref_sig / ref_energy noise = out_sig - proj ratio = np.sum(proj ** 2, axis=1) / (np.sum(noise ** 2, axis=1) + eps) sisnr = 10 * np.log(ratio + eps) / np.log(10.0) return sisnr.mean() def cal_PESQ(ref_sig, out_sig, sr): """Calculate PESQ. Args: ref_sig: numpy.ndarray, [B, C, T] out_sig: numpy.ndarray, [B, C, T] Returns PESQ """ B, C, T = ref_sig.shape ref_sig = ref_sig.reshape(B*C, T) out_sig = out_sig.reshape(B*C, T) pesq_val = 0 for i in range(len(ref_sig)): pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'nb') return pesq_val / (B*C) def cal_STOI(ref_sig, out_sig, sr): """Calculate STOI. Args: ref_sig: numpy.ndarray, [B, C, T] out_sig: numpy.ndarray, [B, C, T] Returns: STOI """ B, C, T = ref_sig.shape ref_sig = ref_sig.reshape(B*C, T) out_sig = out_sig.reshape(B*C, T) try: stoi_val = 0 for i in range(len(ref_sig)): stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) return stoi_val / (B*C) except: return 0 def cal_SISNRi(src_ref, src_est, mix): """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi) Args: src_ref: numpy.ndarray, [B, C, T] src_est: numpy.ndarray, [B, C, T], reordered by best PIT permutation mix: numpy.ndarray, [T] Returns: average_SISNRi """ avg_SISNRi = 0.0 B, C, T = src_ref.shape for c in range(C): sisnr = cal_SISNR(src_ref[:, c], src_est[:, c]) sisnrb = cal_SISNR(src_ref[:, c], mix) avg_SISNRi += (sisnr - sisnrb) avg_SISNRi /= C return avg_SISNRi def main(): args = parser.parse_args() logging.basicConfig(stream=sys.stderr, level=args.verbose) logger.debug(args) sisnr, pesq, stoi = evaluate(args) json.dump({'sisnr': sisnr, 'pesq': pesq, 'stoi': stoi}, sys.stdout) sys.stdout.write('\n') if __name__ == '__main__': main()