jupyterjazz commited on
Commit
72221a7
1 Parent(s): 9b5c148

refactor: batch encoding

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (1) hide show
  1. tokenizer.py +15 -23
tokenizer.py CHANGED
@@ -8,8 +8,8 @@ def get_tokenizer(parent_class):
8
  class TokenizerClass(parent_class):
9
  def __init__(self, *args, **kwargs):
10
  """
11
- JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
12
- the batch encoding.
13
  The task_type_ids are used to pass instruction information to the model.
14
  A task_type should either be an integer or a sequence of integers with the same
15
  length as the batch size.
@@ -19,39 +19,31 @@ def get_tokenizer(parent_class):
19
  def __call__(self, *args, task_type=None, **kwargs):
20
  batch_encoding = super().__call__(*args, **kwargs)
21
  if task_type is not None:
22
- batch_encoding = BatchEncoding(
23
- {
24
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
25
- **batch_encoding,
26
- },
27
- tensor_type=kwargs.get('return_tensors'),
28
- )
29
  return batch_encoding
30
 
31
  def _batch_encode_plus(self, *args, task_type=None, **kwargs):
32
  batch_encoding = super()._batch_encode_plus(*args, **kwargs)
33
  if task_type is not None:
34
- batch_encoding = BatchEncoding(
35
- {
36
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
37
- **batch_encoding,
38
- },
39
- tensor_type=kwargs.get('return_tensors'),
40
- )
41
  return batch_encoding
42
 
43
  def _encode_plus(self, *args, task_type=None, **kwargs):
44
  batch_encoding = super()._encode_plus(*args, **kwargs)
45
  if task_type is not None:
46
- batch_encoding = BatchEncoding(
47
- {
48
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
49
- **batch_encoding,
50
- },
51
- tensor_type=kwargs.get('return_tensors'),
52
- )
53
  return batch_encoding
54
 
 
 
 
 
 
 
 
 
 
 
55
  @staticmethod
56
  def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
57
 
 
8
  class TokenizerClass(parent_class):
9
  def __init__(self, *args, **kwargs):
10
  """
11
+ This class dynamically extends a given tokenizer class from the HF
12
+ Transformers library (RobertaTokenizer or RobertaTokenizerFast).
13
  The task_type_ids are used to pass instruction information to the model.
14
  A task_type should either be an integer or a sequence of integers with the same
15
  length as the batch size.
 
19
  def __call__(self, *args, task_type=None, **kwargs):
20
  batch_encoding = super().__call__(*args, **kwargs)
21
  if task_type is not None:
22
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
 
 
 
 
 
 
23
  return batch_encoding
24
 
25
  def _batch_encode_plus(self, *args, task_type=None, **kwargs):
26
  batch_encoding = super()._batch_encode_plus(*args, **kwargs)
27
  if task_type is not None:
28
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
 
 
 
 
 
 
29
  return batch_encoding
30
 
31
  def _encode_plus(self, *args, task_type=None, **kwargs):
32
  batch_encoding = super()._encode_plus(*args, **kwargs)
33
  if task_type is not None:
34
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
 
 
 
 
 
 
35
  return batch_encoding
36
 
37
+ @classmethod
38
+ def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
39
+ return BatchEncoding(
40
+ {
41
+ 'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
42
+ **batch_encoding,
43
+ },
44
+ tensor_type=tensor_type,
45
+ )
46
+
47
  @staticmethod
48
  def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
49