Markus28 commited on
Commit
9410275
1 Parent(s): 0ff7c3d

feat: add current_task to forward

Browse files
Files changed (1) hide show
  1. modeling_lora.py +3 -1
modeling_lora.py CHANGED
@@ -259,7 +259,9 @@ class BertLoRA(BertPreTrainedModel):
259
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
260
  )
261
 
262
- def forward(self, *args, **kwargs):
 
 
263
  return self.bert(*args, **kwargs)
264
 
265
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
 
259
  partial(LoRAParametrization.select_task_for_layer, task_idx=task_idx)
260
  )
261
 
262
+ def forward(self, *args, **kwargs, current_task: Union[None, int] = -1):
263
+ if current_task is None or current_task >= 0:
264
+ self.current_task = current_task
265
  return self.bert(*args, **kwargs)
266
 
267
  def parameters(self, recurse: bool = True) -> Iterator[Parameter]: