File size: 3,954 Bytes
6343db7
 
ed1b276
6343db7
 
 
ed1b276
 
 
 
 
 
 
 
 
 
 
6343db7
ed1b276
 
 
 
 
6343db7
ed1b276
 
 
 
 
 
 
 
 
 
 
6343db7
ed1b276
 
 
6343db7
ed1b276
6343db7
 
ed1b276
6343db7
 
ed1b276
 
e151a8f
ed1b276
 
 
 
 
 
e151a8f
ed1b276
 
e151a8f
ed1b276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import numpy as np
from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
import warnings


def get_tokenizer(parent_class):
    class TokenizerClass(parent_class):
        def __init__(self, *args, **kwargs):
            """
            This class dynamically extends a given tokenizer class from the HF
            Transformers library (RobertaTokenizer or RobertaTokenizerFast).
            The task_type_ids are used to pass instruction information to the model.
            A task_type should either be an integer or a sequence of integers with the same
            length as the batch size.
            """
            super().__init__(*args, **kwargs)

        def __call__(self, *args, task_type=None, **kwargs):
            batch_encoding = super().__call__(*args, **kwargs)
            if task_type is not None:
                batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
            return batch_encoding

        def _batch_encode_plus(self, *args, task_type=None, **kwargs):
            batch_encoding = super()._batch_encode_plus(*args, **kwargs)
            if task_type is not None:
                batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
            return batch_encoding

        def _encode_plus(self, *args, task_type=None, **kwargs):
            batch_encoding = super()._encode_plus(*args, **kwargs)
            if task_type is not None:
                batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
            return batch_encoding

        @classmethod
        def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
            return BatchEncoding(
                {
                    'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
                    **batch_encoding,
                },
                tensor_type=tensor_type,
            )

        @staticmethod
        def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):

            def apply_task_type(m, x):
                x = torch.tensor(x)
                assert (
                        len(x.shape) == 0 or x.shape[0] == m.shape[0]
                ), 'The shape of task_type does not match the size of the batch.'
                return m * x if len(x.shape) == 0 else m * x[:, None]

            if isinstance(batch_encoding['input_ids'], torch.Tensor):
                shape = batch_encoding['input_ids'].shape
                return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
            else:
                try:
                    shape = torch.tensor(batch_encoding['input_ids']).shape
                except:
                    raise ValueError(
                        "Unable to create tensor, you should probably "
                        "activate truncation and/or padding with "
                        "'padding=True' 'truncation=True' to have batched "
                        "tensors with the same length."
                    )
                if isinstance(batch_encoding['input_ids'], list):
                    return (
                        apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
                    ).tolist()
                elif isinstance(batch_encoding['input_ids'], np.array):
                    return (
                        apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
                    ).numpy()
                else:
                    warnings.warn(
                        'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
                    )
                    return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)

    return TokenizerClass


JinaTokenizer = get_tokenizer(RobertaTokenizer)
JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)