from transformers import PreTrainedTokenizerFast
import numpy
import torch

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def _batch_encode_plus(self, *args, **kwargs):
        outputs = super()._batch_encode_plus(*args, **kwargs)
        del outputs["token_type_ids"]
        for key in ['input_ids', 'attention_mask']:
            if isinstance(outputs[key], torch.Tensor):
                outputs[key] = outputs[key][..., :-1]
            elif isinstance(outputs[key], numpy.ndarray): 
                outputs[key] = outputs[key][..., :-1]
            elif isinstance(outputs[key], list):
                outputs[key] = [sequence[:-1] for sequence in outputs[key]]
        return outputs

# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)