japanese-tts / app.py
esnya
:tada: feat!: First commit
5034c86
raw
history blame contribute delete
No virus
2.72 kB
from typing import cast
import gradio as gr
import numpy as np
import torch
from transformers import SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan
from speecht5_openjtalk_tokenizer import SpeechT5OpenjtalkTokenizer
import pandas as pd
import transformers
setattr(transformers, SpeechT5OpenjtalkTokenizer.__name__, SpeechT5OpenjtalkTokenizer)
class SpeechT5OpenjtalkProcessor(SpeechT5Processor):
tokenizer_class = SpeechT5OpenjtalkTokenizer.__name__
model = SpeechT5ForTextToSpeech.from_pretrained("esnya/japanese_speecht5_tts")
assert isinstance(model, SpeechT5ForTextToSpeech)
processor = SpeechT5OpenjtalkProcessor.from_pretrained("esnya/japanese_speecht5_tts")
assert isinstance(processor, SpeechT5OpenjtalkProcessor)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
assert isinstance(vocoder, SpeechT5HifiGan)
if torch.cuda.is_available():
model = model.cuda()
vocoder = vocoder.cuda()
def convert_float32_to_int16(wav: np.ndarray) -> np.ndarray:
assert wav.dtype == np.float32
return np.clip(wav * 32768.0, -32768.0, 32767.0).astype(np.int16)
@torch.inference_mode()
def text_to_speech(
text: str,
threshold: float = 0.5,
minlenratio: float = 0.0,
maxlenratio: float = 10.0,
):
speaker_embeddings = (
torch.rand(
(1, model.config.speaker_embedding_dim),
dtype=torch.float32,
device=model.device,
)
* 2
- 1
)
input_ids = processor(text=text, return_tensors="pt")
assert input_ids is not None
input_ids = input_ids.input_ids.to(model.device)
speaker_embeddings = cast(torch.FloatTensor, speaker_embeddings)
wav = model.generate_speech(
input_ids,
speaker_embeddings,
threshold=threshold,
minlenratio=minlenratio,
maxlenratio=maxlenratio,
vocoder=vocoder,
)
wav = cast(torch.FloatTensor, wav)
wav = convert_float32_to_int16(wav.reshape(-1).cpu().float().numpy())
return [
(vocoder.config.sampling_rate, wav),
pd.DataFrame(
{
"dim": range(speaker_embeddings.shape[-1]),
"value": speaker_embeddings[0].cpu().float().numpy(),
}
),
]
demo = gr.Interface(
fn=text_to_speech,
inputs=[
"text",
gr.Slider(0, 0.5, 0.5, label="threshold"),
gr.Slider(0, 100, 0, label="minlenratio"),
gr.Slider(0, 100, 10, label="maxlenratio"),
],
outputs=[
"audio",
gr.BarPlot(
label="speaker_embedding (random generated)",
x="dim",
y="value",
y_lim=[-1, 1],
),
],
)
demo.launch()