Markus28 commited on
Commit
462e28d
1 Parent(s): a416a9d

feat: only apply select_task_for_layer if task has changed

Browse files
Files changed (1) hide show
  1. modeling_lora.py +5 -4
modeling_lora.py CHANGED
@@ -265,10 +265,11 @@ class BertLoRA(BertPreTrainedModel):
265
  @current_task.setter
266
  def current_task(self, task_idx: Union[None, int]):
267
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
268
- self._task_idx = task_idx
269
- self.apply(
270
- partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
271
- )
 
272
 
273
  def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
274
  if current_task is None or current_task >= 0:
 
265
  @current_task.setter
266
  def current_task(self, task_idx: Union[None, int]):
267
  assert task_idx is None or 0 <= task_idx < self._num_adaptions
268
+ if self._task_idx != task_idx
269
+ self._task_idx = task_idx
270
+ self.apply(
271
+ partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
272
+ )
273
 
274
  def forward(self, *args, current_task: Union[None, int] = -1, **kwargs):
275
  if current_task is None or current_task >= 0: