|
import torch |
|
|
|
|
|
def normalize_abbreviations(text): |
|
text = text.replace(" n't ", "n't ") |
|
text = text.replace(" N'T ", "N'T ") |
|
text = text.replace(" 'll ", "'ll ") |
|
text = text.replace(" 'LL ", "'LL ") |
|
text = text.replace(" 're ", "'re ") |
|
text = text.replace(" 'RE ", "'RE ") |
|
text = text.replace(" 've ", "'ve ") |
|
text = text.replace(" 'VE ", "'VE ") |
|
text = text.replace(" 'm ", "'m ") |
|
text = text.replace(" 'M ", "'M ") |
|
text = text.replace(" 's ", "'s ") |
|
text = text.replace(" 'S ", "'S ") |
|
text = text.replace(" 'd ", "'d ") |
|
text = text.replace(" 'D ", "'D ") |
|
return text |
|
|
|
|
|
def fix_quotes(text, quote_symbol='"'): |
|
n_quotes = text.count(f" {quote_symbol}") + text.count(f"{quote_symbol} ") - text.count(f" {quote_symbol} ") |
|
if ( |
|
n_quotes == 0 |
|
or (n_quotes % 2) == 1 |
|
or f"{quote_symbol}{quote_symbol}" in text |
|
or f"{quote_symbol} {quote_symbol}" in text |
|
): |
|
return text |
|
|
|
i, i_quote, n_changes = 0, 0, 0 |
|
while i < len(text): |
|
if text[i] != quote_symbol or (i - 1 >= 0 and text[i - 1] != ' ' and i + 1 < len(text) and text[i + 1] != ' '): |
|
i += 1 |
|
continue |
|
|
|
if (i_quote % 2) == 0: |
|
if i > 0 and text[i - 1] != ' ': |
|
text = text[:i] + ' ' + text[i:] |
|
i += 1 |
|
n_changes += 1 |
|
if i + 1 < len(text) and text[i + 1] == ' ': |
|
text = text[:i + 1] + text[i + 2:] |
|
n_changes += 1 |
|
else: |
|
if i > 0 and text[i - 1] == ' ': |
|
text = text[:i - 1] + text[i:] |
|
i -= 1 |
|
n_changes += 1 |
|
if i + 1 < len(text) and text[i + 1].isalnum(): |
|
text = text[:i + 1] + ' ' + text[i + 1:] |
|
n_changes += 1 |
|
|
|
i_quote += 1 |
|
i += 1 |
|
|
|
return text |
|
|
|
|
|
def detokenize(tokens, compact_dashes=False): |
|
text = ' '.join(tokens) |
|
text = normalize_abbreviations(text) |
|
|
|
if compact_dashes: |
|
text = text.replace(' - ', '-') |
|
|
|
for i in range(len(text) - 2, -1, -1): |
|
if text[i] == '.' and (text[i + 1].isupper() or text[i + 1] in ['β', '(', '[', '{']): |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif text[i] in ['?', '!', 'β¦', 'β'] and (text[i + 1].isalnum() or text[i + 1] in ['β', '(', '[', '{']): |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ': |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ': |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif text[i] == ',' and (text[i + 1].isalpha() or text[i + 1] in ['β', '(', '[', '{']): |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif text[i] in [';', ')', ']', '}', '%'] and (text[i + 1].isalnum() or text[i + 1] in ['β', '(', '[', '{']): |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif text[i] == ':' and (text[i + 1] in ['β', '(', '[', '{'] or (text[i + 1].isalnum() and (not text[i + 1].isnumeric() or i - 1 < 0 or not text[i - 1].isnumeric()))): |
|
text = text[:i+1] + ' ' + text[i+1:] |
|
elif text[i] in ['(', '[', '{'] and text[i + 1] == ' ': |
|
text = text[:i+1] + text[i+2:] |
|
elif text[i] == ' ' and text[i+1] in ['.', ';', ':', '?', '!', 'β¦', ',', 'β', ')', ']']: |
|
text = text[:i] + text[i+1:] |
|
elif i > 0 and text[i] == ' ' and text[i - 1] in ['$', 'Β£', 'β¬'] and text[i + 1].isnumeric(): |
|
text = text[:i] + text[i+1:] |
|
elif i > 0 and text[i] == ' ' and text[i - 1].isnumeric() and text[i + 1] == '%': |
|
text = text[:i] + text[i+1:] |
|
|
|
text = fix_quotes(text, '"') |
|
text = fix_quotes(text, "'") |
|
|
|
spans = [] |
|
word_offset, char_offset = 0, 0 |
|
for i, ch in enumerate(text): |
|
if ch == ' ': |
|
if tokens[word_offset][char_offset] == ' ': |
|
char_offset += 1 |
|
continue |
|
|
|
assert ch == tokens[word_offset][char_offset], f"{text}\n{' '.join(tokens)}\n{tokens[word_offset]}\n{char_offset} {ch}" |
|
|
|
if char_offset == 0: |
|
start = i |
|
|
|
if char_offset == len(tokens[word_offset]) - 1: |
|
end = i + 1 |
|
spans.append((start, end)) |
|
word_offset += 1 |
|
char_offset = 0 |
|
else: |
|
char_offset += 1 |
|
|
|
return text, spans |
|
|
|
|
|
def calculate_spans(original_spans, encoding_offsets): |
|
span_id = 0 |
|
subword_spans = [[] for _ in original_spans] |
|
for i, (_, end) in enumerate(encoding_offsets): |
|
subword_spans[span_id].append(i + 1) |
|
|
|
while original_spans[span_id][1] <= end: |
|
span_id += 1 |
|
if span_id < len(original_spans) and end > original_spans[span_id][0]: |
|
subword_spans[span_id].append(i + 1) |
|
|
|
if span_id == len(original_spans): |
|
return subword_spans |
|
|
|
return subword_spans |
|
|
|
|
|
def subtokenize(tokens, tokenizer, compact_dashes=False): |
|
text, spans = detokenize(tokens, compact_dashes=compact_dashes) |
|
|
|
encoding = tokenizer(text, return_offsets_mapping=True) |
|
|
|
spans = calculate_spans(spans, encoding["offset_mapping"][1:-1]) |
|
subwords = encoding["input_ids"] |
|
|
|
subword_mask = torch.zeros(len(subwords), len(spans), dtype=torch.bool) |
|
for word_id, subword_ids in enumerate(spans): |
|
for subword_id in subword_ids: |
|
subword_mask[subword_id + 1, word_id] = True |
|
|
|
return subwords, subword_mask |
|
|