jupyterjazz commited on
Commit
eefe43c
1 Parent(s): 6cc0f51

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (5) hide show
  1. embedding.py +1 -2
  2. mha.py +5 -3
  3. mlp.py +2 -2
  4. modeling_lora.py +33 -35
  5. modeling_xlm_roberta.py +1 -1
embedding.py CHANGED
@@ -47,7 +47,6 @@ class XLMRobertaEmbeddings(nn.Module):
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
50
- print('input shape', input_ids.shape)
51
  embeddings = self.word_embeddings(input_ids, task='sts')
52
  if self.max_position_embeddings > 0:
53
  if position_ids is None:
@@ -58,6 +57,6 @@ class XLMRobertaEmbeddings(nn.Module):
58
  if self.type_vocab_size > 0:
59
  if token_type_ids is None:
60
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
61
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
62
  embeddings = embeddings + token_type_embeddings
63
  return embeddings
 
47
  token_type_ids: (batch, seqlen)
48
  """
49
  batch_size, seqlen = input_ids.shape
 
50
  embeddings = self.word_embeddings(input_ids, task='sts')
51
  if self.max_position_embeddings > 0:
52
  if position_ids is None:
 
57
  if self.type_vocab_size > 0:
58
  if token_type_ids is None:
59
  token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
60
+ token_type_embeddings = self.token_type_embeddings(token_type_ids, task='sts')
61
  embeddings = embeddings + token_type_embeddings
62
  return embeddings
mha.py CHANGED
@@ -341,6 +341,7 @@ class LinearResidual(nn.Linear):
341
  """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
 
343
  def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
 
344
  return super().forward(input, task=task), input
345
 
346
 
@@ -450,7 +451,7 @@ class MHA(nn.Module):
450
 
451
  if fused_bias_fc and FusedDense is None:
452
  raise ImportError("fused_dense is not installed")
