Spaces:
Runtime error
Runtime error
File size: 3,820 Bytes
9bd9742 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import json
import os.path
from collections import defaultdict
from random import shuffle
from typing import Optional
from tqdm import tqdm
import click
from text.cleaner import clean_text
@click.command()
@click.option(
"--transcription-path",
default="filelists/output_fixed.txt",
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option("--cleaned-path", default=None)
@click.option("--train-path", default="filelists/train.list")
@click.option("--val-path", default="filelists/val.list")
@click.option(
"--config-path",
default="configs/config.json",
type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option("--val-per-spk", default=1)
@click.option("--max-val-total", default=20)
@click.option("--clean/--no-clean", default=True)
def main(
transcription_path: str,
cleaned_path: Optional[str],
train_path: str,
val_path: str,
config_path: str,
val_per_spk: int,
max_val_total: int,
clean: bool,
):
if cleaned_path is None:
cleaned_path = transcription_path + ".cleaned"
if clean:
out_file = open(cleaned_path, "w", encoding="utf-8")
for line in tqdm(open(transcription_path, encoding="utf-8").readlines()):
try:
utt, spk, language, text = line.strip().split("|")
norm_text, phones, tones, word2ph = clean_text(text, language)
out_file.write(
"{}|{}|{}|{}|{}|{}|{}\n".format(
utt,
spk,
language,
norm_text,
" ".join(phones),
" ".join([str(i) for i in tones]),
" ".join([str(i) for i in word2ph]),
)
)
except Exception as error:
print("err!", line, error)
out_file.close()
transcription_path = cleaned_path
spk_utt_map = defaultdict(list)
spk_id_map = {}
current_sid = 0
with open(transcription_path, encoding="utf-8") as f:
audioPaths = set()
countSame = 0
countNotFound = 0
for line in f.readlines():
utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
if utt in audioPaths:
# 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题
print(f"重复音频文本:{line}")
countSame += 1
continue
if not os.path.isfile("filelists/" + utt):
print(f"没有找到对应的音频:{utt}")
countNotFound += 1
continue
audioPaths.add(utt)
spk_utt_map[spk].append(line)
if spk not in spk_id_map.keys():
spk_id_map[spk] = current_sid
current_sid += 1
print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
train_list = []
val_list = []
for spk, utts in spk_utt_map.items():
shuffle(utts)
val_list += utts[:val_per_spk]
train_list += utts[val_per_spk:]
if len(val_list) > max_val_total:
train_list += val_list[max_val_total:]
val_list = val_list[:max_val_total]
with open(train_path, "w", encoding="utf-8") as f:
for line in train_list:
f.write(line)
with open(val_path, "w", encoding="utf-8") as f:
for line in val_list:
f.write(line)
config = json.load(open(config_path, encoding="utf-8"))
config["data"]["spk2id"] = spk_id_map
with open(config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
if __name__ == "__main__":
main() |