import torch def normalize_abbreviations(text): text = text.replace(" n't ", "n't ") text = text.replace(" N'T ", "N'T ") text = text.replace(" 'll ", "'ll ") text = text.replace(" 'LL ", "'LL ") text = text.replace(" 're ", "'re ") text = text.replace(" 'RE ", "'RE ") text = text.replace(" 've ", "'ve ") text = text.replace(" 'VE ", "'VE ") text = text.replace(" 'm ", "'m ") text = text.replace(" 'M ", "'M ") text = text.replace(" 's ", "'s ") text = text.replace(" 'S ", "'S ") text = text.replace(" 'd ", "'d ") text = text.replace(" 'D ", "'D ") return text def fix_quotes(text, quote_symbol='"'): n_quotes = text.count(f" {quote_symbol}") + text.count(f"{quote_symbol} ") - text.count(f" {quote_symbol} ") if ( n_quotes == 0 or (n_quotes % 2) == 1 or f"{quote_symbol}{quote_symbol}" in text or f"{quote_symbol} {quote_symbol}" in text ): return text i, i_quote, n_changes = 0, 0, 0 while i < len(text): if text[i] != quote_symbol or (i - 1 >= 0 and text[i - 1] != ' ' and i + 1 < len(text) and text[i + 1] != ' '): i += 1 continue if (i_quote % 2) == 0: if i > 0 and text[i - 1] != ' ': text = text[:i] + ' ' + text[i:] i += 1 n_changes += 1 if i + 1 < len(text) and text[i + 1] == ' ': text = text[:i + 1] + text[i + 2:] n_changes += 1 else: if i > 0 and text[i - 1] == ' ': text = text[:i - 1] + text[i:] i -= 1 n_changes += 1 if i + 1 < len(text) and text[i + 1].isalnum(): text = text[:i + 1] + ' ' + text[i + 1:] n_changes += 1 i_quote += 1 i += 1 return text def detokenize(tokens, compact_dashes=False): text = ' '.join(tokens) text = normalize_abbreviations(text) if compact_dashes: text = text.replace(' - ', '-') for i in range(len(text) - 2, -1, -1): if text[i] == '.' and (text[i + 1].isupper() or text[i + 1] in ['‘', '(', '[', '{']): text = text[:i+1] + ' ' + text[i+1:] elif text[i] in ['?', '!', '…', '’'] and (text[i + 1].isalnum() or text[i + 1] in ['‘', '(', '[', '{']): text = text[:i+1] + ' ' + text[i+1:] elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ': text = text[:i+1] + ' ' + text[i+1:] elif i > 2 and text[i] == '.' and text[i - 1] == '.' and text[i - 2] == '.' and text[i + 1] != ' ': text = text[:i+1] + ' ' + text[i+1:] elif text[i] == ',' and (text[i + 1].isalpha() or text[i + 1] in ['‘', '(', '[', '{']): text = text[:i+1] + ' ' + text[i+1:] elif text[i] in [';', ')', ']', '}', '%'] and (text[i + 1].isalnum() or text[i + 1] in ['‘', '(', '[', '{']): text = text[:i+1] + ' ' + text[i+1:] elif text[i] == ':' and (text[i + 1] in ['‘', '(', '[', '{'] or (text[i + 1].isalnum() and (not text[i + 1].isnumeric() or i - 1 < 0 or not text[i - 1].isnumeric()))): text = text[:i+1] + ' ' + text[i+1:] elif text[i] in ['(', '[', '{'] and text[i + 1] == ' ': text = text[:i+1] + text[i+2:] elif text[i] == ' ' and text[i+1] in ['.', ';', ':', '?', '!', '…', ',', '’', ')', ']']: text = text[:i] + text[i+1:] elif i > 0 and text[i] == ' ' and text[i - 1] in ['$', '£', '€'] and text[i + 1].isnumeric(): text = text[:i] + text[i+1:] elif i > 0 and text[i] == ' ' and text[i - 1].isnumeric() and text[i + 1] == '%': text = text[:i] + text[i+1:] text = fix_quotes(text, '"') text = fix_quotes(text, "'") spans = [] word_offset, char_offset = 0, 0 for i, ch in enumerate(text): if ch == ' ': if tokens[word_offset][char_offset] == ' ': char_offset += 1 continue assert ch == tokens[word_offset][char_offset], f"{text}\n{' '.join(tokens)}\n{tokens[word_offset]}\n{char_offset} {ch}" if char_offset == 0: start = i if char_offset == len(tokens[word_offset]) - 1: end = i + 1 spans.append((start, end)) word_offset += 1 char_offset = 0 else: char_offset += 1 return text, spans def calculate_spans(original_spans, encoding_offsets): span_id = 0 subword_spans = [[] for _ in original_spans] for i, (_, end) in enumerate(encoding_offsets): subword_spans[span_id].append(i + 1) while original_spans[span_id][1] <= end: span_id += 1 if span_id < len(original_spans) and end > original_spans[span_id][0]: subword_spans[span_id].append(i + 1) if span_id == len(original_spans): return subword_spans return subword_spans def subtokenize(tokens, tokenizer, compact_dashes=False): text, spans = detokenize(tokens, compact_dashes=compact_dashes) encoding = tokenizer(text, return_offsets_mapping=True) spans = calculate_spans(spans, encoding["offset_mapping"][1:-1]) subwords = encoding["input_ids"] subword_mask = torch.zeros(len(subwords), len(spans), dtype=torch.bool) for word_id, subword_ids in enumerate(spans): for subword_id in subword_ids: subword_mask[subword_id + 1, word_id] = True return subwords, subword_mask