453
- print('is this true', fused_bias_fc)
454
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
455
  linear_resid_cls = (
456
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
@@ -647,7 +648,8 @@ class MHA(nn.Module):
647
  if not self.return_residual:
648
  qkv = self.Wqkv(x)
649
  else:
650
- qkv, x = self.Wqkv(x, task='sts')
 
651
  if self.dwconv:
652
  qkv = rearrange(
653
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
@@ -732,5 +734,5 @@ class MHA(nn.Module):
732
  context = self._update_kvcache_attention(q, kv, inference_params)
733
  else:
734
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
735
- out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
736
  return out if not self.return_residual else (out, x)
 
341
  """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
342
 
343
  def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
344
+ print('aq vafshe ar modis?')
345
  return super().forward(input, task=task), input
346
 
347
 
 
451
 
452
  if fused_bias_fc and FusedDense is None:
453
  raise ImportError("fused_dense is not installed")
454
+
455
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
456
  linear_resid_cls = (
457
  LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
648
  if not self.return_residual:
649
  qkv = self.Wqkv(x)
650
  else:
651
+ qkv, x = self.Wqkv(x, task='query', residual=True)
652
+
653
  if self.dwconv:
654
  qkv = rearrange(
655
  self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
734
  context = self._update_kvcache_attention(q, kv, inference_params)
735
  else:
736
  context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
737
+ out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), task='passage')
738
  return out if not self.return_residual else (out, x)
mlp.py CHANGED
@@ -48,9 +48,9 @@ class Mlp(nn.Module):
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
  def forward(self, x):
51
- y = self.fc1(x)
52
  y = self.activation(y)
53
- y = self.fc2(y)
54
  return y if not self.return_residual else (y, x)
55
 
56
 
 
48
  self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
49
 
50
  def forward(self, x):
51
+ y = self.fc1(x, task='clustering')
52
  y = self.activation(y)
53
+ y = self.fc2(y, task='sts')
54
  return y if not self.return_residual else (y, x)
55
 
56
 
modeling_lora.py CHANGED
@@ -9,6 +9,7 @@ import torch
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
 
12
  from transformers import PretrainedConfig
13
 
14
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
@@ -98,8 +99,7 @@ class LoRAParametrization(nn.Module):
98
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
99
  return A * self.lora_dropout(self.lora_dropout_mask)
100
 
101
- def lora_forward(self, X, current_task=None):
102
- print('lora input shape', X.shape)
103
  return (
104
  X
105
  + torch.matmul(
@@ -114,10 +114,7 @@ class LoRAParametrization(nn.Module):
114
  )
115
 
116
  def forward(self, X):
117
- print('forward input shape', X.shape, X)
118
- out = self.forward_fn(X)
119
- print(out.shape)
120
- return out
121
 
122
  @property
123
  def current_task(self):
@@ -195,13 +192,20 @@ class LoRAParametrization(nn.Module):
195
  alpha=alpha,
196
  ),
197
  )
198
- original_forward = layer.forward
199
 
200
- def new_forward(self, input, task):
201
- print('an aq mitxari aba')
202
- output = original_forward(input, task=task)
203
- weight = self.parametrizations.weight(self.weight, task)
204
- return nn.functional.linear(input, weight, self.bias)
 
 
 
 
 
 
 
 
205
 
206
  layer.forward = new_forward.__get__(layer, layer.__class__)
207
 
@@ -217,20 +221,20 @@ class LoRAParametrization(nn.Module):
217
  alpha=alpha,
218
  ),
219
  )
220
- original_forward = layer.forward
221
 
222
  def new_forward(self, input, task):
223
- print('input here', input, input.shape)
224
- print('func', original_forward)
225
- # original_forward['parametrizations'] = None
226
- # print('funcc', original_forward.__dict__)
227
- output = original_forward(input)
228
- print(output.shape, 'output shape')
229
  task_idx = adaptation_map[task] if task else None
230
  if task_idx:
231
- output = self.parametrizations.weight[0].lora_forward(output, current_task=task_idx)
232
- print('thats it')
233
- return output
 
 
 
 
 
 
 
234
 
235
  layer.forward = new_forward.__get__(layer, layer.__class__)
236
 
@@ -278,13 +282,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
278
  self._task_idx = None
279
  # By default, disable LoRA until it's specified which adapter/task to use
280
  self.current_task = None
281
- for name, param in super().named_parameters():
282
- if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_A':
283
- print('A0', param[0])
284
- print('A1', param[1])
285
- if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_B':
286
- print('B0', param[0])
287
- print('B1', param[1])
288
 
289
  @property
290
  def main_params_trainable(self):
@@ -364,12 +362,12 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
364
  f"Alternatively, set `task` to `None` if you want to disable LoRA."
365
  )
366
  task_idx = self._adaptation_map[task_name] if task_name else None
367
- if self._task_idx != task_idx:
368
- # In this case, we need to update the LoRAs everywhere
369
- self._task_idx = task_idx
370
- self.apply(
371
- partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
372
- )
373
 
374
  def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
375
  if task != LORA_NO_UPDATE:
 
9
  import torch.nn.utils.parametrize as parametrize
10
  from torch import nn
11
  from torch.nn import Parameter
12
+ from torch.nn import functional as F
13
  from transformers import PretrainedConfig
14
 
15
  from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
 
99
  # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
100
  return A * self.lora_dropout(self.lora_dropout_mask)
101
 
102
+ def lora_forward(self, X, current_task):
 
103
  return (
104
  X
105
  + torch.matmul(
 
114
  )
115
 
116
  def forward(self, X):
117
+ return X
 
 
 
118
 
119
  @property
120
  def current_task(self):
 
192
  alpha=alpha,
193
  ),
194
  )
 
195
 
196
+ def new_forward(self, input, task, residual=False):
197
+ task_idx = adaptation_map[task] if task else None
198
+ if task_idx:
199
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
200
+ else:
201
+ weights = self.weight
202
+
203
+ out = F.linear(input, weights, self.bias)
204
+
205
+ print('lin', task_idx, input.shape, out.shape)
206
+ if residual:
207
+ return out, input
208
+ return out
209
 
210
  layer.forward = new_forward.__get__(layer, layer.__class__)
211
 
 
221
  alpha=alpha,
222
  ),
223
  )
 
224
 
225
  def new_forward(self, input, task):
 
 
 
 
 
 
226
  task_idx = adaptation_map[task] if task else None
227
  if task_idx:
228
+ weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
229
+ else:
230
+ weights = self.weight
231
+
232
+ out = F.embedding(
233
+ input, weights, self.padding_idx, self.max_norm,
234
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
235
+
236
+ print('emb', task_idx, input.shape, out.shape)
237
+ return out
238
 
239
  layer.forward = new_forward.__get__(layer, layer.__class__)
240
 
 
282
  self._task_idx = None
283
  # By default, disable LoRA until it's specified which adapter/task to use
284
  self.current_task = None
285
+
 
 
 
 
 
 
286
 
287
  @property
288
  def main_params_trainable(self):
 
362
  f"Alternatively, set `task` to `None` if you want to disable LoRA."
363
  )
364
  task_idx = self._adaptation_map[task_name] if task_name else None
365
+ # if self._task_idx != task_idx:
366
+ # # In this case, we need to update the LoRAs everywhere
367
+ # self._task_idx = task_idx
368
+ # self.apply(
369
+ # partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
370
+ # )
371
 
372
  def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
373
  if task != LORA_NO_UPDATE:
modeling_xlm_roberta.py CHANGED
@@ -313,7 +313,7 @@ class XLMRobertaPooler(nn.Module):
313
  # We "pool" the model by simply taking the hidden state corresponding
314
  # to the first token.
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
- pooled_output = self.dense(first_token_tensor)
317
  pooled_output = self.activation(pooled_output)
318
  return pooled_output
319
 
 
313
  # We "pool" the model by simply taking the hidden state corresponding
314
  # to the first token.
315
  first_token_tensor = hidden_states[:, 0] if pool else hidden_states
316
+ pooled_output = self.dense(first_token_tensor, task='passage')
317
  pooled_output = self.activation(pooled_output)
318
  return pooled_output
319