# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """ Signal processing-based evaluation using waveforms """ import csv import numpy as np import os.path as op import torch import tqdm from tabulate import tabulate import torchaudio from examples.speech_synthesis.utils import batch_mel_spectral_distortion from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion def load_eval_spec(path): with open(path) as f: reader = csv.DictReader(f, delimiter='\t') samples = list(reader) return samples def eval_distortion(samples, distortion_fn, device="cuda"): nmiss = 0 results = [] for sample in tqdm.tqdm(samples): if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]): nmiss += 1 results.append(None) continue # assume single channel yref, sr = torchaudio.load(sample["ref"]) ysyn, _sr = torchaudio.load(sample["syn"]) yref, ysyn = yref[0].to(device), ysyn[0].to(device) assert sr == _sr, f"{sr} != {_sr}" distortion, extra = distortion_fn([yref], [ysyn], sr, None)[0] _, _, _, _, _, pathmap = extra nins = torch.sum(pathmap.sum(dim=1) - 1) # extra frames in syn ndel = torch.sum(pathmap.sum(dim=0) - 1) # missing frames from syn results.append( (distortion.item(), # path distortion pathmap.size(0), # yref num frames pathmap.size(1), # ysyn num frames pathmap.sum().item(), # path length nins.item(), # insertion ndel.item(), # deletion ) ) return results def eval_mel_cepstral_distortion(samples, device="cuda"): return eval_distortion(samples, batch_mel_cepstral_distortion, device) def eval_mel_spectral_distortion(samples, device="cuda"): return eval_distortion(samples, batch_mel_spectral_distortion, device) def print_results(results, show_bin): results = np.array(list(filter(lambda x: x is not None, results))) np.set_printoptions(precision=3) def _print_result(results): dist, dur_ref, dur_syn, dur_ali, nins, ndel = results.sum(axis=0) res = { "nutt": len(results), "dist": dist, "dur_ref": int(dur_ref), "dur_syn": int(dur_syn), "dur_ali": int(dur_ali), "dist_per_ref_frm": dist/dur_ref, "dist_per_syn_frm": dist/dur_syn, "dist_per_ali_frm": dist/dur_ali, "ins": nins/dur_ref, "del": ndel/dur_ref, } print(tabulate( [res.values()], res.keys(), floatfmt=".4f" )) print(">>>> ALL") _print_result(results) if show_bin: edges = [0, 200, 400, 600, 800, 1000, 2000, 4000] for i in range(1, len(edges)): mask = np.logical_and(results[:, 1] >= edges[i-1], results[:, 1] < edges[i]) if not mask.any(): continue bin_results = results[mask] print(f">>>> ({edges[i-1]}, {edges[i]})") _print_result(bin_results) def main(eval_spec, mcd, msd, show_bin): samples = load_eval_spec(eval_spec) device = "cpu" if mcd: print("===== Evaluate Mean Cepstral Distortion =====") results = eval_mel_cepstral_distortion(samples, device) print_results(results, show_bin) if msd: print("===== Evaluate Mean Spectral Distortion =====") results = eval_mel_spectral_distortion(samples, device) print_results(results, show_bin) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("eval_spec") parser.add_argument("--mcd", action="store_true") parser.add_argument("--msd", action="store_true") parser.add_argument("--show-bin", action="store_true") args = parser.parse_args() main(args.eval_spec, args.mcd, args.msd, args.show_bin)