# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torchcrepe import math import librosa import torch import numpy as np def extract_f0_periodicity_rmse( audio_ref, audio_deg, fs=None, hop_length=256, method="dtw", ): """Compute f0 periodicity Root Mean Square Error (RMSE) between the predicted and the ground truth audio. audio_ref: path to the ground truth audio. audio_deg: path to the predicted audio. fs: sampling rate. hop_length: hop length. method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio. "cut" will cut both audios into a same length according to the one with the shorter length. """ # Load audio if fs != None: audio_ref, _ = librosa.load(audio_ref, sr=fs) audio_deg, _ = librosa.load(audio_deg, sr=fs) else: audio_ref, fs = librosa.load(audio_ref) audio_deg, fs = librosa.load(audio_deg) # Convert to torch audio_ref = torch.from_numpy(audio_ref).unsqueeze(0) audio_deg = torch.from_numpy(audio_deg).unsqueeze(0) # Get periodicity pitch_ref, periodicity_ref = torchcrepe.predict( audio_ref, sample_rate=fs, hop_length=hop_length, fmin=0, fmax=1500, model="full", return_periodicity=True, device="cuda:0", ) pitch_deg, periodicity_deg = torchcrepe.predict( audio_deg, sample_rate=fs, hop_length=hop_length, fmin=0, fmax=1500, model="full", return_periodicity=True, device="cuda:0", ) # Cut silence periodicity_ref = ( torchcrepe.threshold.Silence()( periodicity_ref, audio_ref, fs, hop_length=hop_length, ) .squeeze(0) .numpy() ) periodicity_deg = ( torchcrepe.threshold.Silence()( periodicity_deg, audio_deg, fs, hop_length=hop_length, ) .squeeze(0) .numpy() ) # Avoid silence audio min_length = min(len(periodicity_ref), len(periodicity_deg)) if min_length <= 1: return 0 # Periodicity length alignment if method == "cut": length = min(len(periodicity_ref), len(periodicity_deg)) periodicity_ref = periodicity_ref[:length] periodicity_deg = periodicity_deg[:length] elif method == "dtw": _, wp = librosa.sequence.dtw(periodicity_ref, periodicity_deg, backtrack=True) periodicity_ref_new = [] periodicity_deg_new = [] for i in range(wp.shape[0]): ref_index = wp[i][0] deg_index = wp[i][1] periodicity_ref_new.append(periodicity_ref[ref_index]) periodicity_deg_new.append(periodicity_deg[deg_index]) periodicity_ref = np.array(periodicity_ref_new) periodicity_deg = np.array(periodicity_deg_new) assert len(periodicity_ref) == len(periodicity_deg) # Compute RMSE periodicity_mse = np.square(np.subtract(periodicity_ref, periodicity_deg)).mean() periodicity_rmse = math.sqrt(periodicity_mse) return periodicity_rmse