michael-guenther commited on
Commit
db57d38
1 Parent(s): 59c0808

feat: add enum for task type ids

Browse files
Files changed (1) hide show
  1. tokenizer.py +30 -11
tokenizer.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import numpy as np
3
  from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
4
  import warnings
@@ -6,6 +7,14 @@ import warnings
6
 
7
  def get_tokenizer(parent_class):
8
  class TokenizerClass(parent_class):
 
 
 
 
 
 
 
 
9
  def __init__(self, *args, **kwargs):
10
  """
11
  This class dynamically extends a given tokenizer class from the HF
@@ -16,26 +25,34 @@ def get_tokenizer(parent_class):
16
  """
17
  super().__init__(*args, **kwargs)
18
 
19
- def __call__(self, *args, task_type=None, **kwargs):
20
  batch_encoding = super().__call__(*args, **kwargs)
21
  if task_type is not None:
22
- batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
 
 
23
  return batch_encoding
24
 
25
- def _batch_encode_plus(self, *args, task_type=None, **kwargs):
26
  batch_encoding = super()._batch_encode_plus(*args, **kwargs)
27
  if task_type is not None:
28
- batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
 
 
29
  return batch_encoding
30
 
31
- def _encode_plus(self, *args, task_type=None, **kwargs):
32
  batch_encoding = super()._encode_plus(*args, **kwargs)
33
  if task_type is not None:
34
- batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
 
 
35
  return batch_encoding
36
 
37
  @classmethod
38
- def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
 
 
39
  return BatchEncoding(
40
  {
41
  'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
@@ -45,12 +62,11 @@ def get_tokenizer(parent_class):
45
  )
46
 
47
  @staticmethod
48
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
49
-
50
  def apply_task_type(m, x):
51
  x = torch.tensor(x)
52
  assert (
53
- len(x.shape) == 0 or x.shape[0] == m.shape[0]
54
  ), 'The shape of task_type does not match the size of the batch.'
55
  return m * x if len(x.shape) == 0 else m * x[:, None]
56
 
@@ -79,10 +95,13 @@ def get_tokenizer(parent_class):
79
  warnings.warn(
80
  'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
81
  )
82
- return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
 
 
83
 
84
  return TokenizerClass
85
 
86
 
87
  JinaTokenizer = get_tokenizer(RobertaTokenizer)
88
  JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)
 
 
1
  import torch
2
+ from enum import IntEnum
3
  import numpy as np
4
  from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
5
  import warnings
 
7
 
8
  def get_tokenizer(parent_class):
9
  class TokenizerClass(parent_class):
10
+ class TaskTypes(IntEnum):
11
+ NULL = (0,)
12
+ QUERY = 1
13
+ DOCUMENT = 2
14
+ STS = 3
15
+ CLUSTERING = (4,)
16
+ CLASSIFICATION = 5
17
+
18
  def __init__(self, *args, **kwargs):
19
  """
20
  This class dynamically extends a given tokenizer class from the HF
 
25
  """
26
  super().__init__(*args, **kwargs)
27
 
28
+ def __call__(self, *args, task_type: TaskTypes = None, **kwargs):
29
  batch_encoding = super().__call__(*args, **kwargs)
30
  if task_type is not None:
31
+ batch_encoding = self._add_task_type_ids(
32
+ batch_encoding, task_type, kwargs.get('return_tensors')
33
+ )
34
  return batch_encoding
35
 
36
+ def _batch_encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
37
  batch_encoding = super()._batch_encode_plus(*args, **kwargs)
38
  if task_type is not None:
39
+ batch_encoding = self._add_task_type_ids(
40
+ batch_encoding, task_type, kwargs.get('return_tensors')
41
+ )
42
  return batch_encoding
43
 
44
+ def _encode_plus(self, *args, task_type: TaskTypes = None, **kwargs):
45
  batch_encoding = super()._encode_plus(*args, **kwargs)
46
  if task_type is not None:
47
+ batch_encoding = self._add_task_type_ids(
48
+ batch_encoding, task_type, kwargs.get('return_tensors')
49
+ )
50
  return batch_encoding
51
 
52
  @classmethod
53
+ def _add_task_type_ids(
54
+ cls, batch_encoding: BatchEncoding, task_type: TaskTypes, tensor_type: str
55
+ ):
56
  return BatchEncoding(
57
  {
58
  'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
 
62
  )
63
 
64
  @staticmethod
65
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: TaskTypes):
 
66
  def apply_task_type(m, x):
67
  x = torch.tensor(x)
68
  assert (
69
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
70
  ), 'The shape of task_type does not match the size of the batch.'
71
  return m * x if len(x.shape) == 0 else m * x[:, None]
72
 
 
95
  warnings.warn(
96
  'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
97
  )
98
+ return apply_task_type(
99
+ torch.ones(shape, dtype=torch.long), task_type
100
+ )
101
 
102
  return TokenizerClass
103
 
104
 
105
  JinaTokenizer = get_tokenizer(RobertaTokenizer)
106
  JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)
107
+