M4Singer / data_gen /tts /binarizer_zh.py
kevinwang676's picture
Duplicate from zlc99/M4Singer
26925fd
raw
history blame
No virus
2.24 kB
import os
os.environ["OMP_NUM_THREADS"] = "1"
from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError
from data_gen.tts.data_gen_utils import get_mel2ph
from utils.hparams import set_hparams, hparams
import numpy as np
class ZhBinarizer(BaseBinarizer):
@staticmethod
def get_align(tg_fn, ph, mel, phone_encoded, res):
if tg_fn is not None and os.path.exists(tg_fn):
_, dur = get_mel2ph(tg_fn, ph, mel, hparams)
else:
raise BinarizationError(f"Align not found")
ph_list = ph.split(" ")
assert len(dur) == len(ph_list)
mel2ph = []
# 分隔符的时长分配给韵母
dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0)
for i in range(len(dur)):
p = ph_list[i]
if p[0] != '<' and not p[0].isalpha():
uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0
j = 0
while j < len(uv_) and not uv_[j]:
j += 1
dur[i - 1] += j
dur[i] -= j
if dur[i] < 100:
dur[i - 1] += dur[i]
dur[i] = 0
# 声母和韵母等长
for i in range(len(dur)):
p = ph_list[i]
if p in ALL_SHENMU:
p_next = ph_list[i + 1]
if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU):
print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, "
f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.")
continue
total = dur[i + 1] + dur[i]
dur[i] = total // 2
dur[i + 1] = total - dur[i]
for i in range(len(dur)):
mel2ph += [i + 1] * dur[i]
mel2ph = np.array(mel2ph)
if mel2ph.max() - 1 >= len(phone_encoded):
raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}")
res['mel2ph'] = mel2ph
res['dur'] = dur
if __name__ == "__main__":
set_hparams()
ZhBinarizer().process()