jupyterjazz
commited on
Commit
•
eefe43c
1
Parent(s):
6cc0f51
poc
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- embedding.py +1 -2
- mha.py +5 -3
- mlp.py +2 -2
- modeling_lora.py +33 -35
- modeling_xlm_roberta.py +1 -1
embedding.py
CHANGED
@@ -47,7 +47,6 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
-
print('input shape', input_ids.shape)
|
51 |
embeddings = self.word_embeddings(input_ids, task='sts')
|
52 |
if self.max_position_embeddings > 0:
|
53 |
if position_ids is None:
|
@@ -58,6 +57,6 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
58 |
if self.type_vocab_size > 0:
|
59 |
if token_type_ids is None:
|
60 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
61 |
-
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
62 |
embeddings = embeddings + token_type_embeddings
|
63 |
return embeddings
|
|
|
47 |
token_type_ids: (batch, seqlen)
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
|
|
50 |
embeddings = self.word_embeddings(input_ids, task='sts')
|
51 |
if self.max_position_embeddings > 0:
|
52 |
if position_ids is None:
|
|
|
57 |
if self.type_vocab_size > 0:
|
58 |
if token_type_ids is None:
|
59 |
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
60 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids, task='sts')
|
61 |
embeddings = embeddings + token_type_embeddings
|
62 |
return embeddings
|
mha.py
CHANGED
@@ -341,6 +341,7 @@ class LinearResidual(nn.Linear):
|
|
341 |
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
342 |
|
343 |
def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
|
|
|
344 |
return super().forward(input, task=task), input
|
345 |
|
346 |
|
@@ -450,7 +451,7 @@ class MHA(nn.Module):
|
|
450 |
|
451 |
if fused_bias_fc and FusedDense is None:
|
452 |
raise ImportError("fused_dense is not installed")
|
453 |
-
|
454 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
455 |
linear_resid_cls = (
|
456 |
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
@@ -647,7 +648,8 @@ class MHA(nn.Module):
|
|
647 |
if not self.return_residual:
|
648 |
qkv = self.Wqkv(x)
|
649 |
else:
|
650 |
-
qkv, x = self.Wqkv(x, task='
|
|
|
651 |
if self.dwconv:
|
652 |
qkv = rearrange(
|
653 |
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
@@ -732,5 +734,5 @@ class MHA(nn.Module):
|
|
732 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
733 |
else:
|
734 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
735 |
-
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
736 |
return out if not self.return_residual else (out, x)
|
|
|
341 |
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
342 |
|
343 |
def forward(self, input: torch.Tensor, task=None) -> torch.Tensor:
|
344 |
+
print('aq vafshe ar modis?')
|
345 |
return super().forward(input, task=task), input
|
346 |
|
347 |
|
|
|
451 |
|
452 |
if fused_bias_fc and FusedDense is None:
|
453 |
raise ImportError("fused_dense is not installed")
|
454 |
+
|
455 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
456 |
linear_resid_cls = (
|
457 |
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
|
|
648 |
if not self.return_residual:
|
649 |
qkv = self.Wqkv(x)
|
650 |
else:
|
651 |
+
qkv, x = self.Wqkv(x, task='query', residual=True)
|
652 |
+
|
653 |
if self.dwconv:
|
654 |
qkv = rearrange(
|
655 |
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
|
|
734 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
735 |
else:
|
736 |
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
737 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"), task='passage')
|
738 |
return out if not self.return_residual else (out, x)
|
mlp.py
CHANGED
@@ -48,9 +48,9 @@ class Mlp(nn.Module):
|
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
def forward(self, x):
|
51 |
-
y = self.fc1(x)
|
52 |
y = self.activation(y)
|
53 |
-
y = self.fc2(y)
|
54 |
return y if not self.return_residual else (y, x)
|
55 |
|
56 |
|
|
|
48 |
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
|
50 |
def forward(self, x):
|
51 |
+
y = self.fc1(x, task='clustering')
|
52 |
y = self.activation(y)
|
53 |
+
y = self.fc2(y, task='sts')
|
54 |
return y if not self.return_residual else (y, x)
|
55 |
|
56 |
|
modeling_lora.py
CHANGED
@@ -9,6 +9,7 @@ import torch
|
|
9 |
import torch.nn.utils.parametrize as parametrize
|
10 |
from torch import nn
|
11 |
from torch.nn import Parameter
|
|
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
@@ -98,8 +99,7 @@ class LoRAParametrization(nn.Module):
|
|
98 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
99 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
100 |
|
101 |
-
def lora_forward(self, X, current_task
|
102 |
-
print('lora input shape', X.shape)
|
103 |
return (
|
104 |
X
|
105 |
+ torch.matmul(
|
@@ -114,10 +114,7 @@ class LoRAParametrization(nn.Module):
|
|
114 |
)
|
115 |
|
116 |
def forward(self, X):
|
117 |
-
|
118 |
-
out = self.forward_fn(X)
|
119 |
-
print(out.shape)
|
120 |
-
return out
|
121 |
|
122 |
@property
|
123 |
def current_task(self):
|
@@ -195,13 +192,20 @@ class LoRAParametrization(nn.Module):
|
|
195 |
alpha=alpha,
|
196 |
),
|
197 |
)
|
198 |
-
original_forward = layer.forward
|
199 |
|
200 |
-
def new_forward(self, input, task):
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
207 |
|
@@ -217,20 +221,20 @@ class LoRAParametrization(nn.Module):
|
|
217 |
alpha=alpha,
|
218 |
),
|
219 |
)
|
220 |
-
original_forward = layer.forward
|
221 |
|
222 |
def new_forward(self, input, task):
|
223 |
-
print('input here', input, input.shape)
|
224 |
-
print('func', original_forward)
|
225 |
-
# original_forward['parametrizations'] = None
|
226 |
-
# print('funcc', original_forward.__dict__)
|
227 |
-
output = original_forward(input)
|
228 |
-
print(output.shape, 'output shape')
|
229 |
task_idx = adaptation_map[task] if task else None
|
230 |
if task_idx:
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
236 |
|
@@ -278,13 +282,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
278 |
self._task_idx = None
|
279 |
# By default, disable LoRA until it's specified which adapter/task to use
|
280 |
self.current_task = None
|
281 |
-
|
282 |
-
if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_A':
|
283 |
-
print('A0', param[0])
|
284 |
-
print('A1', param[1])
|
285 |
-
if name == 'roberta.encoder.layers.22.mixer.Wqkv.parametrizations.weight.0.lora_B':
|
286 |
-
print('B0', param[0])
|
287 |
-
print('B1', param[1])
|
288 |
|
289 |
@property
|
290 |
def main_params_trainable(self):
|
@@ -364,12 +362,12 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
364 |
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
365 |
)
|
366 |
task_idx = self._adaptation_map[task_name] if task_name else None
|
367 |
-
if self._task_idx != task_idx:
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
|
374 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
375 |
if task != LORA_NO_UPDATE:
|
|
|
9 |
import torch.nn.utils.parametrize as parametrize
|
10 |
from torch import nn
|
11 |
from torch.nn import Parameter
|
12 |
+
from torch.nn import functional as F
|
13 |
from transformers import PretrainedConfig
|
14 |
|
15 |
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
|
|
|
99 |
# to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
|
100 |
return A * self.lora_dropout(self.lora_dropout_mask)
|
101 |
|
102 |
+
def lora_forward(self, X, current_task):
|
|
|
103 |
return (
|
104 |
X
|
105 |
+ torch.matmul(
|
|
|
114 |
)
|
115 |
|
116 |
def forward(self, X):
|
117 |
+
return X
|
|
|
|
|
|
|
118 |
|
119 |
@property
|
120 |
def current_task(self):
|
|
|
192 |
alpha=alpha,
|
193 |
),
|
194 |
)
|
|
|
195 |
|
196 |
+
def new_forward(self, input, task, residual=False):
|
197 |
+
task_idx = adaptation_map[task] if task else None
|
198 |
+
if task_idx:
|
199 |
+
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
200 |
+
else:
|
201 |
+
weights = self.weight
|
202 |
+
|
203 |
+
out = F.linear(input, weights, self.bias)
|
204 |
+
|
205 |
+
print('lin', task_idx, input.shape, out.shape)
|
206 |
+
if residual:
|
207 |
+
return out, input
|
208 |
+
return out
|
209 |
|
210 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
211 |
|
|
|
221 |
alpha=alpha,
|
222 |
),
|
223 |
)
|
|
|
224 |
|
225 |
def new_forward(self, input, task):
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
task_idx = adaptation_map[task] if task else None
|
227 |
if task_idx:
|
228 |
+
weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_idx)
|
229 |
+
else:
|
230 |
+
weights = self.weight
|
231 |
+
|
232 |
+
out = F.embedding(
|
233 |
+
input, weights, self.padding_idx, self.max_norm,
|
234 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
235 |
+
|
236 |
+
print('emb', task_idx, input.shape, out.shape)
|
237 |
+
return out
|
238 |
|
239 |
layer.forward = new_forward.__get__(layer, layer.__class__)
|
240 |
|
|
|
282 |
self._task_idx = None
|
283 |
# By default, disable LoRA until it's specified which adapter/task to use
|
284 |
self.current_task = None
|
285 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
@property
|
288 |
def main_params_trainable(self):
|
|
|
362 |
f"Alternatively, set `task` to `None` if you want to disable LoRA."
|
363 |
)
|
364 |
task_idx = self._adaptation_map[task_name] if task_name else None
|
365 |
+
# if self._task_idx != task_idx:
|
366 |
+
# # In this case, we need to update the LoRAs everywhere
|
367 |
+
# self._task_idx = task_idx
|
368 |
+
# self.apply(
|
369 |
+
# partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
370 |
+
# )
|
371 |
|
372 |
def forward(self, *args, task: Union[str, None] = LORA_NO_UPDATE, **kwargs):
|
373 |
if task != LORA_NO_UPDATE:
|
modeling_xlm_roberta.py
CHANGED
@@ -313,7 +313,7 @@ class XLMRobertaPooler(nn.Module):
|
|
313 |
# We "pool" the model by simply taking the hidden state corresponding
|
314 |
# to the first token.
|
315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
316 |
-
pooled_output = self.dense(first_token_tensor)
|
317 |
pooled_output = self.activation(pooled_output)
|
318 |
return pooled_output
|
319 |
|
|
|
313 |
# We "pool" the model by simply taking the hidden state corresponding
|
314 |
# to the first token.
|
315 |
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
316 |
+
pooled_output = self.dense(first_token_tensor, task='passage')
|
317 |
pooled_output = self.activation(pooled_output)
|
318 |
return pooled_output
|
319 |
|