import re from pypinyin import Style from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin from pypinyin.converter import DefaultConverter from pypinyin.core import Pinyin from text import pinyin_dict from bert import TTSProsody class MyConverter(NeutralToneWith5Mixin, DefaultConverter): pass def is_chinese(uchar): if uchar >= u'\u4e00' and uchar <= u'\u9fa5': return True else: return False def clean_chinese(text: str): text = text.strip() text_clean = [] for char in text: if (is_chinese(char)): text_clean.append(char) else: if len(text_clean) > 1 and is_chinese(text_clean[-1]): text_clean.append(',') text_clean = ''.join(text_clean).strip(',') return text_clean class VITS_PinYin: def __init__(self, bert_path, device): self.pinyin_parser = Pinyin(MyConverter()) self.prosody = TTSProsody(bert_path, device) def get_phoneme4pinyin(self, pinyins): result = [] count_phone = [] for pinyin in pinyins: if pinyin[:-1] in pinyin_dict: tone = pinyin[-1] a = pinyin[:-1] a1, a2 = pinyin_dict[a] result += [a1, a2 + tone] count_phone.append(2) return result, count_phone def chinese_to_phonemes(self, text): text = clean_chinese(text) phonemes = ["sil"] chars = ['[PAD]'] count_phone = [] count_phone.append(1) for subtext in text.split(","): if (len(subtext) == 0): continue pinyins = self.correct_pinyin_tone3(subtext) sub_p, sub_c = self.get_phoneme4pinyin(pinyins) phonemes.extend(sub_p) phonemes.append("sp") count_phone.extend(sub_c) count_phone.append(1) chars.append(subtext) chars.append(',') phonemes.append("sil") count_phone.append(1) chars.append('[PAD]') chars = "".join(chars) char_embeds = self.prosody.get_char_embeds(chars) char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone) return " ".join(phonemes), char_embeds def correct_pinyin_tone3(self, text): pinyin_list = [p[0] for p in self.pinyin_parser.pinyin( text, style=Style.TONE3, strict=False, neutral_tone_with_five=True)] if len(pinyin_list) >= 2: for i in range(1, len(pinyin_list)): try: if re.findall(r'\d', pinyin_list[i-1])[0] == '3' and re.findall(r'\d', pinyin_list[i])[0] == '3': pinyin_list[i-1] = pinyin_list[i-1].replace('3', '2') except IndexError: pass return pinyin_list