Jackmin108 commited on
Commit
c1736a8
·
1 Parent(s): 362ef00

feat: adapter masking wip

Browse files

Signed-off-by: Meow <ongjackm@gmail.com>

Files changed (4) hide show
  1. embedding.py +21 -4
  2. modeling_lora.py +10 -2
  3. modeling_xlm_roberta.py +7 -6
  4. xlm_padding.py +9 -1
embedding.py CHANGED
@@ -40,7 +40,7 @@ class XLMRobertaEmbeddings(nn.Module):
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
@@ -55,9 +55,25 @@ class XLMRobertaEmbeddings(nn.Module):
55
  emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
56
  emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
57
  embeddings = torch.cat((emb1, emb2), dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
60
- embeddings = self.word_embeddings(input_ids, **lora_kwargs)
 
 
 
 
61
 
62
  if self.max_position_embeddings > 0:
63
  if position_ids is None:
@@ -79,7 +95,8 @@ class XLMRobertaEmbeddings(nn.Module):
79
  emb2 = emb2 + token_type_embs2
80
  embeddings = torch.cat((emb1, emb2), dim=0)
81
  else:
82
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
 
83
  token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
84
  embeddings = embeddings + token_type_embeddings
85
  return embeddings
 
40
  if self.type_vocab_size > 0:
41
  self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
42
 
43
+ def forward(self, input_ids, position_ids=None, token_type_ids=None, task_type=None, adapter_mask=None):
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
 
55
  emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
56
  emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
57
  embeddings = torch.cat((emb1, emb2), dim=0)
58
+
59
+ unique_tasks = torch.unique(adapter_mask).tolist()
60
+ torch_dtype = next(self.word_embeddings.parameters()).dtype
61
+ embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim, dtype=torch_dtype).to(input_ids.device)
62
+ for task in unique_tasks:
63
+ indices = (adapter_mask == task).nonzero(as_tuple=True)[0]
64
+ inp = input_ids[indices]
65
+ lora_kwargs = {'task_type': task} if task is not None else {}
66
+ emb = self.word_embeddings(inp, **lora_kwargs)
67
+ embeddings[indices] = emb
68
+
69
+ exit(0)
70
  else:
71
+ unique_task = torch.unique(adapter_mask)[0]
72
+ task1_indices = (adapter_mask == unique_task).nonzero(as_tuple=True)[0]
73
+ input1 = input_ids[task1_indices]
74
+ lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
75
+ embeddings = self.word_embeddings(input1, **lora_kwargs)
76
+
77
 
78
  if self.max_position_embeddings > 0:
79
  if position_ids is None:
 
95
  emb2 = emb2 + token_type_embs2
96
  embeddings = torch.cat((emb1, emb2), dim=0)
97
  else:
98
+ unique_task = torch.unique(adapter_mask)[0]
99
+ lora_kwargs = {'task_type': unique_task} if unique_task is not None else {}
100
  token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
101
  embeddings = embeddings + token_type_embeddings
102
  return embeddings
modeling_lora.py CHANGED
@@ -177,7 +177,11 @@ class LoRAParametrization(nn.Module):
177
  )
178
 
179
  def new_forward(self, input, task_type, residual=False):
180
- task_idx = adaptation_map[task_type] if task_type else None
 
 
 
 
181
  if task_idx is not None:
182
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
183
  else:
@@ -205,7 +209,11 @@ class LoRAParametrization(nn.Module):
205
  )
206
 
207
  def new_forward(self, input, task_type):
208
- task_idx = adaptation_map[task_type] if task_type else None
 
 
 
 
209
  if task_idx is not None:
210
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
211
  else:
 
177
  )
178
 
179
  def new_forward(self, input, task_type, residual=False):
180
+ if isinstance(task_type, str):
181
+ task_idx = adaptation_map[task_type] if task_type else None
182
+ else:
183
+ task_idx = task_type
184
+
185
  if task_idx is not None:
186
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
187
  else:
 
209
  )
210
 
211
  def new_forward(self, input, task_type):
212
+ if isinstance(task_type, str):
213
+ task_idx = adaptation_map[task_type] if task_type else None
214
+ else:
215
+ task_idx = task_type
216
+
217
  if task_idx is not None:
218
  weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
219
  else:
modeling_xlm_roberta.py CHANGED
@@ -204,7 +204,7 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
@@ -230,10 +230,10 @@ class XLMRobertaEncoder(nn.Module):
230
  hidden_states = hidden_states[subset_mask]
231
  else:
232
  batch, seqlen = hidden_states.shape[:2]
233
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
234
- hidden_states, key_padding_mask
235
  )
236
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
@@ -649,6 +649,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
649
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
650
  """
651
  task_type = kwargs.pop('task_type', None)
 
652
  if kwargs:
653
  for key, value in kwargs.items():
654
  if value is not None:
@@ -662,7 +663,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
662
  )
663
 
664
  hidden_states = self.embeddings(
665
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type
666
  )
667
  # TD [2022-12:18]: Don't need to force residual in fp32
668
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -686,7 +687,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
686
  subset_mask = None
687
 
688
  sequence_output = self.encoder(
689
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type
690
  )
691
 
692
  if masked_tokens_mask is None:
 
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
+ def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, task_type=None, adapter_mask=None):
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
 
230
  hidden_states = hidden_states[subset_mask]
231
  else:
232
  batch, seqlen = hidden_states.shape[:2]
233
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
234
+ hidden_states, key_padding_mask, adapter_mask
235
  )
236
+ mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "task_type": task_type, "cu_adapter_mask": cu_adapter_mask}
237
  if subset_mask is None:
238
  for layer in self.layers:
239
  if self._grad_checkpointing:
 
649
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
650
  """
651
  task_type = kwargs.pop('task_type', None)
652
+ adapter_mask = kwargs.pop('adapter_mask', None)
653
  if kwargs:
654
  for key, value in kwargs.items():
655
  if value is not None:
 
663
  )
664
 
665
  hidden_states = self.embeddings(
666
+ input_ids, position_ids=position_ids, token_type_ids=token_type_ids, task_type=task_type, adapter_mask=adapter_mask
667
  )
668
  # TD [2022-12:18]: Don't need to force residual in fp32
669
  # BERT puts embedding LayerNorm before embedding dropout.
 
687
  subset_mask = None
688
 
689
  sequence_output = self.encoder(
690
+ hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, task_type=task_type, adapter_mask=adapter_mask
691
  )
692
 
693
  if masked_tokens_mask is None:
xlm_padding.py CHANGED
@@ -98,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
- def unpad_input(hidden_states, attention_mask):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
@@ -113,6 +113,13 @@ def unpad_input(hidden_states, attention_mask):
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
 
 
 
 
 
116
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
117
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
118
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -123,6 +130,7 @@ def unpad_input(hidden_states, attention_mask):
123
  indices,
124
  cu_seqlens,
125
  max_seqlen_in_batch,
 
126
  )
127
 
128
 
 
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
+ def unpad_input(hidden_states, attention_mask, adapter_mask):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
 
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
+
117
+ cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
118
+ for i in range(len(adapter_mask)):
119
+ start_idx = cu_seqlens[i]
120
+ end_idx = cu_seqlens[i + 1]
121
+ cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
122
+
123
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
124
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
125
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
130
  indices,
131
  cu_seqlens,
132
  max_seqlen_in_batch,
133
+ cu_adapter_mask,
134
  )
135
 
136