KevinGeng's picture
support local vocoder
7f97911
"""
TODO:
+ [x] Load Configuration
+ [ ] Checking
+ [ ] Better saving directory
"""
import numpy as np
from pathlib import Path
import torch.nn as nn
import torch
import torchaudio
from transformers import pipeline
from pathlib import Path
import pdb
# local import
import sys
from espnet2.bin.tts_inference import Text2Speech
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC# pdb.set_trace()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sys.path.append("src")
import gradio as gr
# ASR part
audio_files = [
str(x)
for x in sorted(
Path(
"/home/kevingeng/Disk2/laronix/laronix_automos/data/20230103_video"
).glob("**/*wav")
)
]
# audio_files = [str(x) for x in sorted(Path("./data/Patient_sil_trim_16k_normed_5_snr_40/Rainbow").glob("**/*wav"))]
# transcriber = pipeline(
# "automatic-speech-recognition",
# model="KevinGeng/PAL_John_128_train_dev_test_seed_1",
# )
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
processor = AutoProcessor.from_pretrained("KevinGeng/whipser_medium_en_PAL300_step25")
model = AutoModelForSpeechSeq2Seq.from_pretrained("KevinGeng/whipser_medium_en_PAL300_step25")
transcriber = pipeline("automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25")
# Text2Mel models
# @title English multi-speaker pretrained model { run: "auto" }
lang = "English"
vits_tag = "kan-bayashi/libritts_xvector_vits"
ft2_tag = "kan-bayashi/libritts_xvector_vits" #@param ["kan-bayashi/vctk_gst_tacotron2", "kan-bayashi/vctk_gst_transformer", "kan-bayashi/vctk_xvector_tacotron2", "kan-bayashi/vctk_xvector_transformer", "kan-bayashi/vctk_xvector_conformer_fastspeech2", "kan-bayashi/vctk_gst+xvector_tacotron2", "kan-bayashi/vctk_gst+xvector_transformer", "kan-bayashi/vctk_gst+xvector_conformer_fastspeech2", "kan-bayashi/vctk_multi_spk_vits", "kan-bayashi/vctk_full_band_multi_spk_vits", "kan-bayashi/libritts_xvector_transformer"
# ft2_tag = "kan-bayashi/libritts_xvector_conformer_fastspeech2"
transformer_tag = "kan-bayashi/libritts_xvector_transformer"
# !!! vits needs no vocoder !!!
# Local Text2Mel models
vits_config_local = "TTS_models/libritts_xvector_vits/config.yaml"
vits_model_local = "TTS_models/libritts_xvector_vits/train.total_count.ave_10best.pth"
# TODO
ft2_config_local = ""
ft2_model_local= ""
transformer_config_local = ""
transformer_config_local = ""
# Vocoders
vocoder_tag = "parallel_wavegan/vctk_parallel_wavegan.v1.long" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"}
hifigan_vocoder_tag = "parallel_wavegan/parallel_wavegan/libritts_hifigan.v1" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"}
# Local Vocoders
## Make sure the use parallel_wavegan as prefix (PWG feature)
vocoder_tag_local = "parallel_wavegan/vctk_parallel_wavegan.v1.long"
hifigan_vocoder_tag_local = "parallel_wavegan/libritts_hifigan.v1"
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
# local import
text2speech = Text2Speech.from_pretrained(
train_config = vits_config_local,
model_file=vits_model_local,
device="cuda",
use_att_constraint=False,
backward_window=1,
forward_window=3,
speed_control_alpha=1.0,
)
# Fastspeech2
ft2_text2speech = Text2Speech.from_pretrained(
model_tag=ft2_tag,
vocoder_tag=str_or_none(vocoder_tag_local),
device="cuda",
use_att_constraint=False,
backward_window=1,
forward_window=3,
speed_control_alpha=1.0,
)
# Fastspeech2 + hifigan
ft2_text2speech_hifi = Text2Speech.from_pretrained(
model_tag=ft2_tag,
vocoder_tag=str_or_none(hifigan_vocoder_tag_local),
device="cuda",
use_att_constraint=False,
backward_window=1,
forward_window=3,
speed_control_alpha=1.0,
)
# transformer tag
transformer_text2speech = Text2Speech.from_pretrained(
model_tag=transformer_tag,
vocoder_tag=str_or_none(vocoder_tag_local),
device="cuda",
use_att_constraint=False,
backward_window=1,
forward_window=3,
speed_control_alpha=1.0,
)
import glob
import os
import numpy as np
import kaldiio
# Get model directory path
# from espnet_model_zoo.downloader import ModelDownloader
# d = ModelDownloader()
# model_dir = os.path.dirname(d.download_and_unpack(tag)["train_config"])
# Speaker x-vector selection
xvector_ark = [
p
for p in glob.glob(
f"xvector/test-clean/spk_xvector.ark", recursive=True
)
if "test" in p
][0]
xvectors = {k: v for k, v in kaldiio.load_ark(xvector_ark)}
spks = list(xvectors.keys())
male_spks = {
"Male1": "260_123286",
"Male2": "1320_122612",
"Male3": "672_122797"
}
female_spks = {"Female1": "5683_32865",
"Female2": "121_121726",
"Female3": "8463_287645"}
spks = dict(male_spks, **female_spks)
spk_names = sorted(spks.keys())
def ASRTTS(audio_file, spk_name, ref_text=""):
spk = spks[spk_name]
spembs = xvectors[spk]
if ref_text == "":
reg_text = transcriber(audio_file)["text"]
else:
reg_text = ref_text
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id, reg_text
def ASRTTS_clean(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
def ft2_ASRTTS_clean(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = ft2_text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_fs2_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
def ft2_ASRTTS_clean_hifi(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = ft2_text2speech_hifi(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_fs2_hifi_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
def transformer_ASRTTS_clean(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = transformer_text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_transformer_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
# def google_ASRTTS_clean(audio_file, spk_name):
# spk = spks[spk_name]
# spembs = xvectors[spk]
# reg_text = transcriber(audio_file)["text"]
# # pdb.set_trace()
# synthesis_input = texttospeech.SynthesisInput(text=reg_text)
# voice = texttospeech.VoiceSelectionParams(
# language_code="en-US", ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL
# )
# audio_config = texttospeech.AudioConfig(
# audio_encoding=texttospeech.AudioEncoding.MP3
# )
# response = Google_TTS_client.synthesize_speech(
# input=synthesis_input, voice=voice, audio_config=audio_config
# )
# save_id = (
# "./wav/" + Path(audio_file).stem + "_google_" + spk_name + "_spkembs.wav"
# )
# with open(save_id, "wb") as out_file:
# out_file.write(response.audio_content)
# return save_id
reference_textbox = gr.Textbox(
value="",
placeholder="Input reference here",
label="Reference",
)
recognization_textbox = gr.Textbox(
value="",
placeholder="Output recognization here",
label="recognization_textbox",
)
speaker_option = gr.Radio(choices=spk_names, label="Speaker")
input_audio = gr.Audio(
source="upload", type="filepath", label="Audio_to_Evaluate"
)
output_audio = gr.Audio(
source="upload", file="filepath", label="Synthesized Audio"
)
examples = [
["./samples/001.wav", "M1", ""],
["./samples/002.wav", "M2", ""],
["./samples/003.wav", "F1", ""],
["./samples/004.wav", "F2", ""],
]
def change_audiobox(choice):
if choice == "upload":
input_audio = gr.Audio.update(source="upload", visible=True)
elif choice == "microphone":
input_audio = gr.Audio.update(source="microphone", visible=True)
else:
input_audio = gr.Audio.update(visible=False)
return input_audio
def show_icon(choice):
if choice == "Male1":
spk_icon = gr.Image.update(value="speaker_icons/male1.png", visible=True)
elif choice == "Male2":
spk_icon = gr.Image.update(value="speaker_icons/male2.png", visible=True)
elif choice == "Male3":
spk_icon = gr.Image.update(value="speaker_icons/male3.png", visible=True)
elif choice == "Female1":
spk_icon = gr.Image.update(value="speaker_icons/female1.png", visible=True)
elif choice == "Female2":
spk_icon = gr.Image.update(value="speaker_icons/female2.png", visible=True)
elif choice == "Female3":
spk_icon = gr.Image.update(value="speaker_icons/female3.png", visible=True)
return spk_icon
def get_download_file(audio_file=None):
if audio_file == None:
output_audio_file = gr.File.update(visible=False)
else:
output_audio_file = gr.File.update(visible=True)
return output_audio_file
def download_file(audio_file):
return gr.File(value=audio_file)
# pdb.set_trace()
with gr.Blocks(
analytics_enabled=False,
css=".gradio-container {background-color: #78BD91}",
) as demo:
# Open Version
with gr.Tab("Open Version"):
with gr.Column(elem_id="Column"):
input_format = gr.Radio(
choices=["microphone", "upload"], label="Choose your input format", elem_id="input_format"
)
input_audio = gr.Audio(
source="microphone",
type="filepath",
label="Input Audio",
interactive=True,
visible=False,
elem_id="input_audio"
)
input_format.change(
fn=change_audiobox, inputs=input_format, outputs=input_audio
)
speaker_option = gr.Radio(choices=spk_names, value="Male1", label="Choose your voice profile")
spk_icon = gr.Image(value="speaker_icons/male1.png",
type="filepath",
image_mode="RGB",
source="upload",
shape=[50, 50],
interactive=True,
visible=True)
speaker_option.change(
fn=show_icon, inputs=speaker_option, outputs=spk_icon
)
b = gr.Button("Convert")
output_audio = gr.Audio(
source="upload", file="filepath", label="Converted Audio", interactive=False
)
b.click(
ASRTTS_clean,
inputs=[input_audio, speaker_option],
outputs=output_audio,
api_name="convert"
)
# Tab selection:
with gr.Tab("Test Version: Multi TTS model"):
with gr.Column(elem_id="Column"):
input_format = gr.Radio(
choices=["microphone", "upload"], label="Choose your input format", elem_id="input_format"
)
input_audio = gr.Audio(
source="microphone",
type="filepath",
label="Input Audio",
interactive=True,
visible=False,
elem_id="input_audio"
)
input_format.change(
fn=change_audiobox, inputs=input_format, outputs=input_audio
)
speaker_option = gr.Radio(choices=spk_names, value="Male1", label="Choose your voice profile")
spk_icon = gr.Image(value="speaker_icons/male1.png",
type="filepath",
image_mode="RGB",
source="upload",
shape=[50, 50],
interactive=True,
visible=True)
speaker_option.change(
fn=show_icon, inputs=speaker_option, outputs=spk_icon
)
with gr.Column():
with gr.Row():
b2 = gr.Button("Convert")
output_audio = gr.Audio(
source="upload", file="filepath", label="Converted Audio", interactive=False
)
b2.click(
ASRTTS_clean,
inputs=[input_audio, speaker_option],
outputs=output_audio,
api_name="convert_"
)
with gr.Row():
# Fastspeech2 + PWG [under construction]
b_ft2 = gr.Button("Convert_fastspeech2")
output_audio_ft2= gr.Audio(
source="upload", file="filepath", label="Converted Audio", interactive=False
)
b_ft2.click(
ft2_ASRTTS_clean,
inputs=[input_audio, speaker_option],
outputs=output_audio_ft2,
api_name="convert_ft2"
)
with gr.Row():
# Fastspeech2 + hifigan [under construction]
b_ft2_hifi = gr.Button("Convert_fastspeech2+HifiGAN")
output_audio_ft2_hifi= gr.Audio(
source="upload", file="filepath", label="Converted Audio", interactive=False
)
b_ft2_hifi.click(
ft2_ASRTTS_clean_hifi,
inputs=[input_audio, speaker_option],
outputs=output_audio_ft2_hifi,
api_name="convert_ft2_hifi"
)
with gr.Row():
# transformer [TODO]
b_transformer = gr.Button("Convert_transformer")
output_audio_transformer= gr.Audio(
source="upload", file="filepath", label="Converted Audio", interactive=False
)
b_transformer.click(
transformer_ASRTTS_clean,
inputs=[input_audio, speaker_option],
outputs=output_audio_transformer,
api_name="convert_trans"
)
# google tts [TODO]
# b_google = gr.Button("Convert_googleTTS")
# output_audio_google= gr.Audio(
# source="upload", file="filepath", label="Converted Audio", interactive=False
# )
# b_google.click(
# google_ASRTTS_clean,
# inputs=[input_audio, speaker_option],
# outputs=output_audio_google,
# api_name="convert"
# )
demo.launch(share=False)