feat: make main parameters trainable
Browse files- modeling_lora.py +14 -3
modeling_lora.py
CHANGED
@@ -207,11 +207,21 @@ class BertLoRA(BertPreTrainedModel):
|
|
207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
208 |
else:
|
209 |
self.bert = bert
|
|
|
210 |
self._register_lora(num_adaptions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
for name, param in super().named_parameters():
|
212 |
if "lora" not in name:
|
213 |
-
param.requires_grad_(
|
214 |
-
self.current_task = 0
|
215 |
|
216 |
@classmethod
|
217 |
def from_bert(cls, *args, num_adaptions=1, **kwargs):
|
@@ -254,6 +264,7 @@ class BertLoRA(BertPreTrainedModel):
|
|
254 |
|
255 |
@current_task.setter
|
256 |
def current_task(self, task_idx: Union[None, int]):
|
|
|
257 |
self._task_idx = task_idx
|
258 |
self.apply(
|
259 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
@@ -274,5 +285,5 @@ class BertLoRA(BertPreTrainedModel):
|
|
274 |
for name, param in super().named_parameters(
|
275 |
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
|
276 |
):
|
277 |
-
if "lora" in name:
|
278 |
yield name, param
|
|
|
207 |
self.bert = BertModel(config, add_pooling_layer=add_pooling_layer)
|
208 |
else:
|
209 |
self.bert = bert
|
210 |
+
self._num_adaptions = num_adaptions
|
211 |
self._register_lora(num_adaptions)
|
212 |
+
self.main_params_trainable = False
|
213 |
+
self.current_task = 0
|
214 |
+
|
215 |
+
@property
|
216 |
+
def main_params_trainable(self):
|
217 |
+
return self._main_params_trainable
|
218 |
+
|
219 |
+
@main_params_trainable.setter
|
220 |
+
def main_params_trainable(self, val):
|
221 |
+
self._main_params_trainable = val
|
222 |
for name, param in super().named_parameters():
|
223 |
if "lora" not in name:
|
224 |
+
param.requires_grad_(val)
|
|
|
225 |
|
226 |
@classmethod
|
227 |
def from_bert(cls, *args, num_adaptions=1, **kwargs):
|
|
|
264 |
|
265 |
@current_task.setter
|
266 |
def current_task(self, task_idx: Union[None, int]):
|
267 |
+
assert task_idx is None or 0 <= task_idx < self._num_adaptions
|
268 |
self._task_idx = task_idx
|
269 |
self.apply(
|
270 |
partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
|
|
|
285 |
for name, param in super().named_parameters(
|
286 |
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate
|
287 |
):
|
288 |
+
if "lora" in name or self.main_params_trainable:
|
289 |
yield name, param
|