import torch import numpy as np from transformers import RobertaTokenizer, BatchEncoding import warnings class JinaTokenizer(RobertaTokenizer): def __init__(self, *args, task_type_vocab_size=6, **kwargs): super().__init__(*args, **kwargs) self.task_type_vocab_size = task_type_vocab_size def __call__(self, *args, task_type=None, **kwargs): batch_encoding = super().__call__(*args, **kwargs) batch_encoding = BatchEncoding( { 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type), **batch_encoding, }, tensor_type=kwargs.get('return_tensors'), ) return batch_encoding def _batch_encode_plus(self, *args, task_type=None, **kwargs): batch_encoding = super()._batch_encode_plus(*args, **kwargs) if task_type is not None: batch_encoding = BatchEncoding( { 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type), **batch_encoding, }, tensor_type=kwargs.get('return_tensors'), ) return batch_encoding def _encode_plus(self, *args, task_type=None, **kwargs): batch_encoding = super()._encode_plus(*args, **kwargs) if task_type is not None: batch_encoding = BatchEncoding( { 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type), **batch_encoding, }, tensor_type=kwargs.get('return_tensors'), ) return batch_encoding @staticmethod def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int): if isinstance(batch_encoding['input_ids'], torch.Tensor): shape = batch_encoding['input_ids'].shape return torch.ones(shape, dtype=torch.long) * task_type else: shape = torch.tensor(batch_encoding['input_ids']).shape if isinstance(batch_encoding['input_ids'], list): return (torch.ones(shape, dtype=torch.long) * task_type).tolist() elif isinstance(batch_encoding['input_ids'], np.array): return (torch.ones(shape, dtype=torch.long) * task_type).numpy() else: warnings.warn( 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor' ) return torch.ones(shape, dtype=torch.long) * task_type