Spaces:
Runtime error
Runtime error
File size: 8,658 Bytes
d1b91e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import json
import re
import six
from six.moves import range # pylint: disable=redefined-builtin
PAD = "<pad>"
EOS = "<EOS>"
UNK = "<UNK>"
SEG = "|"
PUNCS = '!,.?;:'
RESERVED_TOKENS = [PAD, EOS, UNK]
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
if six.PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
# Regular expression for unescaping token strings.
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_ESCAPE_CHARS = set(u"\\_u;0123456789")
def strip_ids(ids, ids_to_strip):
"""Strip ids_to_strip from the end ids."""
ids = list(ids)
while ids and ids[-1] in ids_to_strip:
ids.pop()
return ids
class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
self._num_reserved_ids = num_reserved_ids
@property
def num_reserved_ids(self):
return self._num_reserved_ids
def encode(self, s):
"""Transform a human-readable string into a sequence of int ids.
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
num_reserved_ids) are reserved.
EOS is not appended.
Args:
s: human-readable string to be converted.
Returns:
ids: list of integers
"""
return [int(w) + self._num_reserved_ids for w in s.split()]
def decode(self, ids, strip_extraneous=False):
"""Transform a sequence of int ids into a human-readable string.
EOS is not expected in ids.
Args:
ids: list of integers to be converted.
strip_extraneous: bool, whether to strip off extraneous tokens
(EOS and PAD).
Returns:
s: human-readable string.
"""
if strip_extraneous:
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
"""Transform a sequence of int ids into a their string versions.
This method supports transforming individual input/output ids to their
string versions so that sequence to/from text conversions can be visualized
in a human readable format.
Args:
ids: list of integers to be converted.
Returns:
strs: list of human-readable string.
"""
decoded_ids = []
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(RESERVED_TOKENS[int(id_)])
else:
decoded_ids.append(id_ - self._num_reserved_ids)
return [str(d) for d in decoded_ids]
@property
def vocab_size(self):
raise NotImplementedError()
class TokenTextEncoder(TextEncoder):
"""Encoder based on a user-supplied vocabulary (file or list)."""
def __init__(self,
vocab_filename,
reverse=False,
vocab_list=None,
replace_oov=None,
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Initialize from a file or list, one token per line.
Handling of reserved tokens works as follows:
- When initializing from a list, we add reserved tokens to the vocab.
- When initializing from a file, we do not add reserved tokens to the vocab.
- When saving vocab files, we save reserved tokens to the file.
Args:
vocab_filename: If not None, the full filename to read vocab from. If this
is not None, then vocab_list should be None.
reverse: Boolean indicating if tokens should be reversed during encoding
and decoding.
vocab_list: If not None, a list of elements of the vocabulary. If this is
not None, then vocab_filename should be None.
replace_oov: If not None, every out-of-vocabulary token seen when
encoding will be replaced by this string (which must be in vocab).
num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
"""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
self._replace_oov = replace_oov
if vocab_filename:
self._init_vocab_from_file(vocab_filename)
else:
assert vocab_list is not None
self._init_vocab_from_list(vocab_list)
self.pad_index = self.token_to_id[PAD]
self.eos_index = self.token_to_id[EOS]
self.unk_index = self.token_to_id[UNK]
self.seg_index = self.token_to_id[SEG] if SEG in self.token_to_id else self.eos_index
def encode(self, s):
"""Converts a space-separated string of tokens to a list of ids."""
sentence = s
tokens = sentence.strip().split()
if self._replace_oov is not None:
tokens = [t if t in self.token_to_id else self._replace_oov
for t in tokens]
ret = [self.token_to_id[tok] for tok in tokens]
return ret[::-1] if self._reverse else ret
def decode(self, ids, strip_eos=False, strip_padding=False):
if strip_padding and self.pad() in list(ids):
pad_pos = list(ids).index(self.pad())
ids = ids[:pad_pos]
if strip_eos and self.eos() in list(ids):
eos_pos = list(ids).index(self.eos())
ids = ids[:eos_pos]
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
seq = reversed(ids) if self._reverse else ids
return [self._safe_id_to_token(i) for i in seq]
@property
def vocab_size(self):
return len(self.id_to_token)
def __len__(self):
return self.vocab_size
def _safe_id_to_token(self, idx):
return self.id_to_token.get(idx, "ID_%d" % idx)
def _init_vocab_from_file(self, filename):
"""Load vocab from a file.
Args:
filename: The file to load vocabulary from.
"""
with open(filename) as f:
tokens = [token.strip() for token in f.readlines()]
def token_gen():
for token in tokens:
yield token
self._init_vocab(token_gen(), add_reserved_tokens=False)
def _init_vocab_from_list(self, vocab_list):
"""Initialize tokens from a list of tokens.
It is ok if reserved tokens appear in the vocab list. They will be
removed. The set of tokens in vocab_list should be unique.
Args:
vocab_list: A list of tokens.
"""
def token_gen():
for token in vocab_list:
if token not in RESERVED_TOKENS:
yield token
self._init_vocab(token_gen())
def _init_vocab(self, token_generator, add_reserved_tokens=True):
"""Initialize vocabulary with tokens from token_generator."""
self.id_to_token = {}
non_reserved_start_index = 0
if add_reserved_tokens:
self.id_to_token.update(enumerate(RESERVED_TOKENS))
non_reserved_start_index = len(RESERVED_TOKENS)
self.id_to_token.update(
enumerate(token_generator, start=non_reserved_start_index))
# _token_to_id is the reverse of _id_to_token
self.token_to_id = dict((v, k) for k, v in six.iteritems(self.id_to_token))
def pad(self):
return self.pad_index
def eos(self):
return self.eos_index
def unk(self):
return self.unk_index
def seg(self):
return self.seg_index
def store_to_file(self, filename):
"""Write vocab file to disk.
Vocab files have one token per line. The file ends in a newline. Reserved
tokens are written to the vocab file as well.
Args:
filename: Full path of the file to store the vocab to.
"""
with open(filename, "w") as f:
for i in range(len(self.id_to_token)):
f.write(self.id_to_token[i] + "\n")
def sil_phonemes(self):
return [p for p in self.id_to_token.values() if is_sil_phoneme(p)]
def build_token_encoder(token_list_file):
token_list = json.load(open(token_list_file))
return TokenTextEncoder(None, vocab_list=token_list, replace_oov='<UNK>')
def is_sil_phoneme(p):
return p == '' or not p[0].isalpha()
|