API_MC_AI / VietTTS /frontend.py
duyv's picture
Upload 86 files
a257816 verified
import os
import torch
import torchaudio
import whisper
import onnxruntime
import numpy as np
import torchaudio.compliance.kaldi as kaldi
from typing import Callable, List, Union
from functools import partial
from loguru import logger
from VietTTS.utils.frontend_utils import split_text, normalize_text, mel_spectrogram
from VietTTS.tokenizer.tokenizer import get_tokenizer
class TTSFrontEnd:
def __init__(
self,
speech_embedding_model: str,
speech_tokenizer_model: str,
):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = get_tokenizer()
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.speech_embedding_session = onnxruntime.InferenceSession(
speech_embedding_model,
sess_options=option,
providers=["CPUExecutionProvider"]
)
self.speech_tokenizer_session = onnxruntime.InferenceSession(
speech_tokenizer_model,
sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else "CPUExecutionProvider"]
)
self.spk2info = {}
def _extract_text_token(self, text: str):
text_token = self.tokenizer.encode(text, allowed_special='all')
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_speech_token(self, speech: torch.Tensor):
if speech.shape[1] / 16000 > 30:
speech = speech[:, :int(16000 * 30)]
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(
None,
{self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)}
)[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, speech: torch.Tensor):
feat = kaldi.fbank(
waveform=speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000
)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.speech_embedding_session.run(
None,
{self.speech_embedding_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()}
)[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, speech: torch.Tensor):
speech_feat = mel_spectrogram(
y=speech,
n_fft=1024,
num_mels=80,
sampling_rate=22050,
hop_size=256,
win_size=1024,
fmin=0,
fmax=8000,
center=False
).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
return speech_feat, speech_feat_len
def preprocess_text(self, text, split=True) -> Union[str, List[str]]:
text = normalize_text(text)
if split:
text = list(split_text(
text=text,
tokenize=partial(self.tokenizer.encode, allowed_special='all'),
token_max_n=30,
token_min_n=10,
merge_len=5,
comma_split=False
))
return text
def frontend_tts(
self,
text: str,
prompt_speech_16k: Union[np.ndarray, torch.Tensor]
) -> dict:
if isinstance(prompt_speech_16k, np.ndarray):
prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
text_token, text_token_len = self._extract_text_token(text)
speech_token, speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
embedding = self._extract_spk_embedding(prompt_speech_16k)
model_input = {
'text': text_token,
'text_len': text_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat,
'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding,
'flow_embedding': embedding
}
return model_input
def frontend_vc(
self,
source_speech_16k: Union[np.ndarray, torch.Tensor],
prompt_speech_16k: Union[np.ndarray, torch.Tensor]
) -> dict:
if isinstance(source_speech_16k, np.ndarray):
source_speech_16k = torch.from_numpy(source_speech_16k)
if isinstance(prompt_speech_16k, np.ndarray):
prompt_speech_16k = torch.from_numpy(prompt_speech_16k)
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
embedding = self._extract_spk_embedding(prompt_speech_16k)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {
'source_speech_token': source_speech_token,
'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token,
'flow_prompt_speech_token_len': prompt_speech_token_len,
'prompt_speech_feat': prompt_speech_feat,
'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding
}
return model_input