|
""" from https://github.com/keithito/tacotron """ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
valid_symbols = [ |
|
"AA", |
|
"AA0", |
|
"AA1", |
|
"AA2", |
|
"AE", |
|
"AE0", |
|
"AE1", |
|
"AE2", |
|
"AH", |
|
"AH0", |
|
"AH1", |
|
"AH2", |
|
"AO", |
|
"AO0", |
|
"AO1", |
|
"AO2", |
|
"AW", |
|
"AW0", |
|
"AW1", |
|
"AW2", |
|
"AY", |
|
"AY0", |
|
"AY1", |
|
"AY2", |
|
"B", |
|
"CH", |
|
"D", |
|
"DH", |
|
"EH", |
|
"EH0", |
|
"EH1", |
|
"EH2", |
|
"ER", |
|
"ER0", |
|
"ER1", |
|
"ER2", |
|
"EY", |
|
"EY0", |
|
"EY1", |
|
"EY2", |
|
"F", |
|
"G", |
|
"HH", |
|
"IH", |
|
"IH0", |
|
"IH1", |
|
"IH2", |
|
"IY", |
|
"IY0", |
|
"IY1", |
|
"IY2", |
|
"JH", |
|
"K", |
|
"L", |
|
"M", |
|
"N", |
|
"NG", |
|
"OW", |
|
"OW0", |
|
"OW1", |
|
"OW2", |
|
"OY", |
|
"OY0", |
|
"OY1", |
|
"OY2", |
|
"P", |
|
"R", |
|
"S", |
|
"SH", |
|
"T", |
|
"TH", |
|
"UH", |
|
"UH0", |
|
"UH1", |
|
"UH2", |
|
"UW", |
|
"UW0", |
|
"UW1", |
|
"UW2", |
|
"V", |
|
"W", |
|
"Y", |
|
"Z", |
|
"ZH", |
|
] |
|
|
|
|
|
""" |
|
Defines the set of symbols used in text input to the model. |
|
The default is a set of ASCII characters that works well for English. For other data, you can modify _characters. See TRAINING_DATA.md for details. |
|
""" |
|
|
|
|
|
_pad = "_" |
|
_punctuation = "!'(),.:;? " |
|
_special = "-" |
|
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz^*" |
|
|
|
|
|
|
|
_arpabet = ["@" + s for s in valid_symbols] |
|
|
|
|
|
symbols = ( |
|
[_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet |
|
) |
|
|
|
|
|
|
|
_symbol_to_id = {s: i for i, s in enumerate(symbols)} |
|
_id_to_symbol = {i: s for i, s in enumerate(symbols)} |
|
|
|
|
|
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") |
|
|
|
|
|
|
|
_whitespace_re = re.compile(r"\s+") |
|
|
|
|
|
_abbreviations = [ |
|
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) |
|
for x in [ |
|
("mrs", "misess"), |
|
("mr", "mister"), |
|
("dr", "doctor"), |
|
("st", "saint"), |
|
("co", "company"), |
|
("jr", "junior"), |
|
("maj", "major"), |
|
("gen", "general"), |
|
("drs", "doctors"), |
|
("rev", "reverend"), |
|
("lt", "lieutenant"), |
|
("hon", "honorable"), |
|
("sgt", "sergeant"), |
|
("capt", "captain"), |
|
("esq", "esquire"), |
|
("ltd", "limited"), |
|
("col", "colonel"), |
|
("ft", "fort"), |
|
] |
|
] |
|
|
|
|
|
def expand_abbreviations(text): |
|
"""expand abbreviations pre-defined |
|
""" |
|
for regex, replacement in _abbreviations: |
|
text = re.sub(regex, replacement, text) |
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
def lowercase(text): |
|
"""lowercase the text |
|
""" |
|
return text.lower() |
|
|
|
|
|
def collapse_whitespace(text): |
|
"""Replaces whitespace by " " in the text |
|
""" |
|
return re.sub(_whitespace_re, " ", text) |
|
|
|
|
|
def convert_to_ascii(text): |
|
"""Converts text to ascii |
|
""" |
|
text_encoded = text.encode("ascii", "ignore") |
|
return text_encoded.decode() |
|
|
|
|
|
def basic_cleaners(text): |
|
"""Basic pipeline that lowercases and collapses whitespace without transliteration. |
|
""" |
|
text = lowercase(text) |
|
text = collapse_whitespace(text) |
|
return text |
|
|
|
|
|
def transliteration_cleaners(text): |
|
"""Pipeline for non-English text that transliterates to ASCII. |
|
""" |
|
text = convert_to_ascii(text) |
|
text = lowercase(text) |
|
text = collapse_whitespace(text) |
|
return text |
|
|
|
|
|
def english_cleaners(text): |
|
"""Pipeline for English text, including number and abbreviation expansion. |
|
""" |
|
text = convert_to_ascii(text) |
|
text = lowercase(text) |
|
text = expand_abbreviations(text) |
|
text = collapse_whitespace(text) |
|
return text |
|
|
|
|
|
def text_to_sequence(text, cleaner_names): |
|
"""Returns a list of integers corresponding to the symbols in the text. |
|
Converts a string of text to a sequence of IDs corresponding to the symbols in the text. |
|
The text can optionally have ARPAbet sequences enclosed in curly braces embedded |
|
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." |
|
|
|
Arguments |
|
--------- |
|
text : str |
|
string to convert to a sequence |
|
cleaner_names : list |
|
names of the cleaner functions to run the text through |
|
|
|
""" |
|
sequence = [] |
|
|
|
|
|
while len(text): |
|
m = _curly_re.match(text) |
|
if not m: |
|
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) |
|
break |
|
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) |
|
sequence += _arpabet_to_sequence(m.group(2)) |
|
text = m.group(3) |
|
|
|
return sequence |
|
|
|
|
|
def sequence_to_text(sequence): |
|
"""Converts a sequence of IDs back to a string |
|
""" |
|
result = "" |
|
for symbol_id in sequence: |
|
if symbol_id in _id_to_symbol: |
|
s = _id_to_symbol[symbol_id] |
|
|
|
if len(s) > 1 and s[0] == "@": |
|
s = "{%s}" % s[1:] |
|
result += s |
|
return result.replace("}{", " ") |
|
|
|
|
|
def _clean_text(text, cleaner_names): |
|
"""apply different cleaning pipeline according to cleaner_names |
|
""" |
|
for name in cleaner_names: |
|
if name == "english_cleaners": |
|
cleaner = english_cleaners |
|
if name == "transliteration_cleaners": |
|
cleaner = transliteration_cleaners |
|
if name == "basic_cleaners": |
|
cleaner = basic_cleaners |
|
if not cleaner: |
|
raise Exception("Unknown cleaner: %s" % name) |
|
text = cleaner(text) |
|
return text |
|
|
|
|
|
def _symbols_to_sequence(symbols): |
|
"""convert symbols to sequence |
|
""" |
|
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] |
|
|
|
|
|
def _arpabet_to_sequence(text): |
|
"""Prepend "@" to ensure uniqueness |
|
""" |
|
return _symbols_to_sequence(["@" + s for s in text.split()]) |
|
|
|
|
|
def _should_keep_symbol(s): |
|
"""whether to keep a certain symbol |
|
""" |
|
return s in _symbol_to_id and s != "_" and s != "~" |
|
|