JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
8.33 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Signal processing-based evaluation using waveforms
"""
import numpy as np
import os.path as op
import torchaudio
import tqdm
from tabulate import tabulate
from examples.speech_synthesis.utils import (
gross_pitch_error, voicing_decision_error, f0_frame_error
)
from examples.speech_synthesis.evaluation.eval_sp import load_eval_spec
def difference_function(x, n, tau_max):
"""
Compute difference function of data x. This solution is implemented directly
with Numpy fft.
:param x: audio data
:param n: length of data
:param tau_max: integration window size
:return: difference function
:rtype: list
"""
x = np.array(x, np.float64)
w = x.size
tau_max = min(tau_max, w)
x_cumsum = np.concatenate((np.array([0.]), (x * x).cumsum()))
size = w + tau_max
p2 = (size // 32).bit_length()
nice_numbers = (16, 18, 20, 24, 25, 27, 30, 32)
size_pad = min(x * 2 ** p2 for x in nice_numbers if x * 2 ** p2 >= size)
fc = np.fft.rfft(x, size_pad)
conv = np.fft.irfft(fc * fc.conjugate())[:tau_max]
return x_cumsum[w:w - tau_max:-1] + x_cumsum[w] - x_cumsum[:tau_max] - \
2 * conv
def cumulative_mean_normalized_difference_function(df, n):
"""
Compute cumulative mean normalized difference function (CMND).
:param df: Difference function
:param n: length of data
:return: cumulative mean normalized difference function
:rtype: list
"""
# scipy method
cmn_df = df[1:] * range(1, n) / np.cumsum(df[1:]).astype(float)
return np.insert(cmn_df, 0, 1)
def get_pitch(cmdf, tau_min, tau_max, harmo_th=0.1):
"""
Return fundamental period of a frame based on CMND function.
:param cmdf: Cumulative Mean Normalized Difference function
:param tau_min: minimum period for speech
:param tau_max: maximum period for speech
:param harmo_th: harmonicity threshold to determine if it is necessary to
compute pitch frequency
:return: fundamental period if there is values under threshold, 0 otherwise
:rtype: float
"""
tau = tau_min
while tau < tau_max:
if cmdf[tau] < harmo_th:
while tau + 1 < tau_max and cmdf[tau + 1] < cmdf[tau]:
tau += 1
return tau
tau += 1
return 0 # if unvoiced
def compute_yin(sig, sr, w_len=512, w_step=256, f0_min=100, f0_max=500,
harmo_thresh=0.1):
"""
Compute the Yin Algorithm. Return fundamental frequency and harmonic rate.
https://github.com/NVIDIA/mellotron adaption of
https://github.com/patriceguyot/Yin
:param sig: Audio signal (list of float)
:param sr: sampling rate (int)
:param w_len: size of the analysis window (samples)
:param w_step: size of the lag between two consecutives windows (samples)
:param f0_min: Minimum fundamental frequency that can be detected (hertz)
:param f0_max: Maximum fundamental frequency that can be detected (hertz)
:param harmo_thresh: Threshold of detection. The yalgorithmù return the
first minimum of the CMND function below this threshold.
:returns:
* pitches: list of fundamental frequencies,
* harmonic_rates: list of harmonic rate values for each fundamental
frequency value (= confidence value)
* argmins: minimums of the Cumulative Mean Normalized DifferenceFunction
* times: list of time of each estimation
:rtype: tuple
"""
tau_min = int(sr / f0_max)
tau_max = int(sr / f0_min)
# time values for each analysis window
time_scale = range(0, len(sig) - w_len, w_step)
times = [t/float(sr) for t in time_scale]
frames = [sig[t:t + w_len] for t in time_scale]
pitches = [0.0] * len(time_scale)
harmonic_rates = [0.0] * len(time_scale)
argmins = [0.0] * len(time_scale)
for i, frame in enumerate(frames):
# Compute YIN
df = difference_function(frame, w_len, tau_max)
cm_df = cumulative_mean_normalized_difference_function(df, tau_max)
p = get_pitch(cm_df, tau_min, tau_max, harmo_thresh)
# Get results
if np.argmin(cm_df) > tau_min:
argmins[i] = float(sr / np.argmin(cm_df))
if p != 0: # A pitch was found
pitches[i] = float(sr / p)
harmonic_rates[i] = cm_df[p]
else: # No pitch, but we compute a value of the harmonic rate
harmonic_rates[i] = min(cm_df)
return pitches, harmonic_rates, argmins, times
def extract_f0(samples):
f0_samples = []
for sample in tqdm.tqdm(samples):
if not op.isfile(sample["ref"]) or not op.isfile(sample["syn"]):
f0_samples.append(None)
continue
# assume single channel
yref, sr = torchaudio.load(sample["ref"])
ysyn, _sr = torchaudio.load(sample["syn"])
yref, ysyn = yref[0], ysyn[0]
assert sr == _sr, f"{sr} != {_sr}"
yref_f0 = compute_yin(yref, sr)
ysyn_f0 = compute_yin(ysyn, sr)
f0_samples += [
{
"ref": yref_f0,
"syn": ysyn_f0
}
]
return f0_samples
def eval_f0_error(samples, distortion_fn):
results = []
for sample in tqdm.tqdm(samples):
if sample is None:
results.append(None)
continue
# assume single channel
yref_f, _, _, yref_t = sample["ref"]
ysyn_f, _, _, ysyn_t = sample["syn"]
yref_f = np.array(yref_f)
yref_t = np.array(yref_t)
ysyn_f = np.array(ysyn_f)
ysyn_t = np.array(ysyn_t)
distortion = distortion_fn(yref_t, yref_f, ysyn_t, ysyn_f)
results.append((distortion.item(),
len(yref_f),
len(ysyn_f)
))
return results
def eval_gross_pitch_error(samples):
return eval_f0_error(samples, gross_pitch_error)
def eval_voicing_decision_error(samples):
return eval_f0_error(samples, voicing_decision_error)
def eval_f0_frame_error(samples):
return eval_f0_error(samples, f0_frame_error)
def print_results(results, show_bin):
results = np.array(list(filter(lambda x: x is not None, results)))
np.set_printoptions(precision=3)
def _print_result(results):
res = {
"nutt": len(results),
"error": results[:, 0].mean(),
"std": results[:, 0].std(),
"dur_ref": int(results[:, 1].sum()),
"dur_syn": int(results[:, 2].sum()),
}
print(tabulate([res.values()], res.keys(), floatfmt=".4f"))
print(">>>> ALL")
_print_result(results)
if show_bin:
edges = [0, 200, 400, 600, 800, 1000, 2000, 4000]
for i in range(1, len(edges)):
mask = np.logical_and(results[:, 1] >= edges[i-1],
results[:, 1] < edges[i])
if not mask.any():
continue
bin_results = results[mask]
print(f">>>> ({edges[i-1]}, {edges[i]})")
_print_result(bin_results)
def main(eval_f0, gpe, vde, ffe, show_bin):
samples = load_eval_spec(eval_f0)
if gpe or vde or ffe:
f0_samples = extract_f0(samples)
if gpe:
print("===== Evaluate Gross Pitch Error =====")
results = eval_gross_pitch_error(f0_samples)
print_results(results, show_bin)
if vde:
print("===== Evaluate Voicing Decision Error =====")
results = eval_voicing_decision_error(f0_samples)
print_results(results, show_bin)
if ffe:
print("===== Evaluate F0 Frame Error =====")
results = eval_f0_frame_error(f0_samples)
print_results(results, show_bin)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("eval_f0")
parser.add_argument("--gpe", action="store_true")
parser.add_argument("--vde", action="store_true")
parser.add_argument("--ffe", action="store_true")
parser.add_argument("--show-bin", action="store_true")
args = parser.parse_args()
main(args.eval_f0, args.gpe, args.vde, args.ffe, args.show_bin)