Spaces:
Runtime error
Runtime error
| # Copyright 2024 LY Corporation | |
| # LY Corporation licenses this file to you under the Apache License, | |
| # version 2.0 (the "License"); you may not use this file except in compliance | |
| # with the License. You may obtain a copy of the License at: | |
| # https://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | |
| # License for the specific language governing permissions and limitations | |
| # under the License. | |
| import gradio as gr | |
| import hydra | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchaudio | |
| from g2p_en import G2p | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from promptttspp.text.eng import symbols, text_to_sequence | |
| from promptttspp.utils.model import lowpass_filter | |
| import nltk | |
| def load_model(model_cfg, model_ckpt_path, vocoder_cfg, vocoder_ckpt_path): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = instantiate(model_cfg) | |
| model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu")["model"]) | |
| model = model.to(device).eval() | |
| vocoder = instantiate(vocoder_cfg) | |
| vocoder.load_state_dict( | |
| torch.load(vocoder_ckpt_path, map_location="cpu")["generator"] | |
| ) | |
| vocoder = vocoder.to(device).eval() | |
| return model, vocoder | |
| def build_ui(g2p, model, vocoder, to_mel, mel_stats): | |
| content_placeholder = ( | |
| "This is text to speech demo, which allows you to control the speaker identity " | |
| "in natural language as follows." | |
| ) | |
| style_placeholder = "A man speaks slowly in a low tone." | |
| def onclick_synthesis(content_prompt, style_prompt=None, reference_mel=None): | |
| assert style_prompt is not None or reference_mel is not None | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| phonemes = g2p(content_prompt) | |
| phonemes = [p if p not in [",", "."] else "sil" for p in phonemes] | |
| phonemes = [p for p in phonemes if p in symbols] | |
| phoneme_ids = text_to_sequence(" ".join(phonemes)) | |
| phoneme_ids = torch.LongTensor(phoneme_ids)[None, :].to(device) | |
| if style_prompt is not None: | |
| dec, log_cf0, vuv = model.infer( | |
| phoneme_ids, | |
| style_prompt=style_prompt, | |
| use_max=True, | |
| noise_scale=0.5, | |
| return_f0=True, | |
| ) | |
| else: | |
| reference_mel = (reference_mel - mel_stats["mean"]) / mel_stats["std"] | |
| reference_mel = reference_mel.to(device) | |
| dec, log_cf0, vuv = model.infer( | |
| phoneme_ids, | |
| reference_mel=reference_mel, | |
| use_max=True, | |
| noise_scale=0.5, | |
| return_f0=True, | |
| ) | |
| modfs = int(1.0 / (10 * 0.001)) | |
| log_cf0 = lowpass_filter(log_cf0, modfs, cutoff=20) | |
| f0 = log_cf0.exp() | |
| f0[vuv < 0.5] = 0 | |
| dec = dec * mel_stats["std"] + mel_stats["mean"] | |
| wav = vocoder(dec, f0).squeeze(1).cpu() | |
| return wav | |
| def onclick_with_style_prompt(content_prompt, style_prompt): | |
| wav = onclick_synthesis( | |
| content_prompt=content_prompt, style_prompt=style_prompt | |
| ) | |
| mel = to_mel(wav) | |
| fig = plt.figure(figsize=(12, 8)) | |
| plt.imshow(mel.squeeze().numpy(), aspect="auto", origin="lower") | |
| return (to_mel.sample_rate, wav.squeeze().numpy()), fig | |
| def onclick_with_reference_mel(content_prompt, reference_wav_path): | |
| wav, _ = torchaudio.load(reference_wav_path) | |
| ref_mel = to_mel(wav) | |
| wav = onclick_synthesis(content_prompt=content_prompt, reference_mel=ref_mel) | |
| mel = to_mel(wav) | |
| fig = plt.figure(figsize=(12, 8)) | |
| plt.imshow(mel.squeeze().numpy(), aspect="auto", origin="lower") | |
| return (to_mel.sample_rate, wav.squeeze().numpy()), fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# PromptTTS++: Controlling Speaker Identity in Prompt-Based Text-to-Speech Using Natural Language Descriptions") | |
| gr.Markdown("### You can check the [paper](https://arxiv.org/abs/2309.08140) and [code](https://github.com/line/promptttspp).") | |
| gr.Markdown("### NOTE: Please do not enter personal information.") | |
| content_prompt = gr.Textbox( | |
| content_placeholder, lines=3, label="Content prompt" | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem("Style prompt"): | |
| style_prompt = gr.Textbox( | |
| style_placeholder, lines=3, label="Style prompt" | |
| ) | |
| syn_button1 = gr.Button("Synthesize") | |
| wav1 = gr.Audio(label="Output wav", elem_id="prompt") | |
| plot1 = gr.Plot(label="Output mel", elem_id="prompt") | |
| with gr.TabItem("Reference wav"): | |
| ref_wav_path = gr.Audio( | |
| type="filepath", label="Reference wav", elem_id="ref" | |
| ) | |
| syn_button2 = gr.Button("Synthesize") | |
| wav2 = gr.Audio(label="Output wav", elem_id="ref") | |
| plot2 = gr.Plot(label="Output mel", elem_id="ref") | |
| syn_button1.click( | |
| onclick_with_style_prompt, | |
| inputs=[content_prompt, style_prompt], | |
| outputs=[wav1, plot1], | |
| ) | |
| syn_button2.click( | |
| onclick_with_reference_mel, | |
| inputs=[content_prompt, ref_wav_path], | |
| outputs=[wav2, plot2], | |
| ) | |
| demo.launch() | |
| def main(cfg): | |
| model, vocoder = load_model( | |
| cfg.model, cfg.model_ckpt_path, cfg.vocoder, cfg.vocoder_ckpt_path | |
| ) | |
| to_mel = instantiate(cfg.transforms) | |
| # If the NLTK version is 3.9.1, this download code might be necessary. | |
| nltk.download('averaged_perceptron_tagger_eng') | |
| g2p = G2p() | |
| mel_stats = OmegaConf.load(cfg.mel_stats_file) | |
| build_ui(g2p, model, vocoder, to_mel, mel_stats) | |
| if __name__ == "__main__": | |
| main() | |