Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/examples
/mfa_extraction
/fix_mismatch.py
# -*- coding: utf-8 -*- | |
# Copyright 2020 TensorFlowTTS Team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Fix mismatch between sum durations and mel lengths.""" | |
import numpy as np | |
import os | |
from tqdm import tqdm | |
import click | |
import logging | |
import sys | |
logging.basicConfig( | |
level=logging.DEBUG, | |
stream=sys.stdout, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
def fix(base_path: str, dur_path: str, trimmed_dur_path: str, use_norm: str): | |
for t in ["train", "valid"]: | |
mfa_longer = [] | |
mfa_shorter = [] | |
big_diff = [] | |
not_fixed = [] | |
pre_path = os.path.join(base_path, t) | |
os.makedirs(os.path.join(pre_path, "fix_dur"), exist_ok=True) | |
logging.info(f"FIXING {t} set ...\n") | |
for i in tqdm(os.listdir(os.path.join(pre_path, "ids"))): | |
if use_norm == "t": | |
mel = np.load( | |
os.path.join( | |
pre_path, "norm-feats", f"{i.split('-')[0]}-norm-feats.npy" | |
) | |
) | |
else: | |
mel = np.load( | |
os.path.join( | |
pre_path, "raw-feats", f"{i.split('-')[0]}-raw-feats.npy" | |
) | |
) | |
try: | |
dur = np.load( | |
os.path.join(trimmed_dur_path, f"{i.split('-')[0]}-durations.npy") | |
) | |
except: | |
dur = np.load( | |
os.path.join(dur_path, f"{i.split('-')[0]}-durations.npy") | |
) | |
l_mel = len(mel) | |
dur_s = np.sum(dur) | |
cloned = np.array(dur, copy=True) | |
diff = abs(l_mel - dur_s) | |
if abs(l_mel - dur_s) > 30: # more then 300 ms | |
big_diff.append([i, abs(l_mel - dur_s)]) | |
if dur_s > l_mel: | |
for j in range(1, len(dur) - 1): | |
if diff == 0: | |
break | |
dur_val = cloned[-j] | |
if dur_val >= diff: | |
cloned[-j] -= diff | |
diff -= dur_val | |
break | |
else: | |
cloned[-j] = 0 | |
diff -= dur_val | |
if j == len(dur) - 2: | |
not_fixed.append(i) | |
mfa_longer.append(abs(l_mel - dur_s)) | |
elif dur_s < l_mel: | |
cloned[-1] += diff | |
mfa_shorter.append(abs(l_mel - dur_s)) | |
np.save( | |
os.path.join(pre_path, "fix_dur", f"{i.split('-')[0]}-durations.npy"), | |
cloned.astype(np.int32), | |
allow_pickle=False, | |
) | |
logging.info( | |
f"{t} stats: number of mfa with longer duration: {len(mfa_longer)}, total diff: {sum(mfa_longer)}" | |
f", mean diff: {sum(mfa_longer)/len(mfa_longer) if len(mfa_longer) > 0 else 0}" | |
) | |
logging.info( | |
f"{t} stats: number of mfa with shorter duration: {len(mfa_shorter)}, total diff: {sum(mfa_shorter)}" | |
f", mean diff: {sum(mfa_shorter)/len(mfa_shorter) if len(mfa_shorter) > 0 else 0}" | |
) | |
logging.info( | |
f"{t} stats: number of files with a ''big'' duration diff: {len(big_diff)} if number>1 you should check it" | |
) | |
logging.info(f"{t} stats: not fixed len: {len(not_fixed)}\n") | |
if __name__ == "__main__": | |
fix() | |