David Portes commited on
Commit
65ce00b
·
1 Parent(s): bc2a85a

text_to_seq

Browse files
Files changed (1) hide show
  1. text_to_sequence.py +309 -0
text_to_sequence.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ # *****************************************************************************
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Redistribution and use in source and binary forms, with or without
6
+ # modification, are permitted provided that the following conditions are met:
7
+ # * Redistributions of source code must retain the above copyright
8
+ # notice, this list of conditions and the following disclaimer.
9
+ # * Redistributions in binary form must reproduce the above copyright
10
+ # notice, this list of conditions and the following disclaimer in the
11
+ # documentation and/or other materials provided with the distribution.
12
+ # * Neither the name of the NVIDIA CORPORATION nor the
13
+ # names of its contributors may be used to endorse or promote products
14
+ # derived from this software without specific prior written permission.
15
+ #
16
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19
+ # DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20
+ # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21
+ # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22
+ # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23
+ # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24
+ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25
+ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ #
27
+ # *****************************************************************************
28
+ import re
29
+
30
+
31
+ valid_symbols = [
32
+ "AA",
33
+ "AA0",
34
+ "AA1",
35
+ "AA2",
36
+ "AE",
37
+ "AE0",
38
+ "AE1",
39
+ "AE2",
40
+ "AH",
41
+ "AH0",
42
+ "AH1",
43
+ "AH2",
44
+ "AO",
45
+ "AO0",
46
+ "AO1",
47
+ "AO2",
48
+ "AW",
49
+ "AW0",
50
+ "AW1",
51
+ "AW2",
52
+ "AY",
53
+ "AY0",
54
+ "AY1",
55
+ "AY2",
56
+ "B",
57
+ "CH",
58
+ "D",
59
+ "DH",
60
+ "EH",
61
+ "EH0",
62
+ "EH1",
63
+ "EH2",
64
+ "ER",
65
+ "ER0",
66
+ "ER1",
67
+ "ER2",
68
+ "EY",
69
+ "EY0",
70
+ "EY1",
71
+ "EY2",
72
+ "F",
73
+ "G",
74
+ "HH",
75
+ "IH",
76
+ "IH0",
77
+ "IH1",
78
+ "IH2",
79
+ "IY",
80
+ "IY0",
81
+ "IY1",
82
+ "IY2",
83
+ "JH",
84
+ "K",
85
+ "L",
86
+ "M",
87
+ "N",
88
+ "NG",
89
+ "OW",
90
+ "OW0",
91
+ "OW1",
92
+ "OW2",
93
+ "OY",
94
+ "OY0",
95
+ "OY1",
96
+ "OY2",
97
+ "P",
98
+ "R",
99
+ "S",
100
+ "SH",
101
+ "T",
102
+ "TH",
103
+ "UH",
104
+ "UH0",
105
+ "UH1",
106
+ "UH2",
107
+ "UW",
108
+ "UW0",
109
+ "UW1",
110
+ "UW2",
111
+ "V",
112
+ "W",
113
+ "Y",
114
+ "Z",
115
+ "ZH",
116
+ ]
117
+
118
+
119
+ """
120
+ Defines the set of symbols used in text input to the model.
121
+ The default is a set of ASCII characters that works well for English. For other data, you can modify _characters. See TRAINING_DATA.md for details.
122
+ """
123
+
124
+
125
+ _pad = "_"
126
+ _punctuation = "!'(),.:;? "
127
+ _special = "-"
128
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz^*"
129
+
130
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same
131
+ # as uppercase letters):
132
+ _arpabet = ["@" + s for s in valid_symbols]
133
+
134
+ # Export all symbols:
135
+ symbols = (
136
+ [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
137
+ )
138
+
139
+
140
+ # Mappings from symbol to numeric ID and vice versa:
141
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
142
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
143
+
144
+ # Regular expression matching text enclosed in curly braces:
145
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
146
+
147
+
148
+ # Regular expression matching whitespace:
149
+ _whitespace_re = re.compile(r"\s+")
150
+
151
+ # List of (regular expression, replacement) pairs for abbreviations:
152
+ _abbreviations = [
153
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
154
+ for x in [
155
+ ("mrs", "misess"),
156
+ ("mr", "mister"),
157
+ ("dr", "doctor"),
158
+ ("st", "saint"),
159
+ ("co", "company"),
160
+ ("jr", "junior"),
161
+ ("maj", "major"),
162
+ ("gen", "general"),
163
+ ("drs", "doctors"),
164
+ ("rev", "reverend"),
165
+ ("lt", "lieutenant"),
166
+ ("hon", "honorable"),
167
+ ("sgt", "sergeant"),
168
+ ("capt", "captain"),
169
+ ("esq", "esquire"),
170
+ ("ltd", "limited"),
171
+ ("col", "colonel"),
172
+ ("ft", "fort"),
173
+ ]
174
+ ]
175
+
176
+
177
+ def expand_abbreviations(text):
178
+ """expand abbreviations pre-defined
179
+ """
180
+ for regex, replacement in _abbreviations:
181
+ text = re.sub(regex, replacement, text)
182
+ return text
183
+
184
+
185
+ # def expand_numbers(text):
186
+ # return normalize_numbers(text)
187
+
188
+
189
+ def lowercase(text):
190
+ """lowercase the text
191
+ """
192
+ return text.lower()
193
+
194
+
195
+ def collapse_whitespace(text):
196
+ """Replaces whitespace by " " in the text
197
+ """
198
+ return re.sub(_whitespace_re, " ", text)
199
+
200
+
201
+ def convert_to_ascii(text):
202
+ """Converts text to ascii
203
+ """
204
+ text_encoded = text.encode("ascii", "ignore")
205
+ return text_encoded.decode()
206
+
207
+
208
+ def basic_cleaners(text):
209
+ """Basic pipeline that lowercases and collapses whitespace without transliteration.
210
+ """
211
+ text = lowercase(text)
212
+ text = collapse_whitespace(text)
213
+ return text
214
+
215
+
216
+ def transliteration_cleaners(text):
217
+ """Pipeline for non-English text that transliterates to ASCII.
218
+ """
219
+ text = convert_to_ascii(text)
220
+ text = lowercase(text)
221
+ text = collapse_whitespace(text)
222
+ return text
223
+
224
+
225
+ def english_cleaners(text):
226
+ """Pipeline for English text, including number and abbreviation expansion.
227
+ """
228
+ text = convert_to_ascii(text)
229
+ text = lowercase(text)
230
+ text = expand_abbreviations(text)
231
+ text = collapse_whitespace(text)
232
+ return text
233
+
234
+
235
+ def text_to_sequence(text, cleaner_names):
236
+ """Returns a list of integers corresponding to the symbols in the text.
237
+ Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
238
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
239
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
240
+
241
+ Arguments
242
+ ---------
243
+ text : str
244
+ string to convert to a sequence
245
+ cleaner_names : list
246
+ names of the cleaner functions to run the text through
247
+
248
+ """
249
+ sequence = []
250
+
251
+ # Check for curly braces and treat their contents as ARPAbet:
252
+ while len(text):
253
+ m = _curly_re.match(text)
254
+ if not m:
255
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
256
+ break
257
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
258
+ sequence += _arpabet_to_sequence(m.group(2))
259
+ text = m.group(3)
260
+
261
+ return sequence
262
+
263
+
264
+ def sequence_to_text(sequence):
265
+ """Converts a sequence of IDs back to a string
266
+ """
267
+ result = ""
268
+ for symbol_id in sequence:
269
+ if symbol_id in _id_to_symbol:
270
+ s = _id_to_symbol[symbol_id]
271
+ # Enclose ARPAbet back in curly braces:
272
+ if len(s) > 1 and s[0] == "@":
273
+ s = "{%s}" % s[1:]
274
+ result += s
275
+ return result.replace("}{", " ")
276
+
277
+
278
+ def _clean_text(text, cleaner_names):
279
+ """apply different cleaning pipeline according to cleaner_names
280
+ """
281
+ for name in cleaner_names:
282
+ if name == "english_cleaners":
283
+ cleaner = english_cleaners
284
+ if name == "transliteration_cleaners":
285
+ cleaner = transliteration_cleaners
286
+ if name == "basic_cleaners":
287
+ cleaner = basic_cleaners
288
+ if not cleaner:
289
+ raise Exception("Unknown cleaner: %s" % name)
290
+ text = cleaner(text)
291
+ return text
292
+
293
+
294
+ def _symbols_to_sequence(symbols):
295
+ """convert symbols to sequence
296
+ """
297
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
298
+
299
+
300
+ def _arpabet_to_sequence(text):
301
+ """Prepend "@" to ensure uniqueness
302
+ """
303
+ return _symbols_to_sequence(["@" + s for s in text.split()])
304
+
305
+
306
+ def _should_keep_symbol(s):
307
+ """whether to keep a certain symbol
308
+ """
309
+ return s in _symbol_to_id and s != "_" and s != "~"