michael-guenther commited on
Commit
6170b43
1 Parent(s): 326b1c4

support multiple task ids

Browse files
Files changed (1) hide show
  1. tokenizer.py +31 -13
tokenizer.py CHANGED
@@ -11,13 +11,14 @@ class JinaTokenizer(RobertaTokenizer):
11
 
12
  def __call__(self, *args, task_type=None, **kwargs):
13
  batch_encoding = super().__call__(*args, **kwargs)
14
- batch_encoding = BatchEncoding(
15
- {
16
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
17
- **batch_encoding,
18
- },
19
- tensor_type=kwargs.get('return_tensors'),
20
- )
 
21
  return batch_encoding
22
 
23
  def _batch_encode_plus(self, *args, task_type=None, **kwargs):
@@ -45,18 +46,35 @@ class JinaTokenizer(RobertaTokenizer):
45
  return batch_encoding
46
 
47
  @staticmethod
48
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type: int):
 
 
 
 
 
49
  if isinstance(batch_encoding['input_ids'], torch.Tensor):
50
  shape = batch_encoding['input_ids'].shape
51
- return torch.ones(shape, dtype=torch.long) * task_type
52
  else:
53
- shape = torch.tensor(batch_encoding['input_ids']).shape
 
 
 
 
 
 
 
 
54
  if isinstance(batch_encoding['input_ids'], list):
55
- return (torch.ones(shape, dtype=torch.long) * task_type).tolist()
 
 
56
  elif isinstance(batch_encoding['input_ids'], np.array):
57
- return (torch.ones(shape, dtype=torch.long) * task_type).numpy()
 
 
58
  else:
59
  warnings.warn(
60
  'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
61
  )
62
- return torch.ones(shape, dtype=torch.long) * task_type
 
11
 
12
  def __call__(self, *args, task_type=None, **kwargs):
13
  batch_encoding = super().__call__(*args, **kwargs)
14
+ if task_type is not None:
15
+ batch_encoding = BatchEncoding(
16
+ {
17
+ 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
18
+ **batch_encoding,
19
+ },
20
+ tensor_type=kwargs.get('return_tensors'),
21
+ )
22
  return batch_encoding
23
 
24
  def _batch_encode_plus(self, *args, task_type=None, **kwargs):
 
46
  return batch_encoding
47
 
48
  @staticmethod
49
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
50
+
51
+ def apply_task_type(m, x):
52
+ x = torch.tensor(x)
53
+ return m * x if len(x.shape) == 0 else m * x[:, None]
54
+
55
  if isinstance(batch_encoding['input_ids'], torch.Tensor):
56
  shape = batch_encoding['input_ids'].shape
57
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
58
  else:
59
+ try:
60
+ shape = torch.tensor(batch_encoding['input_ids']).shape
61
+ except:
62
+ raise ValueError(
63
+ "Unable to create tensor, you should probably "
64
+ "activate truncation and/or padding with "
65
+ "'padding=True' 'truncation=True' to have batched "
66
+ "tensors with the same length."
67
+ )
68
  if isinstance(batch_encoding['input_ids'], list):
69
+ return (
70
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
71
+ ).tolist()
72
  elif isinstance(batch_encoding['input_ids'], np.array):
73
+ return (
74
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
75
+ ).numpy()
76
  else:
77
  warnings.warn(
78
  'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
79
  )
80
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)