|
|
|
|
| import torch |
| import torchaudio |
| import librosa |
| import yaml |
| import numpy as np |
| import soundfile as sf |
| import phonemizer |
| from munch import Munch |
| import os |
| import time |
|
|
| |
| from models import * |
| from utils import * |
| from text_utils import TextCleaner |
| from Utils.PLBERT.util import load_plbert |
| from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule |
|
|
| |
| CONFIG_PATH = "/workspace/trainTTS/StyleTTS2_custom/Configs/config_ft.yml" |
| MODEL_PATH = "/workspace/trainTTS/StyleTTS2_custom/Models/mix5voice/model_iter_00032000.pth" |
| REF_AUDIO_PATH = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai.wav" |
| OUTPUT_WAV = "/workspace/trainTTS/StyleTTS2_custom/test_voice_clone/hue_ban_mai_cut.wav" |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
|
| class StyleTTS2Inference: |
| def __init__(self, config_path, model_path, device=DEVICE): |
| self.device = device |
| self.config = yaml.safe_load(open(config_path)) |
| |
| |
| self.phonemizer = phonemizer.backend.EspeakBackend( |
| language='vi', preserve_punctuation=True, with_stress=True |
| ) |
| self.text_cleaner = TextCleaner() |
| |
| |
| |
| |
| text_aligner = load_ASR_models(self.config['ASR_path'], self.config['ASR_config']) |
| pitch_extractor = load_F0_models(self.config['F0_path']) |
| plbert = load_plbert(self.config['PLBERT_dir']) |
| |
| |
| model_params = recursive_munch(self.config['model_params']) |
| self.model = build_model(model_params, text_aligner, pitch_extractor, plbert) |
| |
| |
| print(f"Loading model from: {model_path}") |
| params = torch.load(model_path, map_location='cpu') |
| |
| |
| if 'net' in params: |
| params = params['net'] |
| |
| for key in self.model: |
| |
| if key not in params: |
| print(f"⚠️ Bỏ qua module '{key}' (không tìm thấy trong checkpoint - OK với model inference)") |
| continue |
| |
| |
| state_dict = params[key] |
| new_state_dict = {} |
| |
| for k, v in state_dict.items(): |
| if k.startswith("module."): |
| new_state_dict[k[len("module."):]] = v |
| else: |
| new_state_dict[k] = v |
| |
| self.model[key].load_state_dict(new_state_dict, strict=True) |
| self.model[key].eval().to(self.device) |
| print(f"✅ Loaded module: {key}") |
|
|
| |
| self.sampler = DiffusionSampler( |
| self.model.diffusion.diffusion, |
| sampler=ADPM2Sampler(), |
| sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), |
| clamp=False |
| ) |
| print("Model initialization complete.\n") |
|
|
| def preprocess_audio(self, audio_path): |
| """Chuyển đổi audio reference thành Style Vector""" |
| wave, sr = librosa.load(audio_path, sr=24000) |
| audio, _ = librosa.effects.trim(wave, top_db=30) |
| |
| to_mel = torchaudio.transforms.MelSpectrogram( |
| n_mels=80, n_fft=2048, win_length=1200, hop_length=300 |
| ) |
| mel = to_mel(torch.from_numpy(audio).float()) |
| mel = (torch.log(1e-5 + mel.unsqueeze(0)) - (-4)) / 4 |
| mel = mel.to(self.device) |
| |
| with torch.no_grad(): |
| ref_s = self.model.style_encoder(mel.unsqueeze(1)) |
| ref_p = self.model.predictor_encoder(mel.unsqueeze(1)) |
| ref_style = torch.cat([ref_s, ref_p], dim=1) |
| |
| return ref_style |
|
|
| def preprocess_text(self, text): |
| """Phonemize và Tokenize văn bản""" |
| text = text.strip() |
| if not text: return None |
| |
| ps = self.phonemizer.phonemize([text])[0] |
| tokens = torch.LongTensor(self.text_cleaner(ps)).to(self.device).unsqueeze(0) |
| |
| tokens = torch.cat([torch.LongTensor([0]).to(self.device).unsqueeze(0), tokens], dim=-1) |
| return tokens |
|
|
| def inference(self, text, ref_style, diffusion_steps=5, alpha=0.3, beta=0.7): |
| """Hàm suy luận cốt lõi""" |
| tokens = self.preprocess_text(text) |
| if tokens is None: return None |
| |
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device) |
| text_mask = length_to_mask(input_lengths).to(self.device) |
|
|
| with torch.no_grad(): |
| |
| t_en = self.model.text_encoder(tokens, input_lengths, text_mask) |
| bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int()) |
| d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2) |
|
|
| |
| s_pred = self.sampler( |
| noise=torch.randn((1, 256)).unsqueeze(1).to(self.device), |
| embedding=bert_dur, |
| features=ref_style, |
| num_steps=diffusion_steps |
| ).squeeze(1) |
|
|
| |
| |
| |
| s = s_pred[:, 128:] * alpha + ref_style[:, 128:] * beta |
| ref = s_pred[:, :128] * alpha + ref_style[:, :128] * beta |
|
|
| |
| d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask) |
| x, _ = self.model.predictor.lstm(d) |
| duration = torch.sigmoid(self.model.predictor.duration_proj(x)).sum(axis=-1) |
| pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
|
|
| |
| pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
| c_frame = 0 |
| for i in range(pred_aln_trg.size(0)): |
| pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
| c_frame += int(pred_dur[i].data) |
|
|
| |
| en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device)) |
| F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s) |
| asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device)) |
| |
| out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0)) |
|
|
| return out.squeeze().cpu().numpy()[..., :-50] |
|
|
| def generate_long_text(self, text, ref_audio_path): |
| """Xử lý văn bản dài bằng cách tách câu""" |
| print(f"Processing audio ref: {ref_audio_path}") |
| ref_style = self.preprocess_audio(ref_audio_path) |
| |
| |
| sentences = text.split('.') |
| wavs = [] |
| |
| start_time = time.time() |
| print("Start synthesizing...") |
| |
| for sent in sentences: |
| if len(sent.strip()) == 0: continue |
| |
| |
| if not sent.strip().endswith('.'): sent += '.' |
| |
| wav = self.inference(sent, ref_style) |
| if wav is not None: |
| wavs.append(wav) |
| |
| silence = np.zeros(int(24000 * 0.1)) |
| wavs.append(silence) |
| |
| full_wav = np.concatenate(wavs) |
| print(f"Done! Total time: {time.time() - start_time:.2f}s") |
| return full_wav |
|
|
| |
| if __name__ == "__main__": |
| |
| tts = StyleTTS2Inference(CONFIG_PATH, MODEL_PATH) |
|
|
| |
| list_texts = ["xin chào việt nam, hôm nay trời rất đẹp"] |
|
|
| full_audio = [] |
| |
| |
| for text in list_texts: |
| audio_segment = tts.generate_long_text(text, REF_AUDIO_PATH) |
| full_audio.append(audio_segment) |
| |
| full_audio.append(np.zeros(int(24000 * 0.5))) |
|
|
| |
| final_wav = np.concatenate(full_audio) |
| sf.write(OUTPUT_WAV, final_wav, 24000) |
| print(f"File saved to: {OUTPUT_WAV}") |