Spaces:
Running
Running
# Copyright 2024 LY Corporation | |
# LY Corporation licenses this file to you under the Apache License, | |
# version 2.0 (the "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at: | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
# License for the specific language governing permissions and limitations | |
# under the License. | |
import gradio as gr | |
import hydra | |
import matplotlib.pyplot as plt | |
import torch | |
import torchaudio | |
from g2p_en import G2p | |
from hydra.utils import instantiate | |
from omegaconf import OmegaConf | |
from promptttspp.text.eng import symbols, text_to_sequence | |
from promptttspp.utils.model import lowpass_filter | |
import nltk | |
def load_model(model_cfg, model_ckpt_path, vocoder_cfg, vocoder_ckpt_path): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = instantiate(model_cfg) | |
model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu")["model"]) | |
model = model.to(device).eval() | |
vocoder = instantiate(vocoder_cfg) | |
vocoder.load_state_dict( | |
torch.load(vocoder_ckpt_path, map_location="cpu")["generator"] | |
) | |
vocoder = vocoder.to(device).eval() | |
return model, vocoder | |
def build_ui(g2p, model, vocoder, to_mel, mel_stats): | |
content_placeholder = ( | |
"This is text to speech demo, which allows you to control the speaker identity " | |
"in natural language as follows." | |
) | |
style_placeholder = "A man speaks slowly in a low tone." | |
def onclick_synthesis(content_prompt, style_prompt=None, reference_mel=None): | |
assert style_prompt is not None or reference_mel is not None | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
phonemes = g2p(content_prompt) | |
phonemes = [p if p not in [",", "."] else "sil" for p in phonemes] | |
phonemes = [p for p in phonemes if p in symbols] | |
phoneme_ids = text_to_sequence(" ".join(phonemes)) | |
phoneme_ids = torch.LongTensor(phoneme_ids)[None, :].to(device) | |
if style_prompt is not None: | |
dec, log_cf0, vuv = model.infer( | |
phoneme_ids, | |
style_prompt=style_prompt, | |
use_max=True, | |
noise_scale=0.5, | |
return_f0=True, | |
) | |
else: | |
reference_mel = (reference_mel - mel_stats["mean"]) / mel_stats["std"] | |
reference_mel = reference_mel.to(device) | |
dec, log_cf0, vuv = model.infer( | |
phoneme_ids, | |
reference_mel=reference_mel, | |
use_max=True, | |
noise_scale=0.5, | |
return_f0=True, | |
) | |
modfs = int(1.0 / (10 * 0.001)) | |
log_cf0 = lowpass_filter(log_cf0, modfs, cutoff=20) | |
f0 = log_cf0.exp() | |
f0[vuv < 0.5] = 0 | |
dec = dec * mel_stats["std"] + mel_stats["mean"] | |
wav = vocoder(dec, f0).squeeze(1).cpu() | |
return wav | |
def onclick_with_style_prompt(content_prompt, style_prompt): | |
wav = onclick_synthesis( | |
content_prompt=content_prompt, style_prompt=style_prompt | |
) | |
mel = to_mel(wav) | |
fig = plt.figure(figsize=(12, 8)) | |
plt.imshow(mel.squeeze().numpy(), aspect="auto", origin="lower") | |
return (to_mel.sample_rate, wav.squeeze().numpy()), fig | |
def onclick_with_reference_mel(content_prompt, reference_wav_path): | |
wav, _ = torchaudio.load(reference_wav_path) | |
ref_mel = to_mel(wav) | |
wav = onclick_synthesis(content_prompt=content_prompt, reference_mel=ref_mel) | |
mel = to_mel(wav) | |
fig = plt.figure(figsize=(12, 8)) | |
plt.imshow(mel.squeeze().numpy(), aspect="auto", origin="lower") | |
return (to_mel.sample_rate, wav.squeeze().numpy()), fig | |
with gr.Blocks() as demo: | |
gr.Markdown("# PromptTTS++: Controlling Speaker Identity in Prompt-Based Text-to-Speech Using Natural Language Descriptions") | |
gr.Markdown("### You can check the [paper](https://arxiv.org/abs/2309.08140) and [code](https://github.com/line/promptttspp).") | |
gr.Markdown("### NOTE: Please do not enter personal information.") | |
content_prompt = gr.Textbox( | |
content_placeholder, lines=3, label="Content prompt" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Style prompt"): | |
style_prompt = gr.Textbox( | |
style_placeholder, lines=3, label="Style prompt" | |
) | |
syn_button1 = gr.Button("Synthesize") | |
wav1 = gr.Audio(label="Output wav", elem_id="prompt") | |
plot1 = gr.Plot(label="Output mel", elem_id="prompt") | |
with gr.TabItem("Reference wav"): | |
ref_wav_path = gr.Audio( | |
type="filepath", label="Reference wav", elem_id="ref" | |
) | |
syn_button2 = gr.Button("Synthesize") | |
wav2 = gr.Audio(label="Output wav", elem_id="ref") | |
plot2 = gr.Plot(label="Output mel", elem_id="ref") | |
syn_button1.click( | |
onclick_with_style_prompt, | |
inputs=[content_prompt, style_prompt], | |
outputs=[wav1, plot1], | |
) | |
syn_button2.click( | |
onclick_with_reference_mel, | |
inputs=[content_prompt, ref_wav_path], | |
outputs=[wav2, plot2], | |
) | |
demo.launch() | |
def main(cfg): | |
model, vocoder = load_model( | |
cfg.model, cfg.model_ckpt_path, cfg.vocoder, cfg.vocoder_ckpt_path | |
) | |
to_mel = instantiate(cfg.transforms) | |
# If the NLTK version is 3.9.1, this download code might be necessary. | |
nltk.download('averaged_perceptron_tagger_eng') | |
g2p = G2p() | |
mel_stats = OmegaConf.load(cfg.mel_stats_file) | |
build_ui(g2p, model, vocoder, to_mel, mel_stats) | |
if __name__ == "__main__": | |
main() | |