import re import six from six.moves import range # pylint: disable=redefined-builtin PAD = "" EOS = "" UNK = "" SEG = "|" RESERVED_TOKENS = [PAD, EOS, UNK] NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 if six.PY2: RESERVED_TOKENS_BYTES = RESERVED_TOKENS else: RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] # Regular expression for unescaping token strings. # '\u' is converted to '_' # '\\' is converted to '\' # '\213;' is converted to unichr(213) _UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") _ESCAPE_CHARS = set(u"\\_u;0123456789") def strip_ids(ids, ids_to_strip): """Strip ids_to_strip from the end ids.""" ids = list(ids) while ids and ids[-1] in ids_to_strip: ids.pop() return ids class TextEncoder(object): """Base class for converting from ints to/from human readable strings.""" def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): self._num_reserved_ids = num_reserved_ids @property def num_reserved_ids(self): return self._num_reserved_ids def encode(self, s): """Transform a human-readable string into a sequence of int ids. The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, num_reserved_ids) are reserved. EOS is not appended. Args: s: human-readable string to be converted. Returns: ids: list of integers """ return [int(w) + self._num_reserved_ids for w in s.split()] def decode(self, ids, strip_extraneous=False): """Transform a sequence of int ids into a human-readable string. EOS is not expected in ids. Args: ids: list of integers to be converted. strip_extraneous: bool, whether to strip off extraneous tokens (EOS and PAD). Returns: s: human-readable string. """ if strip_extraneous: ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) return " ".join(self.decode_list(ids)) def decode_list(self, ids): """Transform a sequence of int ids into a their string versions. This method supports transforming individual input/output ids to their string versions so that sequence to/from text conversions can be visualized in a human readable format. Args: ids: list of integers to be converted. Returns: strs: list of human-readable string. """ decoded_ids = [] for id_ in ids: if 0 <= id_ < self._num_reserved_ids: decoded_ids.append(RESERVED_TOKENS[int(id_)]) else: decoded_ids.append(id_ - self._num_reserved_ids) return [str(d) for d in decoded_ids] @property def vocab_size(self): raise NotImplementedError() class ByteTextEncoder(TextEncoder): """Encodes each byte to an id. For 8-bit strings only.""" def encode(self, s): numres = self._num_reserved_ids if six.PY2: if isinstance(s, unicode): s = s.encode("utf-8") return [ord(c) + numres for c in s] # Python3: explicitly convert to UTF-8 return [c + numres for c in s.encode("utf-8")] def decode(self, ids, strip_extraneous=False): if strip_extraneous: ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) numres = self._num_reserved_ids decoded_ids = [] int2byte = six.int2byte for id_ in ids: if 0 <= id_ < numres: decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) else: decoded_ids.append(int2byte(id_ - numres)) if six.PY2: return "".join(decoded_ids) # Python3: join byte arrays and then decode string return b"".join(decoded_ids).decode("utf-8", "replace") def decode_list(self, ids): numres = self._num_reserved_ids decoded_ids = [] int2byte = six.int2byte for id_ in ids: if 0 <= id_ < numres: decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) else: decoded_ids.append(int2byte(id_ - numres)) # Python3: join byte arrays and then decode string return decoded_ids @property def vocab_size(self): return 2**8 + self._num_reserved_ids class ByteTextEncoderWithEos(ByteTextEncoder): """Encodes each byte to an id and appends the EOS token.""" def encode(self, s): return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID] class TokenTextEncoder(TextEncoder): """Encoder based on a user-supplied vocabulary (file or list).""" def __init__(self, vocab_filename, reverse=False, vocab_list=None, replace_oov=None, num_reserved_ids=NUM_RESERVED_TOKENS): """Initialize from a file or list, one token per line. Handling of reserved tokens works as follows: - When initializing from a list, we add reserved tokens to the vocab. - When initializing from a file, we do not add reserved tokens to the vocab. - When saving vocab files, we save reserved tokens to the file. Args: vocab_filename: If not None, the full filename to read vocab from. If this is not None, then vocab_list should be None. reverse: Boolean indicating if tokens should be reversed during encoding and decoding. vocab_list: If not None, a list of elements of the vocabulary. If this is not None, then vocab_filename should be None. replace_oov: If not None, every out-of-vocabulary token seen when encoding will be replaced by this string (which must be in vocab). num_reserved_ids: Number of IDs to save for reserved tokens like . """ super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) self._reverse = reverse self._replace_oov = replace_oov if vocab_filename: self._init_vocab_from_file(vocab_filename) else: assert vocab_list is not None self._init_vocab_from_list(vocab_list) self.pad_index = self._token_to_id[PAD] self.eos_index = self._token_to_id[EOS] self.unk_index = self._token_to_id[UNK] self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index def encode(self, s): """Converts a space-separated string of tokens to a list of ids.""" sentence = s tokens = sentence.strip().split() if self._replace_oov is not None: tokens = [t if t in self._token_to_id else self._replace_oov for t in tokens] ret = [self._token_to_id[tok] for tok in tokens] return ret[::-1] if self._reverse else ret def decode(self, ids, strip_eos=False, strip_padding=False): if strip_padding and self.pad() in list(ids): pad_pos = list(ids).index(self.pad()) ids = ids[:pad_pos] if strip_eos and self.eos() in list(ids): eos_pos = list(ids).index(self.eos()) ids = ids[:eos_pos] return " ".join(self.decode_list(ids)) def decode_list(self, ids): seq = reversed(ids) if self._reverse else ids return [self._safe_id_to_token(i) for i in seq] @property def vocab_size(self): return len(self._id_to_token) def __len__(self): return self.vocab_size def _safe_id_to_token(self, idx): return self._id_to_token.get(idx, "ID_%d" % idx) def _init_vocab_from_file(self, filename): """Load vocab from a file. Args: filename: The file to load vocabulary from. """ with open(filename) as f: tokens = [token.strip() for token in f.readlines()] def token_gen(): for token in tokens: yield token self._init_vocab(token_gen(), add_reserved_tokens=False) def _init_vocab_from_list(self, vocab_list): """Initialize tokens from a list of tokens. It is ok if reserved tokens appear in the vocab list. They will be removed. The set of tokens in vocab_list should be unique. Args: vocab_list: A list of tokens. """ def token_gen(): for token in vocab_list: if token not in RESERVED_TOKENS: yield token self._init_vocab(token_gen()) def _init_vocab(self, token_generator, add_reserved_tokens=True): """Initialize vocabulary with tokens from token_generator.""" self._id_to_token = {} non_reserved_start_index = 0 if add_reserved_tokens: self._id_to_token.update(enumerate(RESERVED_TOKENS)) non_reserved_start_index = len(RESERVED_TOKENS) self._id_to_token.update( enumerate(token_generator, start=non_reserved_start_index)) # _token_to_id is the reverse of _id_to_token self._token_to_id = dict((v, k) for k, v in six.iteritems(self._id_to_token)) def pad(self): return self.pad_index def eos(self): return self.eos_index def unk(self): return self.unk_index def seg(self): return self.seg_index def store_to_file(self, filename): """Write vocab file to disk. Vocab files have one token per line. The file ends in a newline. Reserved tokens are written to the vocab file as well. Args: filename: Full path of the file to store the vocab to. """ with open(filename, "w") as f: for i in range(len(self._id_to_token)): f.write(self._id_to_token[i] + "\n") def sil_phonemes(self): return [p for p in self._id_to_token.values() if not p[0].isalpha()]