michael-guenther commited on
Commit
6343db7
1 Parent(s): 32458be

add tokenizer

Browse files
Files changed (1) hide show
  1. tokenizer.py +62 -0
tokenizer.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import RobertaTokenizer, BatchEncoding
4
+ import warnings
5
+
6
+
7
+ class JinaTokenizer(RobertaTokenizer):
8
+ def __init__(self, *args, task_type_vocab_size=6, **kwargs):
9
+ super().__init__(*args, **kwargs)
10
+ self.task_type_vocab_size = task_type_vocab_size
11
+
12
+ def __call__(self, *args, task_type=None, **kwargs):
13
+ batch_encoding = super().__call__(*args, **kwargs)
14
+ batch_encoding = BatchEncoding(
15
+ {
16
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
17
+ **batch_encoding,
18
+ },
19
+ tensor_type=kwargs.get('return_tensors'),
20
+ )
21
+ return batch_encoding
22
+
23
+ def _batch_encode_plus(self, *args, task_type=None, **kwargs):
24
+ batch_encoding = super()._batch_encode_plus(*args, **kwargs)
25
+ if task_type is not None:
26
+ batch_encoding = BatchEncoding(
27
+ {
28
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
29
+ **batch_encoding,
30
+ },
31
+ tensor_type=kwargs.get('return_tensors'),
32
+ )
33
+ return batch_encoding
34
+
35
+ def _encode_plus(self, *args, task_type=None, **kwargs):
36
+ batch_encoding = super()._encode_plus(*args, **kwargs)
37
+ if task_type is not None:
38
+ batch_encoding = BatchEncoding(
39
+ {
40
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
41
+ **batch_encoding,
42
+ },
43
+ tensor_type=kwargs.get('return_tensors'),
44
+ )
45
+ return batch_encoding
46
+
47
+ @staticmethod
48
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
49
+ if isinstance(batch_encoding['input_ids'], torch.Tensor):
50
+ shape = batch_encoding['input_ids'].shape
51
+ return torch.ones(shape, dtype=torch.long) * task_type
52
+ else:
53
+ shape = torch.tensor(batch_encoding['input_ids']).shape
54
+ if isinstance(batch_encoding['input_ids'], list):
55
+ return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
56
+ elif isinstance(batch_encoding['input_ids'], np.array):
57
+ return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
58
+ else:
59
+ warnings.warn(
60
+ 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
61
+ )
62
+ return torch.ones(shape, dtype=torch.long) * task_type