promptttspp / app.py
MasayaKawamura's picture
Add github code and paper url
36a43ba
# Copyright 2024 LY Corporation
# LY Corporation licenses this file to you under the Apache License,
# version 2.0 (the "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at:
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import gradio as gr
import hydra
import matplotlib.pyplot as plt
import torch
import torchaudio
from g2p_en import G2p
from hydra.utils import instantiate
from omegaconf import OmegaConf
from promptttspp.text.eng import symbols, text_to_sequence
from promptttspp.utils.model import lowpass_filter
import nltk
def load_model(model_cfg, model_ckpt_path, vocoder_cfg, vocoder_ckpt_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = instantiate(model_cfg)
model.load_state_dict(torch.load(model_ckpt_path, map_location="cpu")["model"])
model = model.to(device).eval()
vocoder = instantiate(vocoder_cfg)
vocoder.load_state_dict(
torch.load(vocoder_ckpt_path, map_location="cpu")["generator"]
)
vocoder = vocoder.to(device).eval()
return model, vocoder
def build_ui(g2p, model, vocoder, to_mel, mel_stats):
content_placeholder = (
"This is text to speech demo, which allows you to control the speaker identity "
"in natural language as follows."
)
style_placeholder = "A man speaks slowly in a low tone."
@torch.no_grad()
def onclick_synthesis(content_prompt, style_prompt=None, reference_mel=None):
assert style_prompt is not None or reference_mel is not None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
phonemes = g2p(content_prompt)
phonemes = [p if p not in [",", "."] else "sil" for p in phonemes]
phonemes = [p for p in phonemes if p in symbols]
phoneme_ids = text_to_sequence(" ".join(phonemes))
phoneme_ids = torch.LongTensor(phoneme_ids)[None, :].to(device)
if style_prompt is not None:
dec, log_cf0, vuv = model.infer(
phoneme_ids,
style_prompt=style_prompt,
use_max=True,
noise_scale=0.5,
return_f0=True,
)
else:
reference_mel = (reference_mel - mel_stats["mean"]) / mel_stats["std"]
reference_mel = reference_mel.to(device)
dec, log_cf0, vuv = model.infer(
phoneme_ids,
reference_mel=reference_mel,
use_max=True,
noise_scale=0.5,
return_f0=True,
)
modfs = int(1.0 / (10 * 0.001))
log_cf0 = lowpass_filter(log_cf0, modfs, cutoff=20)
f0 = log_cf0.exp()
f0[vuv < 0.5] = 0
dec = dec * mel_stats["std"] + mel_stats["mean"]
wav = vocoder(dec, f0).squeeze(1).cpu()
return wav
def onclick_with_style_prompt(content_prompt, style_prompt):
wav = onclick_synthesis(
content_prompt=content_prompt, style_prompt=style_prompt
)
mel = to_mel(wav)
fig = plt.figure(figsize=(12, 8))
plt.imshow(mel.squeeze().numpy(), aspect="auto", origin="lower")
return (to_mel.sample_rate, wav.squeeze().numpy()), fig
def onclick_with_reference_mel(content_prompt, reference_wav_path):
wav, _ = torchaudio.load(reference_wav_path)
ref_mel = to_mel(wav)
wav = onclick_synthesis(content_prompt=content_prompt, reference_mel=ref_mel)
mel = to_mel(wav)
fig = plt.figure(figsize=(12, 8))
plt.imshow(mel.squeeze().numpy(), aspect="auto", origin="lower")
return (to_mel.sample_rate, wav.squeeze().numpy()), fig
with gr.Blocks() as demo:
gr.Markdown("# PromptTTS++: Controlling Speaker Identity in Prompt-Based Text-to-Speech Using Natural Language Descriptions")
gr.Markdown("### You can check the [paper](https://arxiv.org/abs/2309.08140) and [code](https://github.com/line/promptttspp).")
gr.Markdown("### NOTE: Please do not enter personal information.")
content_prompt = gr.Textbox(
content_placeholder, lines=3, label="Content prompt"
)
with gr.Tabs():
with gr.TabItem("Style prompt"):
style_prompt = gr.Textbox(
style_placeholder, lines=3, label="Style prompt"
)
syn_button1 = gr.Button("Synthesize")
wav1 = gr.Audio(label="Output wav", elem_id="prompt")
plot1 = gr.Plot(label="Output mel", elem_id="prompt")
with gr.TabItem("Reference wav"):
ref_wav_path = gr.Audio(
type="filepath", label="Reference wav", elem_id="ref"
)
syn_button2 = gr.Button("Synthesize")
wav2 = gr.Audio(label="Output wav", elem_id="ref")
plot2 = gr.Plot(label="Output mel", elem_id="ref")
syn_button1.click(
onclick_with_style_prompt,
inputs=[content_prompt, style_prompt],
outputs=[wav1, plot1],
)
syn_button2.click(
onclick_with_reference_mel,
inputs=[content_prompt, ref_wav_path],
outputs=[wav2, plot2],
)
demo.launch()
@hydra.main(version_base=None, config_path="egs/proposed/bin/conf", config_name="demo")
def main(cfg):
model, vocoder = load_model(
cfg.model, cfg.model_ckpt_path, cfg.vocoder, cfg.vocoder_ckpt_path
)
to_mel = instantiate(cfg.transforms)
# If the NLTK version is 3.9.1, this download code might be necessary.
nltk.download('averaged_perceptron_tagger_eng')
g2p = G2p()
mel_stats = OmegaConf.load(cfg.mel_stats_file)
build_ui(g2p, model, vocoder, to_mel, mel_stats)
if __name__ == "__main__":
main()