File size: 4,419 Bytes
6343db7
11b09c9
6343db7
ed1b276
6343db7
 
 
ed1b276
 
11b09c9
9072f7f
11b09c9
 
 
9072f7f
11b09c9
 
ed1b276
 
 
 
 
 
 
 
 
6343db7
11b09c9
ed1b276
 
11b09c9
 
 
ed1b276
6343db7
11b09c9
ed1b276
 
11b09c9
 
 
ed1b276
 
11b09c9
ed1b276
 
11b09c9
 
 
ed1b276
6343db7
ed1b276
11b09c9
 
 
ed1b276
6343db7
ed1b276
6343db7
 
ed1b276
6343db7
 
ed1b276
11b09c9
ed1b276
 
 
11b09c9
ed1b276
 
e151a8f
ed1b276
 
e151a8f
ed1b276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11b09c9
 
 
ed1b276
 
 
 
 
 
11b09c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from enum import IntEnum
import numpy as np
from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
import warnings


def get_tokenizer(parent_class):
    class TokenizerClass(parent_class):
        class TaskTypes(IntEnum):
            NULL = 0
            QUERY = 1
            DOCUMENT = 2
            STS = 3
            CLUSTERING = 4
            CLASSIFICATION = 5

        def __init__(self, *args, **kwargs):
            """
            This class dynamically extends a given tokenizer class from the HF
            Transformers library (RobertaTokenizer or RobertaTokenizerFast).
            The task_type_ids are used to pass instruction information to the model.
            A task_type should either be an integer or a sequence of integers with the same
            length as the batch size.
            """
            super().__init__(*args, **kwargs)

        def __call__(self, *args, task_type: TaskTypes = None, **kwargs):
            batch_encoding = super().__call__(*args, **kwargs)
            if task_type is not None:
                batch_encoding = self._add_task_type_ids(
                    batch_encoding, task_type, kwargs.get('return_tensors')
                )
            return batch_encoding

        def _batch_encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
            batch_encoding = super()._batch_encode_plus(*args, **kwargs)
            if task_type is not None:
                batch_encoding = self._add_task_type_ids(
                    batch_encoding, task_type, kwargs.get('return_tensors')
                )
            return batch_encoding

        def _encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
            batch_encoding = super()._encode_plus(*args, **kwargs)
            if task_type is not None:
                batch_encoding = self._add_task_type_ids(
                    batch_encoding, task_type, kwargs.get('return_tensors')
                )
            return batch_encoding

        @classmethod
        def _add_task_type_ids(
            cls, batch_encoding: BatchEncoding, task_type: TaskTypes, tensor_type: str
        ):
            return BatchEncoding(
                {
                    'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
                    **batch_encoding,
                },
                tensor_type=tensor_type,
            )

        @staticmethod
        def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: TaskTypes):
            def apply_task_type(m, x):
                x = torch.tensor(x)
                assert (
                    len(x.shape) == 0 or x.shape[0] == m.shape[0]
                ), 'The shape of task_type does not match the size of the batch.'
                return m * x if len(x.shape) == 0 else m * x[:, None]

            if isinstance(batch_encoding['input_ids'], torch.Tensor):
                shape = batch_encoding['input_ids'].shape
                return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
            else:
                try:
                    shape = torch.tensor(batch_encoding['input_ids']).shape
                except:
                    raise ValueError(
                        "Unable to create tensor, you should probably "
                        "activate truncation and/or padding with "
                        "'padding=True' 'truncation=True' to have batched "
                        "tensors with the same length."
                    )
                if isinstance(batch_encoding['input_ids'], list):
                    return (
                        apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
                    ).tolist()
                elif isinstance(batch_encoding['input_ids'], np.array):
                    return (
                        apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
                    ).numpy()
                else:
                    warnings.warn(
                        'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
                    )
                    return apply_task_type(
                        torch.ones(shape, dtype=torch.long), task_type
                    )

    return TokenizerClass


JinaTokenizer = get_tokenizer(RobertaTokenizer)
JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)