michael-guenther's picture
add assertions and docs
4b66519
raw
history blame
No virus
3.65 kB
import torch
import numpy as np
from transformers import RobertaTokenizer, BatchEncoding
import warnings
class JinaTokenizer(RobertaTokenizer):
def __init__(self, *args, **kwargs):
"""
JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
the batch encoding.
The task_type_ids are used to pass instruction information to the model.
A task_type should either be an integer or a sequence of integers with the same
length as the batch size.
"""
super().__init__(*args, **kwargs)
def __call__(self, *args, task_type=None, **kwargs):
batch_encoding = super().__call__(*args, **kwargs)
if task_type is not None:
batch_encoding = BatchEncoding(
{
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=kwargs.get('return_tensors'),
)
return batch_encoding
def _batch_encode_plus(self, *args, task_type=None, **kwargs):
batch_encoding = super()._batch_encode_plus(*args, **kwargs)
if task_type is not None:
batch_encoding = BatchEncoding(
{
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=kwargs.get('return_tensors'),
)
return batch_encoding
def _encode_plus(self, *args, task_type=None, **kwargs):
batch_encoding = super()._encode_plus(*args, **kwargs)
if task_type is not None:
batch_encoding = BatchEncoding(
{
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=kwargs.get('return_tensors'),
)
return batch_encoding
@staticmethod
def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
def apply_task_type(m, x):
x = torch.tensor(x)
assert (
len(x.shape) == 0 or x.shape[0] == m.shape[0]
), 'The shape of task_type does not match the size of the batch.'
return m * x if len(x.shape) == 0 else m * x[:, None]
if isinstance(batch_encoding['input_ids'], torch.Tensor):
shape = batch_encoding['input_ids'].shape
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
else:
try:
shape = torch.tensor(batch_encoding['input_ids']).shape
except:
raise ValueError(
"Unable to create tensor, you should probably "
"activate truncation and/or padding with "
"'padding=True' 'truncation=True' to have batched "
"tensors with the same length."
)
if isinstance(batch_encoding['input_ids'], list):
return (
apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
).tolist()
elif isinstance(batch_encoding['input_ids'], np.array):
return (
apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
).numpy()
else:
warnings.warn(
'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
)
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)