jupyterjazz commited on
Commit
12000a5
1 Parent(s): ae4c28c

feat: mcls

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (1) hide show
  1. tokenizer.py +69 -37
tokenizer.py CHANGED
@@ -1,62 +1,94 @@
1
- import torch
2
- import numpy as np
3
- from transformers import RobertaTokenizer, BatchEncoding
4
  import warnings
5
 
 
 
 
 
6
 
7
  class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, task_type_vocab_size=6, **kwargs):
 
 
9
  super().__init__(*args, **kwargs)
10
  self.task_type_vocab_size = task_type_vocab_size
 
11
 
12
  def __call__(self, *args, task_type=None, **kwargs):
13
- batch_encoding = super().__call__(*args, **kwargs)
14
- batch_encoding = BatchEncoding(
15
- {
16
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
17
- **batch_encoding,
18
- },
19
- tensor_type=kwargs.get('return_tensors'),
 
 
20
  )
21
- return batch_encoding
22
 
23
- def _batch_encode_plus(self, *args, task_type=None, **kwargs):
24
- batch_encoding = super()._batch_encode_plus(*args, **kwargs)
25
- if task_type is not None:
26
- batch_encoding = BatchEncoding(
27
- {
28
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
29
- **batch_encoding,
30
- },
31
- tensor_type=kwargs.get('return_tensors'),
32
  )
33
- return batch_encoding
 
 
 
34
 
35
- def _encode_plus(self, *args, task_type=None, **kwargs):
36
- batch_encoding = super()._encode_plus(*args, **kwargs)
37
  if task_type is not None:
38
- batch_encoding = BatchEncoding(
39
- {
40
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
41
- **batch_encoding,
42
- },
43
- tensor_type=kwargs.get('return_tensors'),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
- return batch_encoding
 
 
 
 
 
46
 
47
  @staticmethod
48
  def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
49
- if isinstance(batch_encoding['input_ids'], torch.Tensor):
50
- shape = batch_encoding['input_ids'].shape
51
  return torch.ones(shape, dtype=torch.long) * task_type
52
  else:
53
- shape = torch.tensor(batch_encoding['input_ids']).shape
54
- if isinstance(batch_encoding['input_ids'], list):
55
  return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
56
- elif isinstance(batch_encoding['input_ids'], np.array):
57
  return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
58
  else:
59
  warnings.warn(
60
- 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
61
  )
62
  return torch.ones(shape, dtype=torch.long) * task_type
 
 
 
 
1
  import warnings
2
 
3
+ import numpy as np
4
+ import torch
5
+ from transformers import BatchEncoding, RobertaTokenizer
6
+
7
 
8
  class JinaTokenizer(RobertaTokenizer):
9
+ def __init__(
10
+ self, *args, task_type_vocab_size=6, cls_token_interval=None, **kwargs
11
+ ):
12
  super().__init__(*args, **kwargs)
13
  self.task_type_vocab_size = task_type_vocab_size
14
+ self.cls_token_interval = cls_token_interval
15
 
16
  def __call__(self, *args, task_type=None, **kwargs):
17
+ kwargs["task_type"] = task_type
18
+ return super().__call__(*args, **kwargs)
19
+
20
+ def _encode_plus(self, *args, **kwargs):
21
+ return self._process_encoding(super()._encode_plus(*args, **kwargs), **kwargs)
22
+
23
+ def _batch_encode_plus(self, *args, **kwargs):
24
+ return self._process_encoding(
25
+ super()._batch_encode_plus(*args, **kwargs), **kwargs
26
  )
 
27
 
28
+ def _process_encoding(self, batch_encoding: BatchEncoding, **kwargs):
29
+ task_type = kwargs.get("task_type")
30
+ if self.cls_token_interval is not None:
31
+ modified_input_ids, modified_attention_mask = self._insert_cls_tokens(
32
+ batch_encoding
 
 
 
 
33
  )
34
+ batch_encoding["input_ids"] = modified_input_ids
35
+ if "attention_mask" in batch_encoding:
36
+ print(batch_encoding["attention_mask"])
37
+ batch_encoding["attention_mask"] = modified_attention_mask
38
 
 
 
39
  if task_type is not None:
40
+ task_type_ids = self._get_task_type_ids(batch_encoding, task_type)
41
+ batch_encoding["task_type_ids"] = task_type_ids
42
+
43
+ return BatchEncoding(batch_encoding, tensor_type=kwargs.get("return_tensors"))
44
+
45
+ def _insert_cls_tokens(self, batch_encoding: BatchEncoding):
46
+ cls_token_id = self.cls_token_id
47
+ new_input_ids = []
48
+ new_attention_masks = []
49
+
50
+ sequences = batch_encoding["input_ids"].tolist()
51
+ original_attention_masks = batch_encoding["attention_mask"].tolist()
52
+
53
+ for seq_index, sequence in enumerate(sequences):
54
+ original_sequence_length = sum(original_attention_masks[seq_index])
55
+ num_cls_tokens_to_add = (
56
+ original_sequence_length - 1
57
+ ) // self.cls_token_interval
58
+ new_sequence_length = original_sequence_length + num_cls_tokens_to_add
59
+
60
+ modified_sequence = [sequence[0]]
61
+ for i in range(1, len(sequence), self.cls_token_interval):
62
+ chunk = sequence[i : i + self.cls_token_interval]
63
+ modified_sequence.extend(chunk)
64
+
65
+ if i + self.cls_token_interval < len(sequence):
66
+ modified_sequence.append(cls_token_id)
67
+
68
+ new_input_ids.append(modified_sequence)
69
+ new_attention_mask = [1] * new_sequence_length + [0] * (
70
+ len(modified_sequence) - new_sequence_length
71
  )
72
+ new_attention_masks.append(new_attention_mask)
73
+
74
+ new_input_ids = torch.tensor(new_input_ids, dtype=torch.long)
75
+ new_attention_masks = torch.tensor(new_attention_masks, dtype=torch.long)
76
+
77
+ return new_input_ids, new_attention_masks
78
 
79
  @staticmethod
80
  def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
81
+ if isinstance(batch_encoding["input_ids"], torch.Tensor):
82
+ shape = batch_encoding["input_ids"].shape
83
  return torch.ones(shape, dtype=torch.long) * task_type
84
  else:
85
+ shape = torch.tensor(batch_encoding["input_ids"]).shape
86
+ if isinstance(batch_encoding["input_ids"], list):
87
  return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
88
+ elif isinstance(batch_encoding["input_ids"], np.array):
89
  return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
90
  else:
91
  warnings.warn(
92
+ "input_ids is not a torch tensor, numpy array, or list. Returning torch tensor"
93
  )
94
  return torch.ones(shape, dtype=torch.long) * task_type