File size: 3,913 Bytes
b725c5a
9893813
b725c5a
 
9893813
 
b725c5a
9893813
 
 
 
b725c5a
9893813
 
 
b725c5a
9893813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b725c5a
 
 
9893813
 
 
b725c5a
acb20c6
 
 
 
 
173e528
b725c5a
 
9893813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b725c5a
 
 
 
 
052a60f
 
b725c5a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import argparse
import os
import torch
import soundfile as sf
import numpy as np

from models.tts.naturalspeech2.ns2 import NaturalSpeech2
from encodec import EncodecModel
from encodec.utils import convert_audio
from utils.util import load_config

from text import text_to_sequence
from text.cmudict import valid_symbols
from text.g2p import preprocess_english, read_lexicon

import torchaudio


def build_codec(device):
    encodec_model = EncodecModel.encodec_model_24khz()
    encodec_model = encodec_model.to(device=device)
    encodec_model.set_target_bandwidth(12.0)
    return encodec_model

def build_model(cfg, device):

    model = NaturalSpeech2(cfg.model)
    model.load_state_dict(
        torch.load(
            "ckpts/ns2/pytorch_model.bin",
            map_location="cpu",
        )
    )
    model = model.to(device=device)
    return model


def ns2_inference(
    prmopt_audio_path,
    text,
    diffusion_steps=100,
):
    try:
        import nltk
        nltk.download('cmudict')
    except:
        pass

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    os.environ["WORK_DIR"] = "./"
    cfg = load_config("egs/tts/NaturalSpeech2/exp_config.json")

    model = build_model(cfg, device)
    codec = build_codec(device)

    ref_wav_path = prmopt_audio_path
    ref_wav, sr = torchaudio.load(ref_wav_path)
    ref_wav = convert_audio(
        ref_wav, sr, codec.sample_rate, codec.channels
    )
    ref_wav = ref_wav.unsqueeze(0).to(device=device)

    with torch.no_grad():
        encoded_frames = codec.encode(ref_wav)
        ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)

    ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)

    symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
    phone2id = {s: i for i, s in enumerate(symbols)}
    id2phone = {i: s for s, i in phone2id.items()}
    
    lexicon = read_lexicon(cfg.preprocess.lexicon_path)
    phone_seq = preprocess_english(text, lexicon)


    phone_id = np.array(
        [
            *map(
                phone2id.get,
                phone_seq.replace("{", "").replace("}", "").split(),
            )
        ]
    )
    phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=device)


    x0, prior_out = model.inference(
        ref_code, phone_id, ref_mask, diffusion_steps
    )

    latent_ref = codec.quantizer.vq.decode(ref_code.transpose(0, 1))
    rec_wav = codec.decoder(x0)

    os.makedirs("result", exist_ok=True)
    sf.write(
        "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result"),
        rec_wav[0, 0].detach().cpu().numpy(),
        samplerate=24000,
    )

    result_file = "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result")
    return result_file


demo_inputs = [
    gr.Audio(
        sources=["upload", "microphone"],
        label="Upload a reference speech you want to clone timbre",
        type="filepath",
    ),
    gr.Textbox(
        value="Amphion is a toolkit that can speak, make sounds, and sing.",
        label="Text you want to generate",
        type="text",
    ),
    gr.Slider(
        10,
        1000,
        value=200,
        step=1,
        label="Diffusion Inference Steps",
        info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
    ),
]
demo_outputs = gr.Audio(label="")

demo = gr.Interface(
    fn=ns2_inference,
    inputs=demo_inputs,
    outputs=demo_outputs,
    title="Amphion Zero-Shot TTS NaturalSpeech2",
    description="Note that the current model is only trained on libritts, and the amount of training data is much less than the 5.5w hours of the original paper. We will soon introduce models trained on large-scale data. Please stay tuned."
)

if __name__ == "__main__":
    demo.launch()