File size: 4,340 Bytes
18be3e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94a6ab2
18be3e0
 
 
 
 
 
 
 
 
 
 
 
 
 
94a6ab2
18be3e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import json
import math

import torch
import torch.nn.functional as F
import librosa
import numpy as np
import soundfile as sf
import gradio as gr
from transformers import WavLMModel

from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from models import Generator
from Utils.JDC.model import JDCNet


# files
hpfile = "config_v1_16k.json"
ptfile = "exp/default/g_00700000"
spk2id_path = "filelists/spk2id.json"
f0_stats_path = "filelists/f0_stats.json"
spk_stats_path = "filelists/spk_stats.json"
spk_emb_dir = "dataset/spk"
spk_wav_dir = "dataset/audio"

# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load config
with open(hpfile) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

# load models
F0_model = JDCNet(num_class=1, seq_len=192)
generator = Generator(h, F0_model).to(device)

state_dict_g = torch.load(ptfile, map_location=device)
generator.load_state_dict(state_dict_g['generator'], strict=True)
generator.remove_weight_norm()
_ = generator.eval()

wavlm = WavLMModel.from_pretrained("microsoft/wavlm-base-plus")
wavlm.eval()
wavlm.to(device)

# load stats
with open(spk2id_path) as f:
    spk2id = json.load(f)
with open(f0_stats_path) as f:
    f0_stats = json.load(f)
with open(spk_stats_path) as f:
    spk_stats = json.load(f)

# tune f0
threshold = 10
step = (math.log(1100) - math.log(50)) / 256
def tune_f0(initial_f0, i):
    if i == 0:
        return initial_f0
    voiced = initial_f0 > threshold
    initial_lf0 = torch.log(initial_f0)
    lf0 = initial_lf0 + step * i
    f0 = torch.exp(lf0)
    f0 = torch.where(voiced, f0, initial_f0)
    return f0

# convert function
def convert(tgt_spk, src_wav, f0_shift=0):
    tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
    tgt_emb = f"{spk_emb_dir}/{tgt_spk}/{tgt_ref}.npy"

    with torch.no_grad():
        # tgt
        spk_id = spk2id[tgt_spk]
        spk_id = torch.LongTensor([spk_id]).unsqueeze(0).to(device)
        
        spk_emb = np.load(tgt_emb)
        spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)

        f0_mean_tgt = f0_stats[tgt_spk]["mean"]
        f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)

        # src
        wav, sr = librosa.load(src_wav, sr=16000)
        wav = torch.FloatTensor(wav).to(device)
        mel = mel_spectrogram(wav.unsqueeze(0), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
        
        x = wavlm(wav.unsqueeze(0)).last_hidden_state
        x = x.transpose(1, 2) # (B, C, T)
        x = F.pad(x, (0, mel.size(2) - x.size(2)), 'constant')

        # cvt
        f0 = generator.get_f0(mel, f0_mean_tgt)
        f0 = tune_f0(f0, f0_shift)
        x = generator.get_x(x, spk_emb, spk_id)
        y = generator.infer(x, f0)
        
        audio = y.squeeze()
        audio = audio / torch.max(torch.abs(audio)) * 0.95
        audio = audio * MAX_WAV_VALUE
        audio = audio.cpu().numpy().astype('int16')

        sf.write("out.wav", audio, h.sampling_rate, "PCM_16")

    out_wav = "out.wav"
    return out_wav

# change spk
def change_spk(tgt_spk):
    tgt_ref = spk_stats[tgt_spk]["best_spk_emb"]
    tgt_wav = f"{spk_wav_dir}/{tgt_spk}/{tgt_ref}.wav"
    return tgt_wav

# interface
with gr.Blocks() as demo:
    gr.Markdown("# PitchVC")
    gr.Markdown("Gradio Demo for PitchVC. ([Github Repo](https://github.com/OlaWod/PitchVC))")

    with gr.Row():
        with gr.Column():
            tgt_spk = gr.Dropdown(choices=spk2id.keys(), type="value", label="Target Speaker")
            ref_audio =  gr.Audio(label="Reference Audio", type='filepath')
            src_audio = gr.Audio(label="Source Audio", type='filepath')
            f0_shift = gr.Slider(minimum=-30, maximum=30, value=0, step=1, label="F0 Shift")
        with gr.Column():
            out_audio =  gr.Audio(label="Output Audio", type='filepath')
            submit = gr.Button(value="Submit")

    tgt_spk.change(fn=change_spk, inputs=[tgt_spk], outputs=[ref_audio])
    submit.click(convert, [tgt_spk, src_audio, f0_shift], [out_audio])

    examples = gr.Examples(
        examples=[["p225", 'dataset/audio/p226/p226_341.wav', 0], 
                    ["p226", 'dataset/audio/p225/p225_220.wav', -5]],
        inputs=[tgt_spk, src_audio, f0_shift])

demo.launch()