Spaces:
Runtime error
Runtime error
File size: 4,256 Bytes
2b7bf83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import numpy as np
import soundfile as sf
from model_encoder import Encoder, Encoder_lf0
from model_decoder import Decoder_ac
from model_encoder import SpeakerEncoder as Encoder_spk
import os
import subprocess
from spectrogram import logmelspectrogram
import kaldiio
import resampy
import pyworld as pw
import argparse
def extract_logmel(wav_path, mean, std, sr=16000):
# wav, fs = librosa.load(wav_path, sr=sr)
wav, fs = sf.read(wav_path)
if fs != sr:
wav = resampy.resample(wav, fs, sr, axis=0)
fs = sr
#wav, _ = librosa.effects.trim(wav, top_db=15)
# duration = len(wav)/fs
assert fs == 16000
peak = np.abs(wav).max()
if peak > 1.0:
wav /= peak
mel = logmelspectrogram(
x=wav,
fs=fs,
n_mels=80,
n_fft=400,
n_shift=160,
win_length=400,
window='hann',
fmin=80,
fmax=7600,
)
mel = (mel - mean) / (std + 1e-8)
tlen = mel.shape[0]
frame_period = 160/fs*1000
f0, timeaxis = pw.dio(wav.astype('float64'), fs, frame_period=frame_period)
f0 = pw.stonemask(wav.astype('float64'), f0, timeaxis, fs)
f0 = f0[:tlen].reshape(-1).astype('float32')
nonzeros_indices = np.nonzero(f0)
lf0 = f0.copy()
lf0[nonzeros_indices] = np.log(f0[nonzeros_indices]) # for f0(Hz), lf0 > 0 when f0 != 0
mean, std = np.mean(lf0[nonzeros_indices]), np.std(lf0[nonzeros_indices])
lf0[nonzeros_indices] = (lf0[nonzeros_indices] - mean) / (std + 1e-8)
return mel, lf0
def convert(args):
src_wav_path = args.source_wav
ref_wav_path = args.reference_wav
out_dir = args.converted_wav_path
os.makedirs(out_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(in_channels=80, channels=512, n_embeddings=512, z_dim=64, c_dim=256)
encoder_lf0 = Encoder_lf0()
encoder_spk = Encoder_spk()
decoder = Decoder_ac(dim_neck=64)
encoder.to(device)
encoder_lf0.to(device)
encoder_spk.to(device)
decoder.to(device)
checkpoint_path = args.model_path
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
encoder.load_state_dict(checkpoint["encoder"])
encoder_spk.load_state_dict(checkpoint["encoder_spk"])
decoder.load_state_dict(checkpoint["decoder"])
encoder.eval()
encoder_spk.eval()
decoder.eval()
mel_stats = np.load('./mel_stats/stats.npy')
mean = mel_stats[0]
std = mel_stats[1]
feat_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(o=str(out_dir)+'/feats.1'))
src_mel, src_lf0 = extract_logmel(src_wav_path, mean, std)
ref_mel, _ = extract_logmel(ref_wav_path, mean, std)
src_mel = torch.FloatTensor(src_mel.T).unsqueeze(0).to(device)
src_lf0 = torch.FloatTensor(src_lf0).unsqueeze(0).to(device)
ref_mel = torch.FloatTensor(ref_mel.T).unsqueeze(0).to(device)
out_filename = os.path.basename(src_wav_path).split('.')[0]
with torch.no_grad():
z, _, _, _ = encoder.encode(src_mel)
lf0_embs = encoder_lf0(src_lf0)
spk_emb = encoder_spk(ref_mel)
output = decoder(z, lf0_embs, spk_emb)
feat_writer[out_filename+'_converted'] = output.squeeze(0).cpu().numpy()
feat_writer[out_filename+'_source'] = src_mel.squeeze(0).cpu().numpy().T
feat_writer[out_filename+'_reference'] = ref_mel.squeeze(0).cpu().numpy().T
feat_writer.close()
print('synthesize waveform...')
cmd = ['parallel-wavegan-decode', '--checkpoint', \
'./vocoder/checkpoint-3000000steps.pkl', \
'--feats-scp', f'{str(out_dir)}/feats.1.scp', '--outdir', str(out_dir)]
subprocess.call(cmd)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--source_wav', '-s', type=str, required=True)
parser.add_argument('--reference_wav', '-r', type=str, required=True)
parser.add_argument('--converted_wav_path', '-c', type=str, default='converted')
parser.add_argument('--model_path', '-m', type=str, required=True)
args = parser.parse_args()
convert(args)
|