Spaces:
Runtime error
Runtime error
File size: 5,844 Bytes
ee21b96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import itertools
import logging
import re
import time
from g2p_en import G2p
logger = logging.getLogger(__name__)
FAIL_SENT = "FAILED_SENTENCE"
def parse():
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True)
parser.add_argument("--out-path", type=str, required=True)
parser.add_argument("--lower-case", action="store_true")
parser.add_argument("--do-filter", action="store_true")
parser.add_argument("--use-word-start", action="store_true")
parser.add_argument("--dup-vowel", default=1, type=int)
parser.add_argument("--dup-consonant", default=1, type=int)
parser.add_argument("--no-punc", action="store_true")
parser.add_argument("--reserve-word", type=str, default="")
parser.add_argument(
"--reserve-first-column",
action="store_true",
help="first column is sentence id",
)
###
parser.add_argument("--parallel-process-num", default=1, type=int)
parser.add_argument("--logdir", default="")
args = parser.parse_args()
return args
def process_sent(sent, g2p, res_wrds, args):
sents = pre_process_sent(sent, args.do_filter, args.lower_case, res_wrds)
pho_seqs = [do_g2p(g2p, s, res_wrds, i == 0) for i, s in enumerate(sents)]
pho_seq = (
[FAIL_SENT]
if [FAIL_SENT] in pho_seqs
else list(itertools.chain.from_iterable(pho_seqs))
)
if args.no_punc:
pho_seq = remove_punc(pho_seq)
if args.dup_vowel > 1 or args.dup_consonant > 1:
pho_seq = dup_pho(pho_seq, args.dup_vowel, args.dup_consonant)
if args.use_word_start:
pho_seq = add_word_start(pho_seq)
return " ".join(pho_seq)
def remove_punc(sent):
ns = []
regex = re.compile("[^a-zA-Z0-9 ]")
for p in sent:
if (not regex.search(p)) or p == FAIL_SENT:
if p == " " and (len(ns) == 0 or ns[-1] == " "):
continue
ns.append(p)
return ns
def do_g2p(g2p, sent, res_wrds, is_first_sent):
if sent in res_wrds:
pho_seq = [res_wrds[sent]]
else:
pho_seq = g2p(sent)
if not is_first_sent:
pho_seq = [" "] + pho_seq # add space to separate
return pho_seq
def pre_process_sent(sent, do_filter, lower_case, res_wrds):
if do_filter:
sent = re.sub("-", " ", sent)
sent = re.sub("—", " ", sent)
if len(res_wrds) > 0:
wrds = sent.split()
wrds = ["SPLIT_ME " + w + " SPLIT_ME" if w in res_wrds else w for w in wrds]
sents = [x.strip() for x in " ".join(wrds).split("SPLIT_ME") if x.strip() != ""]
else:
sents = [sent]
if lower_case:
sents = [s.lower() if s not in res_wrds else s for s in sents]
return sents
def dup_pho(sent, dup_v_num, dup_c_num):
"""
duplicate phoneme defined as cmudict
http://www.speech.cs.cmu.edu/cgi-bin/cmudict
"""
if dup_v_num == 1 and dup_c_num == 1:
return sent
ns = []
for p in sent:
ns.append(p)
if re.search(r"\d$", p):
for i in range(1, dup_v_num):
ns.append(f"{p}-{i}P")
elif re.search(r"\w", p):
for i in range(1, dup_c_num):
ns.append(f"{p}-{i}P")
return ns
def add_word_start(sent):
ns = []
do_add = True
ws = "▁"
for p in sent:
if do_add:
p = ws + p
do_add = False
if p == " ":
do_add = True
else:
ns.append(p)
return ns
def load_reserve_word(reserve_word):
if reserve_word == "":
return []
with open(reserve_word, "r") as fp:
res_wrds = [x.strip().split() for x in fp.readlines() if x.strip() != ""]
assert sum([0 if len(x) == 2 else 1 for x in res_wrds]) == 0
res_wrds = dict(res_wrds)
return res_wrds
def process_sents(sents, args):
g2p = G2p()
out_sents = []
res_wrds = load_reserve_word(args.reserve_word)
for sent in sents:
col1 = ""
if args.reserve_first_column:
col1, sent = sent.split(None, 1)
sent = process_sent(sent, g2p, res_wrds, args)
if args.reserve_first_column and col1 != "":
sent = f"{col1} {sent}"
out_sents.append(sent)
return out_sents
def main():
args = parse()
out_sents = []
with open(args.data_path, "r") as fp:
sent_list = [x.strip() for x in fp.readlines()]
if args.parallel_process_num > 1:
try:
import submitit
except ImportError:
logger.warn(
"submitit is not found and only one job is used to process the data"
)
submitit = None
if args.parallel_process_num == 1 or submitit is None:
out_sents = process_sents(sent_list, args)
else:
# process sentences with parallel computation
lsize = len(sent_list) // args.parallel_process_num + 1
executor = submitit.AutoExecutor(folder=args.logdir)
executor.update_parameters(timeout_min=1000, cpus_per_task=4)
jobs = []
for i in range(args.parallel_process_num):
job = executor.submit(
process_sents, sent_list[lsize * i : lsize * (i + 1)], args
)
jobs.append(job)
is_running = True
while is_running:
time.sleep(5)
is_running = sum([job.done() for job in jobs]) < len(jobs)
out_sents = list(itertools.chain.from_iterable([job.result() for job in jobs]))
with open(args.out_path, "w") as fp:
fp.write("\n".join(out_sents) + "\n")
if __name__ == "__main__":
main()
|