VALL-E-X / utils /generation.py
Plachta's picture
Fix OOM
f330917
import os
import torch
import gdown
import logging
import psutil
import langid
langid.set_languages(['en', 'zh', 'ja'])
import pathlib
import platform
if platform.system().lower() == 'windows':
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
elif platform.system().lower() == 'linux':
temp = pathlib.WindowsPath
pathlib.WindowsPath = pathlib.PosixPath
import numpy as np
from data.tokenizer import (
AudioTokenizer,
tokenize_audio,
)
from data.collation import get_text_token_collater
from models.vallex import VALLE
from utils.g2p import PhonemeBpeTokenizer
from utils.sentence_cutter import split_text_into_sentences
from macros import *
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
url = 'https://drive.google.com/file/d/10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl/view?usp=sharing'
checkpoints_dir = "./checkpoints/"
model_checkpoint_name = "vallex-checkpoint.pt"
model = None
codec = None
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json")
text_collater = get_text_token_collater()
def preload_models():
global model, codec
if not os.path.exists(checkpoints_dir): os.mkdir(checkpoints_dir)
if not os.path.exists(os.path.join(checkpoints_dir, model_checkpoint_name)):
gdown.download(id="10gdQWvP-K_e1undkvv0p2b7SU6I4Egyl", output=os.path.join(checkpoints_dir, model_checkpoint_name), quiet=False)
# VALL-E
model = VALLE(
N_DIM,
NUM_HEAD,
NUM_LAYERS,
norm_first=True,
add_prenet=False,
prefix_mode=PREFIX_MODE,
share_embedding=True,
nar_scale_factor=1.0,
prepend_bos=True,
num_quantizers=NUM_QUANTIZERS,
).to(device)
checkpoint = torch.load(os.path.join(checkpoints_dir, model_checkpoint_name), map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(
checkpoint["model"], strict=True
)
assert not missing_keys
model.eval()
# Encodec
codec = AudioTokenizer(device)
@torch.no_grad()
def generate_audio(text, prompt=None, language='auto', accent='no-accent'):
global model, codec, text_tokenizer, text_collater
text = text.replace("\n", "").strip(" ")
# detect language
if language == "auto":
language = langid.classify(text)[0]
lang_token = lang2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
# load prompt
if prompt is not None:
prompt_path = prompt
if not os.path.exists(prompt_path):
prompt_path = "./presets/" + prompt + ".npz"
if not os.path.exists(prompt_path):
prompt_path = "./customs/" + prompt + ".npz"
if not os.path.exists(prompt_path):
raise ValueError(f"Cannot find prompt {prompt}")
prompt_data = np.load(prompt_path)
audio_prompts = prompt_data['audio_tokens']
text_prompts = prompt_data['text_tokens']
lang_pr = prompt_data['lang_code']
lang_pr = code2lang[int(lang_pr)]
# numpy to tensor
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
text_prompts = torch.tensor(text_prompts).type(torch.int32)
else:
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
text_prompts = torch.zeros([1, 0]).type(torch.int32)
lang_pr = lang if lang != 'mix' else 'en'
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
# accent control
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
samples = codec.decode(
[(encoded_frames.transpose(2, 1), None)]
)
return samples[0][0].cpu().numpy()
@torch.no_grad()
def generate_audio_from_long_text(text, prompt=None, language='auto', accent='no-accent', mode='sliding-window'):
"""
For long audio generation, two modes are available.
fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence.
sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance.
"""
global model, codec, text_tokenizer, text_collater
if prompt is None or prompt == "":
mode = 'sliding-window' # If no prompt is given, use sliding-window mode
sentences = split_text_into_sentences(text)
# detect language
if language == "auto":
language = langid.classify(text)[0]
# if initial prompt is given, encode it
if prompt is not None and prompt != "":
prompt_path = prompt
if not os.path.exists(prompt_path):
prompt_path = "./presets/" + prompt + ".npz"
if not os.path.exists(prompt_path):
prompt_path = "./customs/" + prompt + ".npz"
if not os.path.exists(prompt_path):
raise ValueError(f"Cannot find prompt {prompt}")
prompt_data = np.load(prompt_path)
audio_prompts = prompt_data['audio_tokens']
text_prompts = prompt_data['text_tokens']
lang_pr = prompt_data['lang_code']
lang_pr = code2lang[int(lang_pr)]
# numpy to tensor
audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device)
text_prompts = torch.tensor(text_prompts).type(torch.int32)
else:
audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
text_prompts = torch.zeros([1, 0]).type(torch.int32)
lang_pr = language if language != 'mix' else 'en'
if mode == 'fixed-prompt':
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
for text in sentences:
text = text.replace("\n", "").strip(" ")
if text == "":
continue
lang_token = lang2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
# accent control
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
samples = codec.decode(
[(complete_tokens, None)]
)
return samples[0][0].cpu().numpy()
elif mode == "sliding-window":
complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device)
original_audio_prompts = audio_prompts
original_text_prompts = text_prompts
for text in sentences:
text = text.replace("\n", "").strip(" ")
if text == "":
continue
lang_token = lang2token[language]
lang = token2lang[lang_token]
text = lang_token + text + lang_token
enroll_x_lens = text_prompts.shape[-1]
logging.info(f"synthesize text: {text}")
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
text_tokens, text_tokens_lens = text_collater(
[
phone_tokens
]
)
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
text_tokens_lens += enroll_x_lens
# accent control
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
encoded_frames = model.inference(
text_tokens.to(device),
text_tokens_lens.to(device),
audio_prompts,
enroll_x_lens=enroll_x_lens,
top_k=-100,
temperature=1,
prompt_language=lang_pr,
text_language=langs if accent == "no-accent" else lang,
)
complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1)
if torch.rand(1) < 0.5:
audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:]
text_prompts = text_tokens[:, enroll_x_lens:]
else:
audio_prompts = original_audio_prompts
text_prompts = original_text_prompts
samples = codec.decode(
[(complete_tokens, None)]
)
return samples[0][0].cpu().numpy()
else:
raise ValueError(f"No such mode {mode}")