Spaces:
Sleeping
Sleeping
| import os | |
| from numpy import pad | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| import phonemizer | |
| import yaml | |
| from split_audio.models import load_ASR_models, load_F0_models, build_model | |
| from split_audio.utils import mask_from_lens, maximum_path | |
| from split_audio.utils import length_to_mask, recursive_munch | |
| from split_audio.plbert.plbert import load_plbert | |
| from split_audio.text_utils import TextCleaner | |
| import librosa | |
| import numpy as np | |
| import torchaudio | |
| import soundfile as sf | |
| N_MELS = 80; N_FFT = 2048; WIN = 1200; HOP = 300 | |
| MEAN, STD = -4.0, 4.0 | |
| PAD = 5000 | |
| class AudioSplitter: | |
| def __init__(self, language: str, model_name: str = "phoaudio_single_v1", device: str = "cpu"): | |
| self.language = language | |
| self.model_name = model_name | |
| self.backend_phonemizer = phonemizer.backend.EspeakBackend( | |
| language=language, | |
| preserve_punctuation=True, | |
| with_stress=True, | |
| ) | |
| self.device = device | |
| self.textcleaner = TextCleaner() | |
| # try to download the model before using it | |
| try: | |
| hf_hub_download( | |
| repo_id="presencesw/tts", | |
| filename=self.model_name + ".pth", | |
| local_dir="Models", | |
| # local_dir_use_symlinks=False, | |
| token=os.getenv("HF_TOKEN", None) | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| try: | |
| hf_hub_download( | |
| repo_id="presencesw/tts", | |
| filename=self.model_name + ".yml", | |
| local_dir="Models", | |
| # local_dir_use_symlinks=False, | |
| token=os.getenv("HF_TOKEN", None) | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| try: | |
| hf_hub_download( | |
| repo_id="presencesw/tts", | |
| filename=self.model_name + "_asr.yml", | |
| local_dir="Models", | |
| # local_dir_use_symlinks=False, | |
| token=os.getenv("HF_TOKEN", None) | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| try: | |
| hf_hub_download( | |
| repo_id="presencesw/tts", | |
| filename=self.model_name + "_plbert.yml", | |
| local_dir="Models", | |
| # local_dir_use_symlinks=False, | |
| token=os.getenv("HF_TOKEN", None) | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| self.config = yaml.safe_load(open(os.path.join("Models", self.model_name + ".yml"))) | |
| # text_aligner = load_ASR_models(self.config.get("ASR_config"), self.config.get("ASR_path")) | |
| text_aligner = load_ASR_models(self.config.get("ASR_path"), self.config.get("ASR_config")) | |
| pitch_extractor = load_F0_models(self.config.get("F0_path")) | |
| plbert = load_plbert(self.config.get("PLBERT_dir")) | |
| model_params = recursive_munch(self.config["model_params"]) | |
| self.model = build_model(model_params, text_aligner, pitch_extractor, plbert) | |
| _ = [self.model[key].eval() for key in self.model] | |
| _ = [self.model[key].to(self.device) for key in self.model] | |
| params_whole = torch.load(os.path.join("Models", self.model_name + ".pth"), map_location="cpu") | |
| params = params_whole['net'] | |
| for key in self.model: | |
| if key in params: | |
| print('%s loaded' % key) | |
| try: | |
| self.model[key].load_state_dict(params[key]) | |
| except: | |
| from collections import OrderedDict | |
| state_dict = params[key] | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = k[7:] # remove `module.` | |
| new_state_dict[name] = v | |
| # load params | |
| self.model[key].load_state_dict(new_state_dict, strict=False) | |
| # except: | |
| # _load(params[key], model[key]) | |
| _ = [self.model[key].eval() for key in self.model] | |
| self.n_down = self.model.text_aligner.n_down | |
| self.d = 2 ** self.n_down | |
| def find_subsequence(self, seq, subseq): | |
| n, m = len(seq), len(subseq) | |
| if m == 0 or m > n: | |
| return None | |
| for i in range(n - m + 1): | |
| if seq[i:i+m] == subseq: | |
| return i | |
| return None | |
| def to_tokens(self, txt: str): | |
| ps = self.backend_phonemizer.phonemize([txt])[0].strip() | |
| ps = ps.replace("(en)", "").replace("(vi)", "") | |
| return self.textcleaner(ps) | |
| def wav_to_mel(self, wave_1d: np.ndarray): | |
| # if sr_in != sr_target: | |
| # w = torch.from_numpy(wave_1d).float() | |
| # w = torchaudio.functional.resample(w, sr_in, sr_target) | |
| # wave_1d = w.numpy() | |
| wave_pad = np.concatenate( | |
| [np.zeros(PAD, dtype=wave_1d.dtype), wave_1d, np.zeros(PAD, dtype=wave_1d.dtype)] | |
| ) | |
| w = torch.from_numpy(wave_pad).float() | |
| to_mel = torchaudio.transforms.MelSpectrogram( | |
| n_mels=N_MELS, n_fft=N_FFT, win_length=WIN, hop_length=HOP | |
| ) | |
| mel = to_mel(w) # [n_mels, T] | |
| mel = (torch.log(1e-5 + mel).unsqueeze(0) - MEAN) / STD # [1, 80, T] | |
| mel = mel.squeeze(0) # [80, T] | |
| # trim để chia hết cho d | |
| T = mel.shape[1]; T_trim = T - (T % self.d) | |
| if T_trim != T: | |
| mel = mel[:, :T_trim] | |
| wave_pad = wave_pad[: T_trim * HOP] # đồng bộ thời gian | |
| return wave_pad, mel # np.ndarray (đã pad), torch.Tensor [80, T] | |
| def cal_attn(self, mel_len, text_len, mel, tokens): | |
| mask_mel = length_to_mask(mel_len // (2 ** self.n_down)) | |
| text_mask = length_to_mask(text_len) | |
| mels_in = mel.unsqueeze(0) # [1, 80, T] | |
| ppgs, s2s_pred, s2s_attn = self.model.text_aligner(mels_in, mask_mel, tokens) | |
| s2s_attn = s2s_attn.transpose(-1, -2) | |
| s2s_attn = s2s_attn[..., 1:] | |
| s2s_attn = s2s_attn.transpose(-1, -2) | |
| attn_mask = (~mask_mel).unsqueeze(-1).expand(mask_mel.shape[0], mask_mel.shape[1], text_mask.shape[-1]).float().transpose(-1, -2) | |
| attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask_mel.shape[-1]).float() | |
| attn_mask = (attn_mask < 1) | |
| s2s_attn.masked_fill_(attn_mask, 0.0) | |
| mask_ST = mask_from_lens(s2s_attn, text_len, mel_len // (2 ** self.n_down)) | |
| s2s_attn_mono = maximum_path(s2s_attn, mask_ST) | |
| return s2s_attn_mono | |
| def convert_sr(self, wav, orig_sr, target_sr): | |
| if orig_sr != target_sr: | |
| wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=target_sr) | |
| return wav | |
| def load_audio(self, audio_input, target_sr=24000): | |
| if isinstance(audio_input, str): | |
| wav, sr = librosa.load(audio_input, sr=None) | |
| else: | |
| wav = audio_input | |
| sr = target_sr | |
| wav = self.convert_sr(wav, orig_sr=sr, target_sr=target_sr) | |
| return wav, target_sr | |
| def split_audio(self, str_raw: str, str_trunc: str, audio_input): | |
| ps_trunc = self.to_tokens(str_trunc) | |
| ps_raw = self.to_tokens(str_raw) | |
| wav_np, mel = self.wav_to_mel(audio_input) | |
| T = mel.shape[1] | |
| T_trim = T - (T % self.d) | |
| if T_trim != T: | |
| mel = mel[:, :T_trim] | |
| cut_start = self.find_subsequence(ps_raw, ps_trunc) | |
| cut_end = cut_start + len(ps_trunc) | |
| ps_trunc = torch.LongTensor(ps_trunc).unsqueeze(0) | |
| ps_raw = torch.LongTensor(ps_raw).unsqueeze(0) | |
| mel_len = torch.tensor([mel.shape[1]], dtype=torch.long) | |
| text_len = torch.tensor([ps_raw.shape[1]], dtype=torch.long) | |
| s2s_attn_mono = self.cal_attn( | |
| mel_len=mel_len, | |
| text_len=text_len, | |
| mel=mel, | |
| tokens=ps_raw | |
| ) | |
| with torch.no_grad(): | |
| token_per_frame_down = torch.argmax(s2s_attn_mono[0], dim=0) | |
| token_per_frame_down = token_per_frame_down.cpu().numpy() | |
| mask_down = (token_per_frame_down >= cut_start) & (token_per_frame_down < cut_end) | |
| idx_down = np.where(mask_down)[0] | |
| start_frame_down = idx_down[0] | |
| end_frame_down = idx_down[-1] + 1 | |
| start_frame_full = int(start_frame_down * self.d) | |
| end_frame_full = int(end_frame_down * self.d) | |
| start_sample_in_padded = start_frame_full * HOP | |
| end_sample_in_padded = end_frame_full * HOP | |
| start_sample = max(0, start_sample_in_padded - PAD) | |
| end_sample = max(start_sample+1, end_sample_in_padded - PAD) | |
| end_sample = min(end_sample, len(wav_np) - PAD) | |
| y_cut = wav_np[start_sample + PAD : end_sample + PAD] | |
| # margin_frames_full = int(2 * d) | |
| # start_sample = max(0, (start_frame_full - margin_frames_full) * HOP - PAD) | |
| # end_sample = min(len(wav_np) - 1, (end_frame_full + margin_frames_full) * HOP - pad) | |
| # y_cut = wav_np[start_sample + pad : end_sample + pad] | |
| return y_cut | |
| if __name__ == "__main__": | |
| splitter = AudioSplitter(language="vi", model_name="phoaudio_single_v1", device="cpu") | |
| # str_raw = "tôi nghĩ đến vóc dáng của tiết vân phong. có lẽ cậu ta cũng đánh thắng tôi. nhưng mà cân nhắc đến chuyện cậu ta đã uống sai, chắc là không khó mà ứng phó. thế là tôi xua xua tay nói," | |
| # str_trunc = "tôi nghĩ đến vóc dáng của tiết vân phong" | |
| # str_trunc = "nhưng mà cân nhắc đến chuyện cậu ta đã uống sai" | |
| str_raw = "mệt mỏi vì lo lắng. họ ngủ một cách lơ mơ với cái ý thức cảnh giác cố hữu. nhà ga nào cũng đầy bọn trộm cắp. lâu lắm mới nghe tiếng chân của một người." | |
| str_trunc = "họ ngủ một cách lơ mơ với cái ý thức cảnh giác cố hữu" | |
| # audio_input, sr = splitter.load_audio("example_trimmed.wav", sr=None) | |
| # splitter.split_audio(str_raw, str_trunc, audio_input) | |
| audio_input, sr = splitter.load_audio("Đào_Hiếu.wav", target_sr=24000) | |
| y_cut = splitter.split_audio(str_raw, str_trunc, audio_input) | |
| # print(f"audio cut: {y_cut}") | |
| # librosa.output.write_wav("example_cut.wav", y_cut, sr) | |
| # use librosa algorithm to trim the silence | |
| y_cut = librosa.effects.trim(y_cut, top_db=15)[0] | |
| sf.write("example_cut.wav", y_cut, sr) |