File size: 869 Bytes
6d20d8a 6e82f17 3608e05 6d20d8a f64965c 8a083e2 b54c050 38e83eb b54c050 970954b 6a605a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import PreTrainedTokenizerFast
import numpy
import torch
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
outputs = super()._batch_encode_plus(*args, **kwargs)
del outputs["token_type_ids"]
for key in ['input_ids', 'attention_mask']:
if isinstance(outputs[key], torch.Tensor):
outputs[key] = outputs[key][..., :-1]
elif isinstance(outputs[key], numpy.ndarray):
outputs[key] = outputs[key][..., :-1]
elif isinstance(outputs[key], list):
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |