File size: 869 Bytes
6d20d8a
6e82f17
 
3608e05
6d20d8a
f64965c
8a083e2
b54c050
 
38e83eb
 
 
 
 
 
 
b54c050
970954b
 
 
6a605a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import PreTrainedTokenizerFast
import numpy
import torch

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def _batch_encode_plus(self, *args, **kwargs):
        outputs = super()._batch_encode_plus(*args, **kwargs)
        del outputs["token_type_ids"]
        for key in ['input_ids', 'attention_mask']:
            if isinstance(outputs[key], torch.Tensor):
                outputs[key] = outputs[key][..., :-1]
            elif isinstance(outputs[key], numpy.ndarray): 
                outputs[key] = outputs[key][..., :-1]
            elif isinstance(outputs[key], list):
                outputs[key] = [sequence[:-1] for sequence in outputs[key]]
        return outputs

# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)