import torch import numpy as np import soundfile as sf from model_encoder import Encoder, Encoder_lf0 from model_decoder import Decoder_ac from model_encoder import SpeakerEncoder as Encoder_spk import os import subprocess from spectrogram import logmelspectrogram import kaldiio import resampy import pyworld as pw import argparse def extract_logmel(wav_path, mean, std, sr=16000): # wav, fs = librosa.load(wav_path, sr=sr) wav, fs = sf.read(wav_path) if fs != sr: wav = resampy.resample(wav, fs, sr, axis=0) fs = sr #wav, _ = librosa.effects.trim(wav, top_db=15) # duration = len(wav)/fs assert fs == 16000 peak = np.abs(wav).max() if peak > 1.0: wav /= peak mel = logmelspectrogram( x=wav, fs=fs, n_mels=80, n_fft=400, n_shift=160, win_length=400, window='hann', fmin=80, fmax=7600, ) mel = (mel - mean) / (std + 1e-8) tlen = mel.shape[0] frame_period = 160/fs*1000 f0, timeaxis = pw.dio(wav.astype('float64'), fs, frame_period=frame_period) f0 = pw.stonemask(wav.astype('float64'), f0, timeaxis, fs) f0 = f0[:tlen].reshape(-1).astype('float32') nonzeros_indices = np.nonzero(f0) lf0 = f0.copy() lf0[nonzeros_indices] = np.log(f0[nonzeros_indices]) # for f0(Hz), lf0 > 0 when f0 != 0 mean, std = np.mean(lf0[nonzeros_indices]), np.std(lf0[nonzeros_indices]) lf0[nonzeros_indices] = (lf0[nonzeros_indices] - mean) / (std + 1e-8) return mel, lf0 def convert(args): src_wav_path = args.source_wav ref_wav_path = args.reference_wav out_dir = args.converted_wav_path os.makedirs(out_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(in_channels=80, channels=512, n_embeddings=512, z_dim=64, c_dim=256) encoder_lf0 = Encoder_lf0() encoder_spk = Encoder_spk() decoder = Decoder_ac(dim_neck=64) encoder.to(device) encoder_lf0.to(device) encoder_spk.to(device) decoder.to(device) checkpoint_path = args.model_path checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) encoder.load_state_dict(checkpoint["encoder"]) encoder_spk.load_state_dict(checkpoint["encoder_spk"]) decoder.load_state_dict(checkpoint["decoder"]) encoder.eval() encoder_spk.eval() decoder.eval() mel_stats = np.load('./mel_stats/stats.npy') mean = mel_stats[0] std = mel_stats[1] feat_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(o=str(out_dir)+'/feats.1')) src_mel, src_lf0 = extract_logmel(src_wav_path, mean, std) ref_mel, _ = extract_logmel(ref_wav_path, mean, std) src_mel = torch.FloatTensor(src_mel.T).unsqueeze(0).to(device) src_lf0 = torch.FloatTensor(src_lf0).unsqueeze(0).to(device) ref_mel = torch.FloatTensor(ref_mel.T).unsqueeze(0).to(device) out_filename = os.path.basename(src_wav_path).split('.')[0] with torch.no_grad(): z, _, _, _ = encoder.encode(src_mel) lf0_embs = encoder_lf0(src_lf0) spk_emb = encoder_spk(ref_mel) output = decoder(z, lf0_embs, spk_emb) feat_writer[out_filename+'_converted'] = output.squeeze(0).cpu().numpy() feat_writer[out_filename+'_source'] = src_mel.squeeze(0).cpu().numpy().T feat_writer[out_filename+'_reference'] = ref_mel.squeeze(0).cpu().numpy().T feat_writer.close() print('synthesize waveform...') cmd = ['parallel-wavegan-decode', '--checkpoint', \ './vocoder/checkpoint-3000000steps.pkl', \ '--feats-scp', f'{str(out_dir)}/feats.1.scp', '--outdir', str(out_dir)] subprocess.call(cmd) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--source_wav', '-s', type=str, required=True) parser.add_argument('--reference_wav', '-r', type=str, required=True) parser.add_argument('--converted_wav_path', '-c', type=str, default='converted') parser.add_argument('--model_path', '-m', type=str, required=True) args = parser.parse_args() convert(args)