Jackmin108 commited on
Commit
362ef00
1 Parent(s): e860caa

feat: 2 adapter tuning

Browse files

Signed-off-by: Meow <ongjackm@gmail.com>

Files changed (6) hide show
  1. block.py +11 -1
  2. embedding.py +26 -4
  3. mha.py +37 -5
  4. mlp.py +21 -3
  5. modeling_lora.py +0 -1
  6. modeling_xlm_roberta.py +12 -1
block.py CHANGED
@@ -233,7 +233,17 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'))
 
 
 
 
 
 
 
 
 
 
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
+ task_type = mixer_kwargs.get('task_type')
237
+ if task_type:
238
+ if isinstance(task_type, tuple):
239
+ assert mixer_kwargs['cu_seqlens'].shape[0] % 9 == 1
240
+ split_index = int((mixer_kwargs['cu_seqlens'].shape[0] - 1) / 9)
241
+ split = mixer_kwargs['cu_seqlens'][split_index]
242
+ mlp_out = self.mlp(hidden_states, task_type=mixer_kwargs.get('task_type'), split=split)
243
+ else:
244
+ mlp_out = self.mlp(hidden_states, task_type=task_type)
245
+ else:
246
+ mlp_out = self.mlp(hidden_states)
247
  if self.return_residual: # mlp out is actually a pair here
248
  mlp_out, hidden_states = mlp_out
249
  if not self.fused_dropout_add_ln:
embedding.py CHANGED
@@ -47,8 +47,18 @@ class XLMRobertaEmbeddings(nn.Module):
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- lora_kwargs = {'task_type': task_type} if task_type is not None else {}
51
- embeddings = self.word_embeddings(input_ids, **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
54
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
@@ -58,6 +68,18 @@ class XLMRobertaEmbeddings(nn.Module):
58
  if self.type_vocab_size > 0:
59
  if token_type_ids is None:
60
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
61
- token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
62
- embeddings = embeddings + token_type_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
63
  return embeddings
 
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
+ if isinstance(task_type, tuple):
51
+ assert input_ids.shape[0] % 9 == 0
52
+ split = int(input_ids.shape[0] / 9)
53
+ tensor1 = input_ids[:split, :]
54
+ tensor2 = input_ids[split:, :]
55
+ emb1 = self.word_embeddings(tensor1, task_type=task_type[0])
56
+ emb2 = self.word_embeddings(tensor2, task_type=task_type[1])
57
+ embeddings = torch.cat((emb1, emb2), dim=0)
58
+ else:
59
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
60
+ embeddings = self.word_embeddings(input_ids, **lora_kwargs)
61
+
62
  if self.max_position_embeddings > 0:
63
  if position_ids is None:
64
  position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
 
68
  if self.type_vocab_size > 0:
69
  if token_type_ids is None:
70
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
71
+ if isinstance(task_type, tuple):
72
+ assert embeddings.shape[0] % 9 == 0
73
+ split = int(embeddings.shape[0] / 9)
74
+ emb1 = embeddings[:split, :, :]
75
+ emb2 = embeddings[split:, :, :]
76
+ token_type_embs1 = self.token_type_embeddings(token_type_ids, task_type=task_type[0])
77
+ token_type_embs2 = self.token_type_embeddings(token_type_ids, task_type=task_type[1])
78
+ emb1 = emb1 + token_type_embs1
79
+ emb2 = emb2 + token_type_embs2
80
+ embeddings = torch.cat((emb1, emb2), dim=0)
81
+ else:
82
+ lora_kwargs = {'task_type': task_type} if task_type is not None else {}
83
+ token_type_embeddings = self.token_type_embeddings(token_type_ids, **lora_kwargs)
84
+ embeddings = embeddings + token_type_embeddings
85
  return embeddings
mha.py CHANGED
@@ -643,15 +643,39 @@ class MHA(nn.Module):
643
  inference_params.max_sequence_len if inference_params is not None else max_seqlen
644
  )
645
  batch, seqlen = x.shape[:2]
 
646
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
647
  assert x_kv is None and mixer_subset is None
 
 
 
 
 
 
 
648
  lora_kwargs = {'task_type': task_type} if task_type is not None else {}
 
649
  if not self.return_residual:
650
- qkv = self.Wqkv(x, **lora_kwargs)
 
 
 
 
 
 
 
651
  else:
652
- if lora_kwargs:
653
- lora_kwargs['residual'] = True
654
- qkv, x = self.Wqkv(x, **lora_kwargs)
 
 
 
 
 
 
 
 
655
 
656
  if self.dwconv:
657
  qkv = rearrange(
@@ -739,5 +763,13 @@ class MHA(nn.Module):
739
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
740
 
741
  lora_kwargs.pop('residual', None)
742
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), **lora_kwargs)
 
 
 
 
 
 
 
 
743
  return out if not self.return_residual else (out, x)
 
643
  inference_params.max_sequence_len if inference_params is not None else max_seqlen
644
  )
645
  batch, seqlen = x.shape[:2]
646
+ lora_kwargs = {}
647
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
  assert x_kv is None and mixer_subset is None
649
+
650
+ split = None
651
+ if isinstance(task_type, tuple):
652
+ assert cu_seqlens.shape[0] % 9 == 1
653
+ split_index = int((cu_seqlens.shape[0] - 1) / 9)
654
+ split = cu_seqlens[split_index]
655
+
656
  lora_kwargs = {'task_type': task_type} if task_type is not None else {}
