|
|
|
import argparse |
|
from collections import Counter |
|
import logging |
|
from pathlib import Path |
|
import sys |
|
from typing import List |
|
from typing import Optional |
|
|
|
from typeguard import check_argument_types |
|
|
|
from espnet.utils.cli_utils import get_commandline_args |
|
from espnet2.text.build_tokenizer import build_tokenizer |
|
from espnet2.text.cleaner import TextCleaner |
|
from espnet2.utils.types import str2bool |
|
from espnet2.utils.types import str_or_none |
|
|
|
|
|
def field2slice(field: Optional[str]) -> slice: |
|
"""Convert field string to slice |
|
|
|
Note that field string accepts 1-based integer. |
|
|
|
Examples: |
|
>>> field2slice("1-") |
|
slice(0, None, None) |
|
>>> field2slice("1-3") |
|
slice(0, 3, None) |
|
>>> field2slice("-3") |
|
slice(None, 3, None) |
|
|
|
""" |
|
field = field.strip() |
|
try: |
|
if "-" in field: |
|
|
|
s1, s2 = field.split("-", maxsplit=1) |
|
if s1.strip() == "": |
|
s1 = None |
|
else: |
|
s1 = int(s1) |
|
if s1 == 0: |
|
raise ValueError("1-based string") |
|
if s2.strip() == "": |
|
s2 = None |
|
else: |
|
s2 = int(s2) |
|
else: |
|
|
|
s1 = int(field) |
|
s2 = s1 + 1 |
|
if s1 == 0: |
|
raise ValueError("must be 1 or more value") |
|
except ValueError: |
|
raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") |
|
|
|
|
|
|
|
slic = slice(s1 - 1, s2) |
|
return slic |
|
|
|
|
|
def tokenize( |
|
input: str, |
|
output: str, |
|
field: Optional[str], |
|
delimiter: Optional[str], |
|
token_type: str, |
|
space_symbol: str, |
|
non_linguistic_symbols: Optional[str], |
|
bpemodel: Optional[str], |
|
log_level: str, |
|
write_vocabulary: bool, |
|
vocabulary_size: int, |
|
remove_non_linguistic_symbols: bool, |
|
cutoff: int, |
|
add_symbol: List[str], |
|
cleaner: Optional[str], |
|
g2p: Optional[str], |
|
): |
|
assert check_argument_types() |
|
|
|
logging.basicConfig( |
|
level=log_level, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
if input == "-": |
|
fin = sys.stdin |
|
else: |
|
fin = Path(input).open("r", encoding="utf-8") |
|
if output == "-": |
|
fout = sys.stdout |
|
else: |
|
p = Path(output) |
|
p.parent.mkdir(parents=True, exist_ok=True) |
|
fout = p.open("w", encoding="utf-8") |
|
|
|
cleaner = TextCleaner(cleaner) |
|
tokenizer = build_tokenizer( |
|
token_type=token_type, |
|
bpemodel=bpemodel, |
|
delimiter=delimiter, |
|
space_symbol=space_symbol, |
|
non_linguistic_symbols=non_linguistic_symbols, |
|
remove_non_linguistic_symbols=remove_non_linguistic_symbols, |
|
g2p_type=g2p, |
|
) |
|
|
|
counter = Counter() |
|
if field is not None: |
|
field = field2slice(field) |
|
|
|
for line in fin: |
|
line = line.rstrip() |
|
if field is not None: |
|
|
|
|
|
tokens = line.split(delimiter) |
|
tokens = tokens[field] |
|
if delimiter is None: |
|
line = " ".join(tokens) |
|
else: |
|
line = delimiter.join(tokens) |
|
|
|
line = cleaner(line) |
|
tokens = tokenizer.text2tokens(line) |
|
if not write_vocabulary: |
|
fout.write(" ".join(tokens) + "\n") |
|
else: |
|
for t in tokens: |
|
counter[t] += 1 |
|
|
|
if not write_vocabulary: |
|
return |
|
|
|
|
|
|
|
|
|
words_and_counts = list( |
|
filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) |
|
) |
|
|
|
if vocabulary_size > 0: |
|
if vocabulary_size < len(add_symbol): |
|
raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") |
|
words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] |
|
|
|
|
|
for symbol_and_id in add_symbol: |
|
|
|
try: |
|
symbol, idx = symbol_and_id.split(":") |
|
idx = int(idx) |
|
except ValueError: |
|
raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}") |
|
symbol = symbol.strip() |
|
|
|
|
|
|
|
if idx < 0: |
|
idx = len(words_and_counts) + 1 + idx |
|
words_and_counts.insert(idx, (symbol, None)) |
|
|
|
|
|
for w, c in words_and_counts: |
|
fout.write(w + "\n") |
|
|
|
|
|
total_count = sum(counter.values()) |
|
invocab_count = sum(c for w, c in words_and_counts if c is not None) |
|
logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %") |
|
|
|
|
|
def get_parser() -> argparse.ArgumentParser: |
|
parser = argparse.ArgumentParser( |
|
description="Tokenize texts", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
parser.add_argument( |
|
"--log_level", |
|
type=lambda x: x.upper(), |
|
default="INFO", |
|
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
|
help="The verbose level of logging", |
|
) |
|
|
|
parser.add_argument( |
|
"--input", "-i", required=True, help="Input text. - indicates sys.stdin" |
|
) |
|
parser.add_argument( |
|
"--output", "-o", required=True, help="Output text. - indicates sys.stdout" |
|
) |
|
parser.add_argument( |
|
"--field", |
|
"-f", |
|
help="The target columns of the input text as 1-based integer. e.g 2-", |
|
) |
|
parser.add_argument( |
|
"--token_type", |
|
"-t", |
|
default="char", |
|
choices=["char", "bpe", "word", "phn"], |
|
help="Token type", |
|
) |
|
parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") |
|
parser.add_argument("--space_symbol", default="<space>", help="The space symbol") |
|
parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") |
|
parser.add_argument( |
|
"--non_linguistic_symbols", |
|
type=str_or_none, |
|
help="non_linguistic_symbols file path", |
|
) |
|
parser.add_argument( |
|
"--remove_non_linguistic_symbols", |
|
type=str2bool, |
|
default=False, |
|
help="Remove non-language-symbols from tokens", |
|
) |
|
parser.add_argument( |
|
"--cleaner", |
|
type=str_or_none, |
|
choices=[None, "tacotron", "jaconv", "vietnamese"], |
|
default=None, |
|
help="Apply text cleaning", |
|
) |
|
parser.add_argument( |
|
"--g2p", |
|
type=str_or_none, |
|
choices=[ |
|
None, |
|
"g2p_en", |
|
"g2p_en_no_space", |
|
"pyopenjtalk", |
|
"pyopenjtalk_kana", |
|
"pyopenjtalk_accent", |
|
"pyopenjtalk_accent_with_pause", |
|
"pypinyin_g2p", |
|
"pypinyin_g2p_phone", |
|
"espeak_ng_arabic", |
|
], |
|
default=None, |
|
help="Specify g2p method if --token_type=phn", |
|
) |
|
|
|
group = parser.add_argument_group("write_vocabulary mode related") |
|
group.add_argument( |
|
"--write_vocabulary", |
|
type=str2bool, |
|
default=False, |
|
help="Write tokens list instead of tokenized text per line", |
|
) |
|
group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") |
|
group.add_argument( |
|
"--cutoff", |
|
default=0, |
|
type=int, |
|
help="cut-off frequency used for write-vocabulary mode", |
|
) |
|
group.add_argument( |
|
"--add_symbol", |
|
type=str, |
|
default=[], |
|
action="append", |
|
help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def main(cmd=None): |
|
print(get_commandline_args(), file=sys.stderr) |
|
parser = get_parser() |
|
args = parser.parse_args(cmd) |
|
kwargs = vars(args) |
|
tokenize(**kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|