|
|
|
import html |
|
import string |
|
|
|
import ftfy |
|
import regex as re |
|
from transformers import AutoTokenizer |
|
|
|
__all__ = ['HuggingfaceTokenizer'] |
|
|
|
|
|
def basic_clean(text): |
|
text = ftfy.fix_text(text) |
|
text = html.unescape(html.unescape(text)) |
|
return text.strip() |
|
|
|
|
|
def whitespace_clean(text): |
|
text = re.sub(r'\s+', ' ', text) |
|
text = text.strip() |
|
return text |
|
|
|
|
|
def canonicalize(text, keep_punctuation_exact_string=None): |
|
text = text.replace('_', ' ') |
|
if keep_punctuation_exact_string: |
|
text = keep_punctuation_exact_string.join( |
|
part.translate(str.maketrans('', '', string.punctuation)) |
|
for part in text.split(keep_punctuation_exact_string)) |
|
else: |
|
text = text.translate(str.maketrans('', '', string.punctuation)) |
|
text = text.lower() |
|
text = re.sub(r'\s+', ' ', text) |
|
return text.strip() |
|
|
|
|
|
class HuggingfaceTokenizer: |
|
|
|
def __init__(self, name, seq_len=None, clean=None, **kwargs): |
|
assert clean in (None, 'whitespace', 'lower', 'canonicalize') |
|
self.name = name |
|
self.seq_len = seq_len |
|
self.clean = clean |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) |
|
self.vocab_size = self.tokenizer.vocab_size |
|
|
|
def __call__(self, sequence, **kwargs): |
|
return_mask = kwargs.pop('return_mask', False) |
|
|
|
|
|
_kwargs = {'return_tensors': 'pt'} |
|
if self.seq_len is not None: |
|
_kwargs.update({ |
|
'padding': 'max_length', |
|
'truncation': True, |
|
'max_length': self.seq_len |
|
}) |
|
_kwargs.update(**kwargs) |
|
|
|
|
|
if isinstance(sequence, str): |
|
sequence = [sequence] |
|
if self.clean: |
|
sequence = [self._clean(u) for u in sequence] |
|
ids = self.tokenizer(sequence, **_kwargs) |
|
|
|
|
|
if return_mask: |
|
return ids.input_ids, ids.attention_mask |
|
else: |
|
return ids.input_ids |
|
|
|
def _clean(self, text): |
|
if self.clean == 'whitespace': |
|
text = whitespace_clean(basic_clean(text)) |
|
elif self.clean == 'lower': |
|
text = whitespace_clean(basic_clean(text)).lower() |
|
elif self.clean == 'canonicalize': |
|
text = canonicalize(basic_clean(text)) |
|
return text |
|
|