657
+
658
  if not self.return_residual:
659
+ if isinstance(task_type, tuple):
660
+ tensor1 = x[:split, :]
661
+ tensor2 = x[split:, :]
662
+ qkv1 = self.Wqkv(tensor1, task_type=task_type[0])
663
+ qkv2 = self.Wqkv(tensor2, task_type=task_type[1])
664
+ qkv = torch.cat((qkv1, qkv2), dim=0)
665
+ else:
666
+ qkv = self.Wqkv(x, **lora_kwargs)
667
  else:
668
+ if isinstance(task_type, tuple):
669
+ tensor1 = x[:split, :]
670
+ tensor2 = x[split:, :]
671
+ qkv1, tensor1 = self.Wqkv(tensor1, task_type=task_type[0], residual=True)
672
+ qkv2, tensor2 = self.Wqkv(tensor2, task_type=task_type[1], residual=True)
673
+ qkv = torch.cat((qkv1, qkv2), dim=0)
674
+ x = torch.cat((tensor1, tensor2), dim=0)
675
+ else:
676
+ if lora_kwargs:
677
+ lora_kwargs['residual'] = True
678
+ qkv, x = self.Wqkv(x, **lora_kwargs)
679
 
680
  if self.dwconv:
681
  qkv = rearrange(
 
763
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
764
 
765
  lora_kwargs.pop('residual', None)
766
+ inp = rearrange(context, "... h d -> ... (h d)")
767
+ if isinstance(task_type, tuple):
768
+ tensor1 = inp[:split, :]
769
+ tensor2 = inp[split:, :]
770
+ out1 = self.out_proj(tensor1, task_type=task_type[0])
771
+ out2 = self.out_proj(tensor2, task_type=task_type[1])
772
+ out = torch.cat((out1, out2), dim=0)
773
+ else:
774
+ out = self.out_proj(inp, **lora_kwargs)
775
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -47,11 +47,29 @@ class Mlp(nn.Module):
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
- def forward(self, x, task_type=None):
51
  lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
- y = self.fc1(x, **lora_kwargs)
 
 
 
 
 
 
 
 
 
53
  y = self.activation(y)
54
- y = self.fc2(y, **lora_kwargs)
 
 
 
 
 
 
 
 
 
55
  return y if not self.return_residual else (y, x)
56
 
57
 
 
47
  self.activation = activation
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
+ def forward(self, x, task_type=None, split=None):
51
  lora_kwargs = {'task_type': task_type} if task_type is not None else {}
52
+ if split:
53
+ assert isinstance(task_type, tuple)
54
+ tensor1 = x[:split, :]
55
+ tensor2 = x[split:, :]
56
+ y1 = self.fc1(tensor1, task_type=task_type[0])
57
+ y2 = self.fc1(tensor2, task_type=task_type[1])
58
+ y = torch.cat((y1, y2), dim=0)
59
+ else:
60
+ y = self.fc1(x, **lora_kwargs)
61
+
62
  y = self.activation(y)
63
+
64
+ if split:
65
+ assert isinstance(task_type, tuple)
66
+ tensor1 = y[:split, :]
67
+ tensor2 = y[split:, :]
68
+ y1 = self.fc2(tensor1, task_type=task_type[0])
69
+ y2 = self.fc2(tensor2, task_type=task_type[1])
70
+ y = torch.cat((y1, y2), dim=0)
71
+ else:
72
+ y = self.fc2(y, **lora_kwargs)
73
  return y if not self.return_residual else (y, x)
74
 
75
 
modeling_lora.py CHANGED
@@ -227,7 +227,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
227
  roberta: Optional[XLMRobertaModel] = None
228
  ):
229
  super().__init__(config)
230
-
231
  if roberta is None:
232
  self.roberta = XLMRobertaModel(config)
233
  else:
 
227
  roberta: Optional[XLMRobertaModel] = None
228
  ):
229
  super().__init__(config)
 
230
  if roberta is None:
231
  self.roberta = XLMRobertaModel(config)
232
  else:
modeling_xlm_roberta.py CHANGED
@@ -316,7 +316,18 @@ class XLMRobertaPooler(nn.Module):
316
  lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
 
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
- pooled_output = self.dense(first_token_tensor, **lora_kwargs)
 
 
 
 
 
 
 
 
 
 
 
320
  pooled_output = self.activation(pooled_output)
321
  return pooled_output
322
 
 
316
  lora_kwargs = {'task_type': task_type} if task_type is not None else {}
317
 
318
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
319
+
320
+ if isinstance(task_type, tuple):
321
+ assert first_token_tensor.shape[0] % 9 == 0
322
+ split = int(first_token_tensor.shape[0] / 9)
323
+ tensor1 = first_token_tensor[:split, :]
324
+ tensor2 = first_token_tensor[split:, :]
325
+ pooled_out1 = self.dense(tensor1, task_type=task_type[0])
326
+ pooled_out2 = self.dense(tensor2, task_type=task_type[0])
327
+ pooled_output = torch.cat((pooled_out1, pooled_out2), dim=0)
328
+ else:
329
+ pooled_output = self.dense(first_token_tensor, **lora_kwargs)
330
+
331
  pooled_output = self.activation(pooled_output)
332
  return pooled_output
333