Spaces:
Running
Running
import logging | |
import numpy as np | |
import torch | |
from bert_vits2 import commons | |
from bert_vits2 import utils as bert_vits2_utils | |
from bert_vits2.clap_wrapper import get_clap_audio_feature, get_clap_text_feature | |
from bert_vits2.get_emo import get_emo | |
from bert_vits2.models import SynthesizerTrn | |
from bert_vits2.models_v230 import SynthesizerTrn as SynthesizerTrn_v230 | |
from bert_vits2.models_ja_extra import SynthesizerTrn as SynthesizerTrn_ja_extra | |
from bert_vits2.text import * | |
from bert_vits2.text.cleaner import clean_text | |
from bert_vits2.utils import process_legacy_versions | |
from contants import config | |
from utils import get_hparams_from_file | |
from utils.sentence import split_languages | |
class Bert_VITS2: | |
def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs): | |
self.model_path = model_path | |
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config | |
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0) | |
self.speakers = [item[0] for item in | |
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])] | |
self.symbols = symbols | |
self.sampling_rate = self.hps_ms.data.sampling_rate | |
self.bert_model_names = {} | |
self.zh_bert_extra = False | |
self.ja_bert_extra = False | |
self.ja_bert_dim = 1024 | |
self.num_tones = num_tones | |
self.pinyinPlus = None | |
# Compatible with legacy versions | |
self.version = process_legacy_versions(self.hps_ms).lower().replace("-", "_") | |
self.text_extra_str_map = {"zh": "", "ja": "", "en": ""} | |
self.bert_extra_str_map = {"zh": "", "ja": "", "en": ""} | |
self.hps_ms.model.emotion_embedding = None | |
if self.version in ["1.0", "1.0.0", "1.0.1"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
""" | |
self.version = "1.0" | |
self.symbols = symbols_legacy | |
self.hps_ms.model.n_layers_trans_flow = 3 | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh"]) | |
self.ja_bert_dim = 768 | |
self.num_tones = num_tones_v111 | |
self.text_extra_str_map.update({"zh": "_v100"}) | |
elif self.version in ["1.1.0-transition"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
""" | |
self.version = "1.1.0-transition" | |
self.hps_ms.model.n_layers_trans_flow = 3 | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja"]) | |
self.ja_bert_dim = 768 | |
self.num_tones = num_tones_v111 | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "BERT_BASE_JAPANESE_V3"}) | |
self.text_extra_str_map.update({"zh": "_v100", "ja": "_v111"}) | |
self.bert_extra_str_map.update({"ja": "_v111"}) | |
elif self.version in ["1.1", "1.1.0", "1.1.1"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
bert-base-japanese-v3 | |
""" | |
self.version = "1.1" | |
self.hps_ms.model.n_layers_trans_flow = 6 | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja"]) | |
self.ja_bert_dim = 768 | |
self.num_tones = num_tones_v111 | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "BERT_BASE_JAPANESE_V3"}) | |
self.text_extra_str_map.update({"zh": "_v100", "ja": "_v111"}) | |
self.bert_extra_str_map.update({"ja": "_v111"}) | |
elif self.version in ["2.0", "2.0.0", "2.0.1", "2.0.2"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
deberta-v2-large-japanese | |
deberta-v3-large | |
""" | |
self.version = "2.0" | |
self.hps_ms.model.n_layers_trans_flow = 4 | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) | |
self.num_tones = num_tones | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE"}) | |
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) | |
self.text_extra_str_map.update({"zh": "_v100", "ja": "_v200", "en": "_v200"}) | |
self.bert_extra_str_map.update({"ja": "_v200", "en": "_v200"}) | |
elif self.version in ["2.1", "2.1.0"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
deberta-v2-large-japanese-char-wwm | |
deberta-v3-large | |
wav2vec2-large-robust-12-ft-emotion-msp-dim | |
""" | |
self.version = "2.1" | |
self.hps_ms.model.n_layers_trans_flow = 4 | |
self.hps_ms.model.emotion_embedding = 1 | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) | |
self.num_tones = num_tones | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) | |
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) | |
elif self.version in ["2.2", "2.2.0"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
deberta-v2-large-japanese-char-wwm | |
deberta-v3-large | |
clap-htsat-fused | |
""" | |
self.version = "2.2" | |
self.hps_ms.model.n_layers_trans_flow = 4 | |
self.hps_ms.model.emotion_embedding = 2 | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) | |
self.num_tones = num_tones | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) | |
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) | |
elif self.version in ["2.3", "2.3.0"]: | |
""" | |
chinese-roberta-wwm-ext-large | |
deberta-v2-large-japanese-char-wwm | |
deberta-v3-large | |
""" | |
self.version = "2.3" | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) | |
self.num_tones = num_tones | |
self.text_extra_str_map.update({"en": "_v230"}) | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) | |
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) | |
elif self.version is not None and self.version in ["extra", "zh_clap"]: | |
""" | |
Erlangshen-MegatronBert-1.3B-Chinese | |
clap-htsat-fused | |
""" | |
self.version = "extra" | |
self.hps_ms.model.emotion_embedding = 2 | |
self.hps_ms.model.n_layers_trans_flow = 6 | |
self.lang = ["zh"] | |
self.num_tones = num_tones | |
self.zh_bert_extra = True | |
self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"}) | |
self.bert_extra_str_map.update({"zh": "_extra"}) | |
elif self.version is not None and self.version in ["extra_fix", "2.4", "2.4.0"]: | |
""" | |
Erlangshen-MegatronBert-1.3B-Chinese | |
clap-htsat-fused | |
""" | |
self.version = "2.4" | |
self.hps_ms.model.emotion_embedding = 2 | |
self.hps_ms.model.n_layers_trans_flow = 6 | |
self.lang = ["zh"] | |
self.num_tones = num_tones | |
self.zh_bert_extra = True | |
self.bert_model_names.update({"zh": "Erlangshen_MegatronBert_1.3B_Chinese"}) | |
self.bert_extra_str_map.update({"zh": "_extra"}) | |
self.text_extra_str_map.update({"zh": "_v240"}) | |
elif self.version is not None and self.version in ["ja_extra"]: | |
""" | |
deberta-v2-large-japanese-char-wwm | |
""" | |
self.version = "ja_extra" | |
self.hps_ms.model.emotion_embedding = 2 | |
self.hps_ms.model.n_layers_trans_flow = 6 | |
self.lang = ["ja"] | |
self.num_tones = num_tones | |
self.ja_bert_extra = True | |
self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) | |
self.bert_extra_str_map.update({"ja": "_extra"}) | |
self.text_extra_str_map.update({"ja": "_extra"}) | |
else: | |
logging.debug("Version information not found. Loaded as the newest version: v2.3.") | |
self.version = "2.3" | |
self.lang = getattr(self.hps_ms.data, "lang", ["zh", "ja", "en"]) | |
self.num_tones = num_tones | |
self.text_extra_str_map.update({"en": "_v230"}) | |
if "ja" in self.lang: self.bert_model_names.update({"ja": "DEBERTA_V2_LARGE_JAPANESE_CHAR_WWM"}) | |
if "en" in self.lang: self.bert_model_names.update({"en": "DEBERTA_V3_LARGE"}) | |
if "zh" in self.lang and "zh" not in self.bert_model_names.keys(): | |
self.bert_model_names.update({"zh": "CHINESE_ROBERTA_WWM_EXT_LARGE"}) | |
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} | |
self.device = device | |
def load_model(self, model_handler): | |
self.model_handler = model_handler | |
if self.version in ["2.3", "extra", "2.4"]: | |
Synthesizer = SynthesizerTrn_v230 | |
elif self.version == "ja_extra": | |
Synthesizer = SynthesizerTrn_ja_extra | |
else: | |
Synthesizer = SynthesizerTrn | |
if self.version == "2.4": | |
self.pinyinPlus = self.model_handler.get_pinyinPlus() | |
self.net_g = Synthesizer( | |
len(self.symbols), | |
self.hps_ms.data.filter_length // 2 + 1, | |
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, | |
n_speakers=self.hps_ms.data.n_speakers, | |
symbols=self.symbols, | |
ja_bert_dim=self.ja_bert_dim, | |
num_tones=self.num_tones, | |
zh_bert_extra=self.zh_bert_extra, | |
**self.hps_ms.model).to(self.device) | |
_ = self.net_g.eval() | |
bert_vits2_utils.load_checkpoint(self.model_path, self.net_g, None, skip_optimizer=True, version=self.version) | |
def get_speakers(self): | |
return self.speakers | |
def get_text(self, text, language_str, hps, style_text=None, style_weight=0.7): | |
clean_text_lang_str = language_str + self.text_extra_str_map.get(language_str, "") | |
bert_feature_lang_str = language_str + self.bert_extra_str_map.get(language_str, "") | |
tokenizer, _ = self.model_handler.get_bert_model(self.bert_model_names[language_str]) | |
norm_text, phone, tone, word2ph = clean_text(text, clean_text_lang_str, tokenizer, self.pinyinPlus) | |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id) | |
if hps.data.add_blank: | |
phone = commons.intersperse(phone, 0) | |
tone = commons.intersperse(tone, 0) | |
language = commons.intersperse(language, 0) | |
for i in range(len(word2ph)): | |
word2ph[i] = word2ph[i] * 2 | |
word2ph[0] += 1 | |
if style_text == "" or self.zh_bert_extra: | |
style_text = None | |
bert = self.model_handler.get_bert_feature(norm_text, word2ph, bert_feature_lang_str, | |
self.bert_model_names[language_str], style_text, style_weight) | |
del word2ph | |
assert bert.shape[-1] == len(phone), phone | |
if self.zh_bert_extra: | |
zh_bert = bert | |
ja_bert, en_bert = None, None | |
elif self.ja_bert_extra: | |
ja_bert = bert | |
zh_bert, en_bert = None, None | |
elif language_str == "zh": | |
zh_bert = bert | |
ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) | |
en_bert = torch.zeros(1024, len(phone)) | |
elif language_str == "ja": | |
zh_bert = torch.zeros(1024, len(phone)) | |
ja_bert = bert | |
en_bert = torch.zeros(1024, len(phone)) | |
elif language_str == "en": | |
zh_bert = torch.zeros(1024, len(phone)) | |
ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) | |
en_bert = bert | |
else: | |
zh_bert = torch.zeros(1024, len(phone)) | |
ja_bert = torch.zeros(self.ja_bert_dim, len(phone)) | |
en_bert = torch.zeros(1024, len(phone)) | |
assert bert.shape[-1] == len( | |
phone | |
), f"Bert seq len {bert.shape[-1]} != {len(phone)}" | |
phone = torch.LongTensor(phone) | |
tone = torch.LongTensor(tone) | |
language = torch.LongTensor(language) | |
return zh_bert, ja_bert, en_bert, phone, tone, language | |
def _get_emo(self, reference_audio, emotion): | |
if reference_audio: | |
emo = torch.from_numpy( | |
get_emo(reference_audio, self.model_handler.emotion_model, | |
self.model_handler.emotion_processor)) | |
else: | |
if emotion is None: emotion = 0 | |
emo = torch.Tensor([emotion]) | |
return emo | |
def _get_clap(self, reference_audio, text_prompt): | |
if isinstance(reference_audio, np.ndarray): | |
emo = get_clap_audio_feature(reference_audio, self.model_handler.clap_model, | |
self.model_handler.clap_processor, self.device) | |
else: | |
if text_prompt is None: text_prompt = config.bert_vits2_config.text_prompt | |
emo = get_clap_text_feature(text_prompt, self.model_handler.clap_model, | |
self.model_handler.clap_processor, self.device) | |
emo = torch.squeeze(emo, dim=1).unsqueeze(0) | |
return emo | |
def _infer(self, id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length, | |
emo=None): | |
with torch.no_grad(): | |
x_tst = phones.to(self.device).unsqueeze(0) | |
tones = tones.to(self.device).unsqueeze(0) | |
lang_ids = lang_ids.to(self.device).unsqueeze(0) | |
if self.zh_bert_extra: | |
zh_bert = zh_bert.to(self.device).unsqueeze(0) | |
elif self.ja_bert_extra: | |
ja_bert = ja_bert.to(self.device).unsqueeze(0) | |
else: | |
zh_bert = zh_bert.to(self.device).unsqueeze(0) | |
ja_bert = ja_bert.to(self.device).unsqueeze(0) | |
en_bert = en_bert.to(self.device).unsqueeze(0) | |
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device) | |
speakers = torch.LongTensor([int(id)]).to(self.device) | |
audio = self.net_g.infer(x_tst, | |
x_tst_lengths, | |
speakers, | |
tones, | |
lang_ids, | |
zh_bert=zh_bert, | |
ja_bert=ja_bert, | |
en_bert=en_bert, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise, | |
noise_scale_w=noisew, | |
length_scale=length, | |
emo=emo | |
)[0][0, 0].data.cpu().float().numpy() | |
torch.cuda.empty_cache() | |
return audio | |
def infer(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None, | |
text_prompt=None, style_text=None, style_weigth=0.7, **kwargs): | |
zh_bert, ja_bert, en_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms, style_text, | |
style_weigth) | |
emo = None | |
if self.hps_ms.model.emotion_embedding == 1: | |
emo = self._get_emo(reference_audio, emotion).to(self.device).unsqueeze(0) | |
elif self.hps_ms.model.emotion_embedding == 2: | |
emo = self._get_clap(reference_audio, text_prompt) | |
return self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, noisew, length, | |
emo) | |
def infer_multilang(self, text, id, lang, sdp_ratio, noise, noisew, length, reference_audio=None, emotion=None, | |
text_prompt=None, style_text=None, style_weigth=0.7, **kwargs): | |
sentences_list = split_languages(text, self.lang, expand_abbreviations=True, expand_hyphens=True) | |
emo = None | |
if self.hps_ms.model.emotion_embedding == 1: | |
emo = self._get_emo(reference_audio, emotion).to(self.device).unsqueeze(0) | |
elif self.hps_ms.model.emotion_embedding == 2: | |
emo = self._get_clap(reference_audio, text_prompt) | |
phones, tones, lang_ids, zh_bert, ja_bert, en_bert = [], [], [], [], [], [] | |
for idx, (_text, lang) in enumerate(sentences_list): | |
skip_start = idx != 0 | |
skip_end = idx != len(sentences_list) - 1 | |
_zh_bert, _ja_bert, _en_bert, _phones, _tones, _lang_ids = self.get_text(_text, lang, self.hps_ms, | |
style_text, style_weigth) | |
if skip_start: | |
_phones = _phones[3:] | |
_tones = _tones[3:] | |
_lang_ids = _lang_ids[3:] | |
_zh_bert = _zh_bert[:, 3:] | |
_ja_bert = _ja_bert[:, 3:] | |
_en_bert = _en_bert[:, 3:] | |
if skip_end: | |
_phones = _phones[:-2] | |
_tones = _tones[:-2] | |
_lang_ids = _lang_ids[:-2] | |
_zh_bert = _zh_bert[:, :-2] | |
_ja_bert = _ja_bert[:, :-2] | |
_en_bert = _en_bert[:, :-2] | |
phones.append(_phones) | |
tones.append(_tones) | |
lang_ids.append(_lang_ids) | |
zh_bert.append(_zh_bert) | |
ja_bert.append(_ja_bert) | |
en_bert.append(_en_bert) | |
zh_bert = torch.cat(zh_bert, dim=1) | |
ja_bert = torch.cat(ja_bert, dim=1) | |
en_bert = torch.cat(en_bert, dim=1) | |
phones = torch.cat(phones, dim=0) | |
tones = torch.cat(tones, dim=0) | |
lang_ids = torch.cat(lang_ids, dim=0) | |
audio = self._infer(id, phones, tones, lang_ids, zh_bert, ja_bert, en_bert, sdp_ratio, noise, | |
noisew, length, emo) | |
return audio | |