File size: 1,121 Bytes
6d20d8a
6e82f17
 
3608e05
6d20d8a
f64965c
8a083e2
b54c050
 
 
 
 
 
 
 
 
 
 
970954b
 
 
6a605a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import PreTrainedTokenizerFast
import numpy
import torch

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def _batch_encode_plus(self, *args, **kwargs):
        outputs = super()._batch_encode_plus(*args, **kwargs)
        del outputs["token_type_ids"]
        for key in ['input_ids', 'attention_mask']:
            if isinstance(outputs[key], (list, numpy.ndarray, torch.Tensor)):
                if isinstance(outputs[key], list):
                    outputs[key] = [sequence[:-1] for sequence in outputs[key]]
                elif isinstance(outputs[key], numpy.ndarray):
                    outputs[key] = numpy.array([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype)
                elif isinstance(outputs[key], torch.Tensor):
                    outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
        return outputs

# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)