import re import unicodedata TRANS_TABLE = dict([(ord(x), ord(y)) for x, y in zip(u"‘’´“”—–-", u"'''\"\"---")]) def _is_punctuation(char): """Checks whether `chars` is a punctuation character.""" cp = ord(char) # We treat all non-letter/number ASCII as punctuation. # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126): return True cat = unicodedata.category(char) if cat.startswith("P"): return True return False def _handle_single_quote(tokens): line = ' '.join(tokens) line = re.sub(r"' ([smdSMDtT])\b", r"'\1", line) line = re.sub(r"' ll\b", "'ll", line) line = re.sub(r"' re\b", "'re", line) line = re.sub(r"' ve\b", "'ve", line) line = re.sub(r"' LL\b", "'LL ", line) line = re.sub(r"' RE\b", "'RE ", line) line = re.sub(r"' VE\b", "'VE ", line) return line.split() def _split_on_cont_punc(tokens): new_tokens = [] for token in tokens: if len(token) > 1: last_j = 0 pre_is_punc = _is_punctuation(token[0]) for j, ch in enumerate(token): is_punc = _is_punctuation(ch) if is_punc != pre_is_punc: new_tokens.append(token[last_j: j]) last_j = j pre_is_punc = is_punc if last_j < len(token): new_tokens.append(token[last_j:]) else: new_tokens.append(token) return new_tokens def _split_pre_and_post_punc(tokens): def pre_punc(token): last_j = 0 for j in range(1, len(token)): if not _is_punctuation(token[j]): last_j = j break return token[:last_j], token[last_j:] def post_punc(token): last_j = len(token) for j in range(len(token) - 2, -1, -1): if not _is_punctuation(token[j]): last_j = j + 1 break return token[:last_j], token[last_j:] new_tokens = [] for token in tokens: if len(token) > 1 and _is_punctuation(token[0]): a, b = pre_punc(token) if a: new_tokens.append(a) if b: if _is_punctuation(b[-1]): c, d = post_punc(b) if c: new_tokens.append(c) if d: new_tokens.append(d) else: new_tokens.append(b) elif len(token) > 1 and _is_punctuation(token[-1]): a, b = post_punc(token) if a: new_tokens.append(a) if b: new_tokens.append(b) else: new_tokens.append(token) return new_tokens class GuokeTokenizer(object): def __init__(self, cfg): self.cfg = cfg def encode(self, x: str) -> str: x = x.strip() x = x.replace("``", '"').replace("''", '"') x = x.translate(TRANS_TABLE) tokens = x.split() tokens = _split_pre_and_post_punc(tokens) tokens = _handle_single_quote(tokens) x = " ".join(tokens) if self.cfg.lower: x = x.lower() return x def decode(self, x: str) -> str: raise NotImplementedError()