VALL-E-X / app.py
Plachta's picture
initial commit
b1e1a76
raw
history blame
No virus
13.1 kB
import argparse
import logging
import os
import pathlib
import time
import tempfile
from pathlib import Path
pathlib.PosixPath = pathlib.PosixPath
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import torch
import torchaudio
import random
import numpy as np
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
from utils.g2p import PhonemeBpeTokenizer
import gradio as gr
import whisper
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
# torch.manual_seed(42)
lang2token = {
'zh': "[ZH]",
'ja': "[JA]",
"en": "[EN]",
}
lang2code = {
'zh': 0,
'ja': 1,
"en": 2,
}
token2lang = {
'[ZH]': "zh",
'[JA]': "ja",
"[EN]": "en",
}
code2lang = {
0: 'zh',
1: 'ja',
2: "en",
}
langdropdown2token = {
'English': "[EN]",
'中文': "[ZH]",
'日本語': "[JA]",
'mix': "",
}
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
text_collater = get_text_token_collater()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
# VALL-E-X model
model = VALLE(
1024,
16,
12,
norm_first=True,
add_prenet=False,
prefix_mode=1,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=8,
)
checkpoint = torch.load("./epoch-10.pt", map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
model.to('cpu')
model.eval()
# Encodec model
audio_tokenizer = AudioTokenizer(device)
# ASR
whisper_model = whisper.load_model("medium").cpu()
def clear_prompts():
try:
path = tempfile.gettempdir()
for eachfile in os.listdir(path):
filename = os.path.join(path, eachfile)
if os.path.isfile(filename) and filename.endswith(".npz"):
lastmodifytime = os.stat(filename).st_mtime
endfiletime = time.time() - 60
if endfiletime > lastmodifytime:
os.remove(filename)
except:
return
def transcribe_one(model, audio_path):
# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device)
# detect the spoken language
_, probs = model.detect_language(mel)
print(f"Detected language: {max(probs, key=probs.get)}")
lang = max(probs, key=probs.get)
# decode the audio
options = whisper.DecodingOptions(beam_size=5)
result = whisper.decode(model, mel, options)
# print the recognized text
print(result.text)
text_pr = result.text
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
text_pr += "."
return lang, text_pr
def make_npz_prompt(name, uploaded_audio, recorded_audio):
global model, text_collater, text_tokenizer, audio_tokenizer
clear_prompts()
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
sr, wav_pr = audio_prompt
wav_pr = torch.FloatTensor(wav_pr) / 32768
if wav_pr.size(-1) == 2:
wav_pr = wav_pr.mean(-1, keepdim=False)
text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
# tokenize audio
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr.unsqueeze(0), sr))
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
# tokenize text
text_tokens, enroll_x_lens = text_collater(
[
text_tokenizer.tokenize(text=f"{text_pr}".strip())
]
)
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
# save as npz file
np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"),
audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr])
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
def make_prompt(name, wav, sr, save=True):
global whisper_model
whisper_model.to(device)
if not isinstance(wav, torch.FloatTensor):
wav = torch.tensor(wav)
if wav.abs().max() > 1:
wav /= wav.abs().max()
if wav.size(-1) == 2:
wav = wav.mean(-1, keepdim=False)
if wav.ndim == 1:
wav = wav.unsqueeze(0)
assert wav.ndim and wav.size(0) == 1
torchaudio.save(f"./prompts/{name}.wav", wav, sr)
lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
lang_token = lang2token[lang]
text = lang_token + text + lang_token
with open(f"./prompts/{name}.txt", 'w') as f:
f.write(text)
if not save:
os.remove(f"./prompts/{name}.wav")
os.remove(f"./prompts/{name}.txt")
whisper_model.cpu()
torch.cuda.empty_cache()
return text, lang
@torch.no_grad()
def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt):
global model, text_collater, text_tokenizer, audio_tokenizer
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
sr, wav_pr = audio_prompt
wav_pr = torch.FloatTensor(wav_pr)/32768
if wav_pr.size(-1) == 2:
wav_pr = wav_pr.mean(-1, keepdim=False)
text_pr, lang_pr = make_prompt(str(random.randint(0, 10000000)), wav_pr, sr, save=False)
lang_token = langdropdown2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
# onload model
model.to(device)
# tokenize audio
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr.unsqueeze(0), sr))
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
# tokenize text
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[
text_tokenizer.tokenize(text=f"{text_pr}{text}".strip())
]
)
enroll_x_lens = None
if text_pr:
_, enroll_x_lens = text_collater(
[
text_tokenizer.tokenize(text=f"{text_pr}".strip())
]
)
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=lang,
)
samples = audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
# offload model
model.to('cpu')
torch.cuda.empty_cache()
message = f"text prompt: {text_pr}\nsythesized text: {text}"
return message, (24000, samples[0][0].cpu().numpy())
@torch.no_grad()
def infer_from_prompt(text, language, accent, prompt_file):
# onload model
model.to(device)
clear_prompts()
# text to synthesize
lang_token = langdropdown2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
# load prompt
prompt_data = np.load(prompt_file.name)
audio_prompts = prompt_data['audio_tokens']
text_prompts = prompt_data['text_tokens']
lang_pr = prompt_data['lang_code']
lang_pr = code2lang[int(lang_pr)]
# numpy to tensor
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
text_prompts = torch.tensor(text_prompts).type(torch.int32)
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
text_tokens, text_tokens_lens = text_collater(
[
text_tokenizer.tokenize(text=f"_{text}".strip())
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
# accent control
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=lang,
)
samples = audio_tokenizer.decode(
[(encoded_frames.transpose(2, 1), None)]
)
# offload model
model.to('cpu')
torch.cuda.empty_cache()
message = f"sythesized text: {text}"
return message, (24000, samples[0][0].cpu().numpy())
def main():
app = gr.Blocks()
with app:
with gr.Tab("Infer from audio"):
with gr.Row():
with gr.Column():
textbox = gr.TextArea(label="Text",
placeholder="Type your sentence here",
value="Hello, it's nice to meet you.", elem_id=f"tts-input")
language_dropdown = gr.Dropdown(choices=['English', '中文', '日本語', 'mix'], value='English', label='language')
accent_dropdown = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent', label='accent')
upload_audio_prompt = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
record_audio_prompt = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
with gr.Column():
text_output = gr.Textbox(label="Message")
audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
btn = gr.Button("Generate!")
btn.click(infer_from_audio,
inputs=[textbox, language_dropdown, accent_dropdown, upload_audio_prompt, record_audio_prompt],
outputs=[text_output, audio_output])
textbox_mp = gr.TextArea(label="Prompt name",
placeholder="Name your prompt here",
value="prompt_1", elem_id=f"prompt-name")
btn_mp = gr.Button("Make prompt!")
prompt_output = gr.File(interactive=False)
btn_mp.click(make_npz_prompt,
inputs=[textbox_mp, upload_audio_prompt, record_audio_prompt],
outputs=[text_output, prompt_output])
with gr.Tab("Make prompt"):
with gr.Row():
with gr.Column():
textbox2 = gr.TextArea(label="Prompt name",
placeholder="Name your prompt here",
value="prompt_1", elem_id=f"prompt-name")
upload_audio_prompt_2 = gr.Audio(label='uploaded audio prompt', source='upload', interactive=True)
record_audio_prompt_2 = gr.Audio(label='recorded audio prompt', source='microphone', interactive=True)
with gr.Column():
text_output_2 = gr.Textbox(label="Message")
prompt_output_2 = gr.File(interactive=False)
btn_2 = gr.Button("Make!")
btn_2.click(make_npz_prompt,
inputs=[textbox2, upload_audio_prompt_2, record_audio_prompt_2],
outputs=[text_output_2, prompt_output_2])
with gr.Tab("Infer from prompt"):
with gr.Row():
with gr.Column():
textbox_3 = gr.TextArea(label="Text",
placeholder="Type your sentence here",
value="Hello, it's nice to meet you.", elem_id=f"tts-input")
language_dropdown_3 = gr.Dropdown(choices=['English', '中文', '日本語', 'mix'], value='English',
label='language')
accent_dropdown_3 = gr.Dropdown(choices=['no-accent', 'English', '中文', '日本語'], value='no-accent',
label='accent')
prompt_file = gr.File(file_count='single', file_types=['.npz'], interactive=True)
with gr.Column():
text_output_3 = gr.Textbox(label="Message")
audio_output_3 = gr.Audio(label="Output Audio", elem_id="tts-audio")
btn_3 = gr.Button("Generate!")
btn_3.click(infer_from_prompt,
inputs=[textbox_3, language_dropdown_3, accent_dropdown_3, prompt_file],
outputs=[text_output_3, audio_output_3])
app.launch()
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()