File size: 2,553 Bytes
6343db7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import numpy as np
from transformers import RobertaTokenizer, BatchEncoding
import warnings
class JinaTokenizer(RobertaTokenizer):
def __init__(self, *args, task_type_vocab_size=6, **kwargs):
super().__init__(*args, **kwargs)
self.task_type_vocab_size = task_type_vocab_size
def __call__(self, *args, task_type=None, **kwargs):
batch_encoding = super().__call__(*args, **kwargs)
batch_encoding = BatchEncoding(
{
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=kwargs.get('return_tensors'),
)
return batch_encoding
def _batch_encode_plus(self, *args, task_type=None, **kwargs):
batch_encoding = super()._batch_encode_plus(*args, **kwargs)
if task_type is not None:
batch_encoding = BatchEncoding(
{
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=kwargs.get('return_tensors'),
)
return batch_encoding
def _encode_plus(self, *args, task_type=None, **kwargs):
batch_encoding = super()._encode_plus(*args, **kwargs)
if task_type is not None:
batch_encoding = BatchEncoding(
{
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=kwargs.get('return_tensors'),
)
return batch_encoding
@staticmethod
def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
if isinstance(batch_encoding['input_ids'], torch.Tensor):
shape = batch_encoding['input_ids'].shape
return torch.ones(shape, dtype=torch.long) * task_type
else:
shape = torch.tensor(batch_encoding['input_ids']).shape
if isinstance(batch_encoding['input_ids'], list):
return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
elif isinstance(batch_encoding['input_ids'], np.array):
return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
else:
warnings.warn(
'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
)
return torch.ones(shape, dtype=torch.long) * task_type
|