File size: 2,553 Bytes
6343db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import numpy as np
from transformers import RobertaTokenizer, BatchEncoding
import warnings


class JinaTokenizer(RobertaTokenizer):
    def __init__(self, *args, task_type_vocab_size=6, **kwargs):
        super().__init__(*args, **kwargs)
        self.task_type_vocab_size = task_type_vocab_size

    def __call__(self, *args, task_type=None, **kwargs):
        batch_encoding = super().__call__(*args, **kwargs)
        batch_encoding = BatchEncoding(
            {
                'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
                **batch_encoding,
            },
            tensor_type=kwargs.get('return_tensors'),
        )
        return batch_encoding

    def _batch_encode_plus(self, *args, task_type=None, **kwargs):
        batch_encoding = super()._batch_encode_plus(*args, **kwargs)
        if task_type is not None:
            batch_encoding = BatchEncoding(
                {
                    'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
                    **batch_encoding,
                },
                tensor_type=kwargs.get('return_tensors'),
            )
        return batch_encoding

    def _encode_plus(self, *args, task_type=None, **kwargs):
        batch_encoding = super()._encode_plus(*args, **kwargs)
        if task_type is not None:
            batch_encoding = BatchEncoding(
                {
                    'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
                    **batch_encoding,
                },
                tensor_type=kwargs.get('return_tensors'),
            )
        return batch_encoding

    @staticmethod
    def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
        if isinstance(batch_encoding['input_ids'], torch.Tensor):
            shape = batch_encoding['input_ids'].shape
            return torch.ones(shape, dtype=torch.long) * task_type
        else:
            shape = torch.tensor(batch_encoding['input_ids']).shape
            if isinstance(batch_encoding['input_ids'], list):
                return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
            elif isinstance(batch_encoding['input_ids'], np.array):
                return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
            else:
                warnings.warn(
                    'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
                )
                return torch.ones(shape, dtype=torch.long) * task_type