import logging import numpy as np import torch from bert_vits2 import commons from bert_vits2 import utils as bert_vits2_utils from bert_vits2.clap_wrapper import get_clap_audio_feature, get_clap_text_feature from bert_vits2.get_emo import get_emo from bert_vits2.models import SynthesizerTrn from bert_vits2.models_v230 import SynthesizerTrn as SynthesizerTrn_v230 from bert_vits2.models_ja_extra import SynthesizerTrn as SynthesizerTrn_ja_extra from bert_vits2.text import * from bert_vits2.text.cleaner import clean_text from bert_vits2.utils import process_legacy_versions from contants import config from utils import get_hparams_from_file from utils.sentence import split_languages class Bert_VITS2: def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs): self.model_path = model_path self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0) self.speakers = [item[0] for item in sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])] self.symbols = symbols self.sampling_rate = self.hps_ms.data.sampling_rate self.bert_model_names = {} self.zh_bert_extra = False self.ja_bert_extra = False self.ja_bert_dim = 1024 self.num_tones = num_tones self.pinyinPlus = None # Compatible with legacy versions self.version = process_legacy_versions(self.hps_ms).lower().replace("-", "_") self.text_extra_str_map = {"zh": "", "ja": "", "en": ""} self.bert_extra_str_map = {"zh": "", "ja": "", "en": ""} self.hps_ms.model.emotion_embedding = None if self.version in ["1.0", "1.0.0", "1.0.1"]: """ chinese-roberta-wwm-ext-large """ self.version = "1.0" self.symbols = symbols_legacy self.hps_ms.model.n_layers_trans_flow = 3 self.lang = getattr(self.hps_ms.data, "lang", ["zh"]) self.ja_bert_dim = 768 self.num_tones = num_tones_v111 self.text_extra_str_map.update({"zh": "_v100"}) elif self.version in ["1.1.0-transition"]: """ chinese-roberta-wwm-ext-large """ self.version = "1.1.0-transition" self.hps_ms.model.n_layers_trans_flow = 3 self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja"]) self.ja_bert_dim = 768 self.num_tones = num_tones_v111 if "ja" in self.lang: self.bert_model_names.update({"ja": "BERT_BASE_JAPANESE_V3"}) self.text_extra_str_map.update({"zh": "_v100", "ja": "_v111"}) self.bert_extra_str_map.update({"ja": "_v111"}) elif self.version in ["1.1", "1.1.0", "1.1.1"]: """ chinese-roberta-wwm-ext-large bert-base-japanese-v3 """ self.version = "1.1" self.hps_ms.model.n_layers_trans_flow = 6 self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja"]) self.ja_bert_dim = 768 self.num_tones = num_tones_v111 if "ja" in self.lang: self.bert_model_names.update({"ja": "BERT_BASE_JAPANESE_V3"}) self.text_extra_str_map.update({"zh": "_v100", "ja": "_v111"}) self.bert_extra_str_map.update({"ja": "_v111"}) elif self.version in ["2.0", "2.0.0", "2.0.1", "2.0.2"]: """ chinese-roberta-wwm-ext-large deberta-v2-large-japanese deberta-v3-large """ self.version = "2.0" self.hps_ms.model.n_layers_trans_flow = 4 self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) self.num_tones = num_tones if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE"}) if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) self.text_extra_str_map.update({"zh": "_v100", "ja": "_v200", "en": "_v200"}) self.bert_extra_str_map.update({"ja": "_v200", "en": "_v200"}) elif self.version in ["2.1", "2.1.0"]: """ chinese-roberta-wwm-ext-large deberta-v2-large-japanese-char-wwm deberta-v3-large wav2vec2-large-robust-12-ft-emotion-msp-dim """ self.version = "2.1" self.hps_ms.model.n_layers_trans_flow = 4 self.hps_ms.model.emotion_embedding = 1 self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) self.num_tones = num_tones if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) elif self.version in ["2.2", "2.2.0"]: """ chinese-roberta-wwm-ext-large deberta-v2-large-japanese-char-wwm deberta-v3-large clap-htsat-fused """ self.version = "2.2" self.hps_ms.model.n_layers_trans_flow = 4 self.hps_ms.model.emotion_embedding = 2 self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) self.num_tones = num_tones if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) elif self.version in ["2.3", "2.3.0"]: """ chinese-roberta-wwm-ext-large deberta-v2-large-japanese-char-wwm deberta-v3-large """ self.version = "2.3" self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) self.num_tones = num_tones self.text_extra_str_map.update({"en": "_v230"}) if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) elif self.version is not None and self.version in ["extra", "zh_clap"]: """ Erlangshen-MegatronBert-1.3B-Chinese clap-htsat-fused """ self.version = "extra" self.hps_ms.model.emotion_embedding = 2 self.hps_ms.model.n_layers_trans_flow = 6 self.lang = ["zh"] self.num_tones = num_tones self.zh_bert_extra = True self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"}) self.bert_extra_str_map.update({"zh": "_extra"}) elif self.version is not None and self.version in ["extra_fix", "2.4", "2.4.0"]: """ Erlangshen-MegatronBert-1.3B-Chinese clap-htsat-fused """ self.version = "2.4" self.hps_ms.model.emotion_embedding = 2 self.hps_ms.model.n_layers_trans_flow = 6 self.lang = ["zh"] self.num_tones = num_tones self.zh_bert_extra = True self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"}) self.bert_extra_str_map.update({"zh": "_extra"}) self.text_extra_str_map.update({"zh": "_v240"}) elif self.version is not None and self.version in ["ja_extra"]: """ deberta-v2-large-japanese-char-wwm """ self.version = "ja_extra" self.hps_ms.model.emotion_embedding = 2 self.hps_ms.model.n_layers_trans_flow = 6 self.lang = ["ja"] self.num_tones = num_tones self.ja_bert_extra = True self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) self.bert_extra_str_map.update({"ja": "_extra"}) self.text_extra_str_map.update({"ja": "_extra"}) else: logging.debug("Version information not found. Loaded as the newest version: v2.3.") self.version = "2.3" self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) self.num_tones = num_tones self.text_extra_str_map.update({"en": "_v230"}) if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) if "zh" in self.lang and "zh" not in self.bert_model_names.keys(): self.bert_model_names.update({"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"}) self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} self.device = device def load_model(self, model_handler): self.model_handler = model_handler if self.version in ["2.3", "extra", "2.4"]: Synthesizer = SynthesizerTrn_v230 elif self.version == "ja_extra": Synthesizer = SynthesizerTrn_ja_extra else: Synthesizer = SynthesizerTrn if self.version == "2.4": self.pinyinPlus = self.model_handler.get_pinyinPlus() self.net_g = Synthesizer( len(self.symbols), self.hps_ms.data.filter_length // 2 + 1, self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, n_speakers=self.hps_ms.data.n_speakers, symbols=self.symbols, ja_bert_dim=self.ja_bert_dim, num_tones=self.num_tones, zh_bert_extra=self.zh_bert_extra, **self.hps_ms.model).to(self.device) _ = self.net_g.eval() bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version) def get_speakers(self): return self.speakers def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7): clean_text_lang_str = language_str + self.text_extra_str_map.get(language_str, "") bert_feature_lang_str = language_str + self.bert_extra_str_map.get(language_str, "") tokenizer, _ = self.model_handler.get_bert_model(self.bert_model_names[language_str]) norm_text, phone, tone, word2ph = clean_text(text, clean_text_lang_str, tokenizer, self.pinyinPlus) phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id) if hps.data.add_blank: phone = commons.intersperse(phone, 0) tone = commons.intersperse(tone, 0) language = commons.intersperse(language, 0) for i in range(len(word2ph)): word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 if style_text == "" or self.zh_bert_extra: style_text = None bert = self.model_handler.get_bert_feature(norm_text, word2ph, bert_feature_lang_str, self.bert_model_names[language_str], style_text, style_weight) del word2ph assert bert.shape[-1] == len(phone), phone if self.zh_bert_extra: zh_bert = bert ja_bert, en_bert = None, None elif self.ja_bert_extra: ja_bert = bert zh_bert, en_bert = None, None elif language_str == "zh": zh_bert = bert ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) en_bert = torch.zeros(1024, len(phone)) elif language_str == "ja": zh_bert = torch.zeros(1024, len(phone)) ja_bert = bert en_bert = torch.zeros(1024, len(phone)) elif language_str == "en": zh_bert = torch.zeros(1024, len(phone)) ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) en_bert = bert else: zh_bert = torch.zeros(1024, len(phone)) ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) en_bert = torch.zeros(1024, len(phone)) assert bert.shape[-1] == len( phone ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" phone = torch.LongTensor(phone) tone = torch.LongTensor(tone) language = torch.LongTensor(language) return zh_bert, ja_bert, en_bert, phone, tone, language def _get_emo(self, reference_audio, emotion): if reference_audio: emo = torch.from_numpy( get_emo(reference_audio, self.model_handler.emotion_model, self.model_handler.emotion_processor)) else: if emotion is None: emotion = 0 emo = torch.Tensor([emotion]) return emo def _get_clap(self, reference_audio, text_prompt): if isinstance(reference_audio, np.ndarray): emo = get_clap_audio_feature(reference_audio, self.model_handler.clap_model, self.model_handler.clap_processor, self.device) else: if text_prompt is None: text_prompt = config.bert_vits2_config.text_prompt emo = get_clap_text_feature(text_prompt, self.model_handler.clap_model, self.model_handler.clap_processor, self.device) emo = torch.squeeze(emo, dim=1).unsqueeze(0) return emo def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length, emo=None): with torch.no_grad(): x_tst = phones.to(self.device).unsqueeze(0) tones = tones.to(self.device).unsqueeze(0) lang_ids = lang_ids.to(self.device).unsqueeze(0) if self.zh_bert_extra: zh_bert = zh_bert.to(self.device).unsqueeze(0) elif self.ja_bert_extra: ja_bert = ja_bert.to(self.device).unsqueeze(0) else: zh_bert = zh_bert.to(self.device).unsqueeze(0) ja_bert = ja_bert.to(self.device).unsqueeze(0) en_bert = en_bert.to(self.device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device) speakers = torch.LongTensor([int(id)]).to(self.device) audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, zh_bert=zh_bert, ja_bert=ja_bert, en_bert=en_bert, sdp_ratio=sdp_ratio, noise_scale=noise, noise_scale_w=noisew, length_scale=length, emo=emo )[0][0, 0].data.cpu().float().numpy() torch.cuda.empty_cache() return audio def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None, text_prompt=None, style_text=None, style_weigth=0.7, **kwargs): zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms, style_text, style_weigth) emo = None if self.hps_ms.model.emotion_embedding == 1: emo = self._get_emo(reference_audio, emotion).to(self.device).unsqueeze(0) elif self.hps_ms.model.emotion_embedding == 2: emo = self._get_clap(reference_audio, text_prompt) return self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length, emo) def infer_multilang(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None, text_prompt=None, style_text=None, style_weigth=0.7, **kwargs): sentences_list = split_languages(text, self.lang, expand_abbreviations=True, expand_hyphens=True) emo = None if self.hps_ms.model.emotion_embedding == 1: emo = self._get_emo(reference_audio, emotion).to(self.device).unsqueeze(0) elif self.hps_ms.model.emotion_embedding == 2: emo = self._get_clap(reference_audio, text_prompt) phones, tones, lang_ids, zh_bert, ja_bert, en_bert = [], [], [], [], [], [] for idx, (_text, lang) in enumerate(sentences_list): skip_start = idx != 0 skip_end = idx != len(sentences_list) - 1 _zh_bert, _ja_bert, _en_bert, _phones, _tones, _lang_ids = self.get_text(_text, lang, self.hps_ms, style_text, style_weigth) if skip_start: _phones = _phones[3:] _tones = _tones[3:] _lang_ids = _lang_ids[3:] _zh_bert = _zh_bert[:, 3:] _ja_bert = _ja_bert[:, 3:] _en_bert = _en_bert[:, 3:] if skip_end: _phones = _phones[:-2] _tones = _tones[:-2] _lang_ids = _lang_ids[:-2] _zh_bert = _zh_bert[:, :-2] _ja_bert = _ja_bert[:, :-2] _en_bert = _en_bert[:, :-2] phones.append(_phones) tones.append(_tones) lang_ids.append(_lang_ids) zh_bert.append(_zh_bert) ja_bert.append(_ja_bert) en_bert.append(_en_bert) zh_bert = torch.cat(zh_bert, dim=1) ja_bert = torch.cat(ja_bert, dim=1) en_bert = torch.cat(en_bert, dim=1) phones = torch.cat(phones, dim=0) tones = torch.cat(tones, dim=0) lang_ids = torch.cat(lang_ids, dim=0) audio = self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length, emo) return audio