michael-guenther's picture
Update tokenizer.py
9072f7f verified
raw
history blame
No virus
4.42 kB
import torch
from enum import IntEnum
import numpy as np
from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
import warnings
def get_tokenizer(parent_class):
class TokenizerClass(parent_class):
class TaskTypes(IntEnum):
NULL = 0
QUERY = 1
DOCUMENT = 2
STS = 3
CLUSTERING = 4
CLASSIFICATION = 5
def __init__(self, *args, **kwargs):
"""
This class dynamically extends a given tokenizer class from the HF
Transformers library (RobertaTokenizer or RobertaTokenizerFast).
The task_type_ids are used to pass instruction information to the model.
A task_type should either be an integer or a sequence of integers with the same
length as the batch size.
"""
super().__init__(*args, **kwargs)
def __call__(self, *args, task_type: TaskTypes = None, **kwargs):
batch_encoding = super().__call__(*args, **kwargs)
if task_type is not None:
batch_encoding = self._add_task_type_ids(
batch_encoding, task_type, kwargs.get('return_tensors')
)
return batch_encoding
def _batch_encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
batch_encoding = super()._batch_encode_plus(*args, **kwargs)
if task_type is not None:
batch_encoding = self._add_task_type_ids(
batch_encoding, task_type, kwargs.get('return_tensors')
)
return batch_encoding
def _encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
batch_encoding = super()._encode_plus(*args, **kwargs)
if task_type is not None:
batch_encoding = self._add_task_type_ids(
batch_encoding, task_type, kwargs.get('return_tensors')
)
return batch_encoding
@classmethod
def _add_task_type_ids(
cls, batch_encoding: BatchEncoding, task_type: TaskTypes, tensor_type: str
):
return BatchEncoding(
{
'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
**batch_encoding,
},
tensor_type=tensor_type,
)
@staticmethod
def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: TaskTypes):
def apply_task_type(m, x):
x = torch.tensor(x)
assert (
len(x.shape) == 0 or x.shape[0] == m.shape[0]
), 'The shape of task_type does not match the size of the batch.'
return m * x if len(x.shape) == 0 else m * x[:, None]
if isinstance(batch_encoding['input_ids'], torch.Tensor):
shape = batch_encoding['input_ids'].shape
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
else:
try:
shape = torch.tensor(batch_encoding['input_ids']).shape
except:
raise ValueError(
"Unable to create tensor, you should probably "
"activate truncation and/or padding with "
"'padding=True' 'truncation=True' to have batched "
"tensors with the same length."
)
if isinstance(batch_encoding['input_ids'], list):
return (
apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
).tolist()
elif isinstance(batch_encoding['input_ids'], np.array):
return (
apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
).numpy()
else:
warnings.warn(
'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
)
return apply_task_type(
torch.ones(shape, dtype=torch.long), task_type
)
return TokenizerClass
JinaTokenizer = get_tokenizer(RobertaTokenizer)
JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)