import re |
import unicodedata |
TRANS_TABLE = dict([(ord(x), ord(y)) for x, y in zip(u"ββΒ΄ββββ-", u"'''\"\"---")]) |
def _is_punctuation(char): |
"""Checks whether `chars` is a punctuation character.""" |
cp = ord(char) |
if (33 <= cp <= 47) or (58 <= cp <= 64) or (91 <= cp <= 96) or (123 <= cp <= 126): |
return True |
cat = unicodedata.category(char) |
if cat.startswith("P"): |
return True |
return False |
def _handle_single_quote(tokens): |
line = ' '.join(tokens) |
line = re.sub(r"' ([smdSMDtT])\b", r"'\1", line) |
line = re.sub(r"' ll\b", "'ll", line) |
line = re.sub(r"' re\b", "'re", line) |
line = re.sub(r"' ve\b", "'ve", line) |
line = re.sub(r"' LL\b", "'LL ", line) |
line = re.sub(r"' RE\b", "'RE ", line) |
line = re.sub(r"' VE\b", "'VE ", line) |
return line.split() |
def _split_on_cont_punc(tokens): |
new_tokens = [] |
for token in tokens: |
if len(token) > 1: |
last_j = 0 |
pre_is_punc = _is_punctuation(token[0]) |
for j, ch in enumerate(token): |
is_punc = _is_punctuation(ch) |
if is_punc != pre_is_punc: |
new_tokens.append(token[last_j: j]) |
last_j = j |
pre_is_punc = is_punc |
if last_j < len(token): |
new_tokens.append(token[last_j:]) |
else: |
new_tokens.append(token) |
return new_tokens |
def _split_pre_and_post_punc(tokens): |
def pre_punc(token): |
last_j = 0 |
for j in range(1, len(token)): |
if not _is_punctuation(token[j]): |
last_j = j |
break |
return token[:last_j], token[last_j:] |
def post_punc(token): |
last_j = len(token) |
for j in range(len(token) - 2, -1, -1): |
if not _is_punctuation(token[j]): |
last_j = j + 1 |
break |
return token[:last_j], token[last_j:] |
new_tokens = [] |
for token in tokens: |
if len(token) > 1 and _is_punctuation(token[0]): |
a, b = pre_punc(token) |
if a: |
new_tokens.append(a) |
if b: |
if _is_punctuation(b[-1]): |
c, d = post_punc(b) |
if c: |
new_tokens.append(c) |
if d: |
new_tokens.append(d) |
else: |
new_tokens.append(b) |
elif len(token) > 1 and _is_punctuation(token[-1]): |
a, b = post_punc(token) |
if a: |
new_tokens.append(a) |
if b: |
new_tokens.append(b) |
else: |
new_tokens.append(token) |
return new_tokens |
class GuokeTokenizer(object): |
def __init__(self, cfg): |
self.cfg = cfg |
def encode(self, x: str) -> str: |
x = x.strip() |
x = x.replace("``", '"').replace("''", '"') |
x = x.translate(TRANS_TABLE) |
tokens = x.split() |
tokens = _split_pre_and_post_punc(tokens) |
tokens = _handle_single_quote(tokens) |
x = " ".join(tokens) |
if self.cfg.lower: |
x = x.lower() |
return x |
def decode(self, x: str) -> str: |
raise NotImplementedError() |