jupyterjazz commited on
Commit
5ed05aa
1 Parent(s): 77af1c7

feat: support lora

Browse files

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>

Files changed (2) hide show
  1. configuration_xlm_roberta.py +3 -1
  2. modeling_lora.py +327 -0
configuration_xlm_roberta.py CHANGED
@@ -21,6 +21,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
21
  position_embedding_type="absolute",
22
  use_cache=True,
23
  classifier_dropout=None,
 
24
  **kwargs,
25
  ):
26
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -39,4 +40,5 @@ class XLMRobertaFlashConfig(PretrainedConfig):
39
  self.layer_norm_eps = layer_norm_eps
40
  self.position_embedding_type = position_embedding_type
41
  self.use_cache = use_cache
42
- self.classifier_dropout = classifier_dropout
 
 
21
  position_embedding_type="absolute",
22
  use_cache=True,
23
  classifier_dropout=None,
24
+ num_loras=5,
25
  **kwargs,
26
  ):
27
  super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
40
  self.layer_norm_eps = layer_norm_eps
41
  self.position_embedding_type = position_embedding_type
42
  self.use_cache = use_cache
43
+ self.classifier_dropout = classifier_dropout
44
+ self.num_loras = num_loras
modeling_lora.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from functools import partial
4
+ from typing import Iterator, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.utils.parametrize as parametrize
8
+ from torch import nn
9
+ from torch.nn import Parameter
10
+ from transformers import PretrainedConfig
11
+
12
+ from .modeling_xlm_roberta import XLMRobertaModel, XLMRobertaPreTrainedModel, XLMRobertaFlashConfig
13
+
14
+
15
+ def initialized_weights(
16
+ shape: Tuple[int], num_adaptions: int, init: str = "kaiming"
17
+ ) -> torch.Tensor:
18
+ weight_data = []
19
+ for _ in range(num_adaptions):
20
+ new_adaption = torch.zeros(shape)
21
+ if init == "kaiming":
22
+ nn.init.kaiming_uniform_(new_adaption, a=math.sqrt(5))
23
+ elif init == "normal":
24
+ nn.init.normal_(new_adaption)
25
+ else:
26
+ raise NotImplementedError
27
+ weight_data.append(new_adaption)
28
+ return torch.stack(weight_data, dim=0)
29
+
30
+
31
+ class LoRAParametrization(nn.Module):
32
+ """
33
+ This LoRA implementation was inspired by https://github.com/cccntu/minLoRA
34
+ The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy
35
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software
36
+ and associated documentation files (the "Software"), to deal in the Software without restriction,
37
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
38
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
39
+ subject to the following conditions:
40
+ The above copyright notice and this permission notice shall be included in all copies or substantial
41
+ portions of the Software.
42
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
43
+ LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
44
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
45
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
46
+ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
47
+ """
48
+ def __init__(
49
+ self,
50
+ fan_in: int,
51
+ fan_out: int,
52
+ layer_type: str = "linear",
53
+ num_adaptions: int = 1,
54
+ rank: int = 4,
55
+ lora_dropout_p: float = 0.0,
56
+ lora_alpha: float = 1,
57
+ ):
58
+ super().__init__()
59
+ # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
60
+ # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings
61
+ fan_in_fan_out = layer_type == "embedding"
62
+ self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
63
+
64
+ # For the officially "correct" LoRA initialization, check here: https://github.com/microsoft/LoRA
65
+ # TODO: Ensure that the initialization here is correct
66
+ if layer_type == "linear":
67
+ self.lora_A = nn.Parameter(
68
+ initialized_weights((rank, fan_in), num_adaptions, init="kaiming")
69
+ )
70
+ self.lora_B = nn.Parameter(torch.zeros((num_adaptions, fan_out, rank)))
71
+ elif layer_type == "embedding":
72
+ self.lora_A = nn.Parameter(torch.zeros((num_adaptions, fan_in, rank)))
73
+ self.lora_B = nn.Parameter(
74
+ initialized_weights(
75
+ (rank, fan_out), num_adaptions=num_adaptions, init="normal"
76
+ )
77
+ )
78
+ else:
79
+ raise NotImplementedError
80
+
81
+ self.lora_alpha, self.rank = lora_alpha, rank
82
+ self.scaling = lora_alpha / rank
83
+ self.lora_dropout = (
84
+ nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
85
+ )
86
+ self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
87
+ self.register_buffer(
88
+ "lora_dropout_mask",
89
+ torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype),
90
+ persistent=False,
91
+ )
92
+ self.forward_fn = lambda x: x
93
+ self.current_task = None
94
+
95
+ def _dropout(self, A):
96
+ # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
97
+ return A * self.lora_dropout(self.lora_dropout_mask)
98
+
99
+ def lora_forward(self, X):
100
+ assert self.current_task is not None
101
+ return (
102
+ X
103
+ + torch.matmul(
104
+ *self.swap(
105
+ (
106
+ self.lora_B[self.current_task],
107
+ self.dropout_fn(self.lora_A[self.current_task]),
108
+ )
109
+ )
110
+ ).view(X.shape)
111
+ * self.scaling
112
+ )
113
+
114
+ def forward(self, X):
115
+ return self.forward_fn(X)
116
+
117
+ @property
118
+ def current_task(self):
119
+ return self._current_task
120
+
121
+ @current_task.setter
122
+ def current_task(self, task: Union[None, int]):
123
+ self._current_task = task
124
+ if task is None:
125
+ self.forward_fn = lambda x: x
126
+ else:
127
+ self.forward_fn = self.lora_forward
128
+
129
+ @classmethod
130
+ def from_linear(
131
+ cls,
132
+ layer: nn.Module,
133
+ num_adaptions: int = 1,
134
+ rank: int = 4,
135
+ lora_dropout_p: float = 0.0,
136
+ lora_alpha: int = 1,
137
+ ):
138
+ assert isinstance(layer, nn.Linear)
139
+ fan_out, fan_in = layer.weight.shape
140
+ return cls(
141
+ fan_in,
142
+ fan_out,
143
+ num_adaptions=num_adaptions,
144
+ layer_type="linear",
145
+ rank=rank,
146
+ lora_dropout_p=lora_dropout_p,
147
+ lora_alpha=lora_alpha,
148
+ )
149
+
150
+ @classmethod
151
+ def from_embedding(
152
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
153
+ ):
154
+ assert isinstance(layer, nn.Embedding)
155
+ fan_in, fan_out = layer.weight.shape
156
+ return cls(
157
+ fan_in,
158
+ fan_out,
159
+ num_adaptions=num_adaptions,
160
+ layer_type="embedding",
161
+ rank=rank,
162
+ lora_dropout_p=lora_dropout_p,
163
+ lora_alpha=lora_alpha,
164
+ )
165
+
166
+ @classmethod
167
+ def add_to_layer(
168
+ cls, layer, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1
169
+ ):
170
+ if isinstance(layer, nn.Linear):
171
+ parametrize.register_parametrization(
172
+ layer,
173
+ "weight",
174
+ cls.from_linear(
175
+ layer,
176
+ num_adaptions=num_adaptions,
177
+ rank=rank,
178
+ lora_dropout_p=lora_dropout_p,
179
+ lora_alpha=lora_alpha,
180
+ ),
181
+ )
182
+ elif isinstance(layer, nn.Embedding):
183
+ parametrize.register_parametrization(
184
+ layer,
185
+ "weight",
186
+ cls.from_embedding(
187
+ layer,
188
+ num_adaptions=num_adaptions,
189
+ rank=rank,
190
+ lora_dropout_p=lora_dropout_p,
191
+ lora_alpha=lora_alpha,
192
+ ),
193
+ )
194
+
195
+ @staticmethod
196
+ def select_task_for_layer(layer: nn.Module, task_idx: Optional[int] = None):
197
+ if isinstance(layer, LoRAParametrization):
198
+ layer.current_task = task_idx
199
+
200
+ @staticmethod
201
+ def merge_lora_into_layer(layer: nn.Module):
202
+ if hasattr(layer, "parametrizations"):
203
+ for attr_name in layer.parametrizations.keys():
204
+ parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=True)
205
+
206
+
207
+ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
208
+ def __init__(self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None, add_pooling_layer=True):
209
+ super().__init__(config)
210
+ if roberta is None:
211
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=add_pooling_layer)
212
+ else:
213
+ self.roberta = roberta
214
+ self._is_merged = False
215
+ self._num_adaptions = config.num_loras
216
+ self._register_lora(self._num_adaptions)
217
+ self.main_params_trainable = False
218
+ self._task_idx = None
219
+ # By default, we select the first LoRA
220
+ self.current_task = 0
221
+
222
+ @property
223
+ def main_params_trainable(self):
224
+ return self._main_params_trainable
225
+
226
+ @main_params_trainable.setter
227
+ def main_params_trainable(self, val: bool):
228
+ """Whether the main parameters (i.e. those that are not LoRA) should be trainable.
229
+ This method sets the `requires_grad_` attribute of the main weights
230
+ and controls which parameters are returned in `self.parameters()`.
231
+ :param val: Whether or not to make the parameters trainable.
232
+ :return: None
233
+ """
234
+ self._main_params_trainable = val
235
+ for name, param in super().named_parameters():
236
+ if "lora" not in name:
237
+ param.requires_grad_(val)
238
+
239
+ @classmethod
240
+ def from_roberta(cls, *args, **kwargs):
241
+ roberta = XLMRobertaModel.from_pretrained(*args, **kwargs)
242
+ config = XLMRobertaFlashConfig.from_pretrained(*args, **kwargs)
243
+ return cls(config, roberta=roberta)
244
+
245
+ def merge_lora(self):
246
+ """Merges currently selected LoRA into main weights."""
247
+ if self._is_merged:
248
+ raise Exception('LoRA has already been merged, cannot merge again')
249
+ self._is_merged = True
250
+ self.apply(LoRAParametrization.merge_lora_into_layer)
251
+
252
+ @classmethod
253
+ def from_pretrained(
254
+ cls,
255
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
256
+ *model_args,
257
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
258
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
259
+ ignore_mismatched_sizes: bool = False,
260
+ force_download: bool = False,
261
+ local_files_only: bool = False,
262
+ token: Optional[Union[str, bool]] = None,
263
+ revision: str = "main",
264
+ use_safetensors: bool = None,
265
+ **kwargs,
266
+ ):
267
+ """
268
+ TODO: choose between from_roberta and super().from_pretrained
269
+ We want to be able to load both a pretrained XLMRoBertaModel, and a trained
270
+ XLMRobertaLoRA via this method. To this end, we need to check which of these
271
+ models we are expected to load.
272
+ """
273
+ return cls.from_roberta(pretrained_model_name_or_path)
274
+
275
+ def _register_lora(self, num_adaptions=1, rank=4, lora_dropout_p=0.0, lora_alpha=1):
276
+ self.apply(
277
+ partial(
278
+ LoRAParametrization.add_to_layer,
279
+ num_adaptions=num_adaptions,
280
+ rank=rank,
281
+ lora_dropout_p=lora_dropout_p,
282
+ lora_alpha=lora_alpha,
283
+ )
284
+ )
285
+
286
+ @property
287
+ def current_task(self):
288
+ """ Which LoRA is currently selected
289
+ :return: Integer or None (when LoRA is disabled)
290
+ """
291
+ return self._task_idx
292
+
293
+ @current_task.setter
294
+ def current_task(self, task_idx: Union[None, int]):
295
+ """Set the LoRA that is to be used.
296
+ The LoRA is specified by `task_idx`, which may be an integer >= 0,
297
+ indexing the available LoRAs. If it is None, no LoRA is used.
298
+ :param task_idx: Which LoRA to use
299
+ :return:
300
+ """
301
+ if self._is_merged:
302
+ raise Exception('LoRA has been merged, cannot select new task')
303
+ assert task_idx is None or 0 <= task_idx < self._num_adaptions
304
+ if self._task_idx != task_idx:
305
+ # In this case, we need to update the LoRAs everywhere
306
+ self._task_idx = task_idx
307
+ self.apply(
308
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
309
+ )
310
+
311
+ def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
312
+ if current_task is None or current_task >= 0:
313
+ self.current_task = current_task
314
+ return self.bert(*args, **kwargs)
315
+
316
+ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
317
+ for _, param in self.named_parameters(recurse=recurse):
318
+ yield param
319
+
320
+ def named_parameters(
321
+ self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
322
+ ) -> Iterator[Tuple[str, Parameter]]:
323
+ for name, param in super().named_parameters(
324
+ prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
325
+ ):
326
+ if "lora" in name or self.main_params_trainable:
327
+ yield name, param