File size: 1,121 Bytes
6d20d8a 6e82f17 3608e05 6d20d8a f64965c 8a083e2 b54c050 970954b 6a605a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
from transformers import PreTrainedTokenizerFast
import numpy
import torch
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
outputs = super()._batch_encode_plus(*args, **kwargs)
del outputs["token_type_ids"]
for key in ['input_ids', 'attention_mask']:
if isinstance(outputs[key], (list, numpy.ndarray, torch.Tensor)):
if isinstance(outputs[key], list):
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
elif isinstance(outputs[key], numpy.ndarray):
outputs[key] = numpy.array([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype)
elif isinstance(outputs[key], torch.Tensor):
outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |