Files changed (1) hide show
  1. tokenizer.py +128 -50
tokenizer.py CHANGED
@@ -1,62 +1,140 @@
1
  import torch
2
  import numpy as np
3
- from transformers import RobertaTokenizer, BatchEncoding
4
  import warnings
5
 
6
 
7
- class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, task_type_vocab_size=6, **kwargs):
9
- super().__init__(*args, **kwargs)
10
- self.task_type_vocab_size = task_type_vocab_size
11
-
12
- def __call__(self, *args, task_type=None, **kwargs):
13
- batch_encoding = super().__call__(*args, **kwargs)
14
- batch_encoding = BatchEncoding(
15
- {
16
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
17
- **batch_encoding,
18
- },
19
- tensor_type=kwargs.get('return_tensors'),
20
- )
21
- return batch_encoding
22
-
23
- def _batch_encode_plus(self, *args, task_type=None, **kwargs):
24
- batch_encoding = super()._batch_encode_plus(*args, **kwargs)
25
- if task_type is not None:
26
- batch_encoding = BatchEncoding(
27
- {
28
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
29
- **batch_encoding,
30
- },
31
- tensor_type=kwargs.get('return_tensors'),
32
  )
33
- return batch_encoding
34
 
35
- def _encode_plus(self, *args, task_type=None, **kwargs):
36
- batch_encoding = super()._encode_plus(*args, **kwargs)
37
- if task_type is not None:
38
- batch_encoding = BatchEncoding(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  {
40
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
41
  **batch_encoding,
42
  },
43
- tensor_type=kwargs.get('return_tensors'),
44
  )
45
- return batch_encoding
46
-
47
- @staticmethod
48
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
49
- if isinstance(batch_encoding['input_ids'], torch.Tensor):
50
- shape = batch_encoding['input_ids'].shape
51
- return torch.ones(shape, dtype=torch.long) * task_type
52
- else:
53
- shape = torch.tensor(batch_encoding['input_ids']).shape
54
- if isinstance(batch_encoding['input_ids'], list):
55
- return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
56
- elif isinstance(batch_encoding['input_ids'], np.array):
57
- return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
 
58
  else:
59
- warnings.warn(
60
- 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
61
- )
62
- return torch.ones(shape, dtype=torch.long) * task_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+ from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
4
  import warnings
5
 
6
 
7
+ def get_tokenizer(parent_class):
8
+ class TokenizerClass(parent_class):
9
+ def __init__(self, *args, **kwargs):
10
+ """
11
+ This class dynamically extends a given tokenizer class from the HF
12
+ Transformers library (RobertaTokenizer or RobertaTokenizerFast).
13
+ The task_type_ids are used to pass instruction information to the model.
14
+ A task_type should either be an integer or a sequence of integers with the same
15
+ length as the batch size.
16
+ """
17
+ super().__init__(*args, **kwargs)
18
+ self.cls_token_interval = kwargs.get('cls_token_interval')
19
+
20
+ def __call__(self, *args, task_type=None, **kwargs):
21
+ return super().__call__(*args, **kwargs)
22
+
23
+ def _encode_plus(self, *args, **kwargs):
24
+ return self._process_encoding(super()._encode_plus(*args, **kwargs), **kwargs)
25
+
26
+ def _batch_encode_plus(self, *args, **kwargs):
27
+ return self._process_encoding(
28
+ super()._batch_encode_plus(*args, **kwargs), **kwargs
 
 
 
29
  )
 
30
 
31
+ def _process_encoding(self, batch_encoding: BatchEncoding, **kwargs):
32
+ task_type = kwargs.get("task_type")
33
+ if self.cls_token_interval is not None:
34
+ modified_input_ids, modified_attention_mask, modified_special_tokens_mask = self._insert_cls_tokens(
35
+ batch_encoding
36
+ )
37
+ batch_encoding["input_ids"] = modified_input_ids
38
+
39
+ if "attention_mask" in batch_encoding:
40
+ batch_encoding["attention_mask"] = modified_attention_mask
41
+
42
+ if "special_tokens_mask" in batch_encoding:
43
+ batch_encoding["special_tokens_mask"] = modified_special_tokens_mask
44
+
45
+ if task_type is not None:
46
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
47
+
48
+ return BatchEncoding(batch_encoding, tensor_type=kwargs.get("return_tensors"))
49
+
50
+ def _insert_cls_tokens(self, batch_encoding: BatchEncoding):
51
+ cls_token_id = self.cls_token_id
52
+ new_input_ids = []
53
+ new_attention_masks = []
54
+ new_special_tokens_masks = []
55
+
56
+ sequences = batch_encoding["input_ids"].tolist()
57
+ original_attention_masks = batch_encoding["attention_mask"].tolist()
58
+ original_special_tokens_mask = batch_encoding["special_tokens_mask"].tolist()
59
+ for seq_index, sequence in enumerate(sequences):
60
+ original_sequence_length = sum(original_attention_masks[seq_index])
61
+ num_cls_tokens_to_add = (
62
+ original_sequence_length - 1
63
+ ) // self.cls_token_interval
64
+ new_sequence_length = original_sequence_length + num_cls_tokens_to_add
65
+ special_tokens_mask = original_special_tokens_mask[seq_index]
66
+ modified_sequence = [sequence[0]]
67
+ modified_special_tokens_mask = [special_tokens_mask[0]]
68
+ for i in range(1, len(sequence), self.cls_token_interval):
69
+ modified_sequence.extend(sequence[i: i + self.cls_token_interval])
70
+ modified_special_tokens_mask.extend(special_tokens_mask[i: i + self.cls_token_interval])
71
+
72
+ if i + self.cls_token_interval < len(sequence):
73
+ modified_sequence.append(cls_token_id)
74
+ modified_special_tokens_mask.append(1)
75
+
76
+ new_input_ids.append(modified_sequence)
77
+ new_attention_mask = [1] * new_sequence_length + [0] * (
78
+ len(modified_sequence) - new_sequence_length
79
+ )
80
+ new_special_tokens_masks.append(modified_special_tokens_mask)
81
+ new_attention_masks.append(new_attention_mask)
82
+
83
+ new_input_ids = torch.tensor(new_input_ids, dtype=torch.long)
84
+ new_attention_masks = torch.tensor(new_attention_masks, dtype=torch.long)
85
+ new_special_tokens_masks = torch.tensor(new_special_tokens_masks, dtype=torch.long)
86
+
87
+ return new_input_ids, new_attention_masks, new_special_tokens_masks
88
+
89
+ @classmethod
90
+ def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
91
+ return BatchEncoding(
92
  {
93
+ 'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
94
  **batch_encoding,
95
  },
96
+ tensor_type=tensor_type,
97
  )
98
+
99
+ @staticmethod
100
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
101
+
102
+ def apply_task_type(m, x):
103
+ x = torch.tensor(x)
104
+ assert (
105
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
106
+ ), 'The shape of task_type does not match the size of the batch.'
107
+ return m * x if len(x.shape) == 0 else m * x[:, None]
108
+
109
+ if isinstance(batch_encoding['input_ids'], torch.Tensor):
110
+ shape = batch_encoding['input_ids'].shape
111
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
112
  else:
113
+ try:
114
+ shape = torch.tensor(batch_encoding['input_ids']).shape
115
+ except:
116
+ raise ValueError(
117
+ "Unable to create tensor, you should probably "
118
+ "activate truncation and/or padding with "
119
+ "'padding=True' 'truncation=True' to have batched "
120
+ "tensors with the same length."
121
+ )
122
+ if isinstance(batch_encoding['input_ids'], list):
123
+ return (
124
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
125
+ ).tolist()
126
+ elif isinstance(batch_encoding['input_ids'], np.array):
127
+ return (
128
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
129
+ ).numpy()
130
+ else:
131
+ warnings.warn(
132
+ 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
133
+ )
134
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
135
+
136
+ return TokenizerClass
137
+
138
+
139
+ JinaTokenizer = get_tokenizer(RobertaTokenizer)
140
+ JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)