feat: use property in LoRA parametrization
Browse files- modeling_lora.py +8 -3
modeling_lora.py
CHANGED
@@ -116,8 +116,13 @@ class LoRAParametrization(nn.Module):
|
|
116 |
def forward(self, X):
|
117 |
return self.forward_fn(X)
|
118 |
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
121 |
if task is None:
|
122 |
self.forward_fn = lambda x: x
|
123 |
else:
|
@@ -192,7 +197,7 @@ class LoRAParametrization(nn.Module):
|
|
192 |
@classmethod
|
193 |
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
194 |
if isinstance(layer, LoRAParametrization):
|
195 |
-
layer.
|
196 |
|
197 |
|
198 |
class BertLoRA(BertPreTrainedModel):
|
|
|
116 |
def forward(self, X):
|
117 |
return self.forward_fn(X)
|
118 |
|
119 |
+
@property
|
120 |
+
def current_task(self):
|
121 |
+
return self._current_task
|
122 |
+
|
123 |
+
@current_task.setter
|
124 |
+
def current_task(self, task: Union[None, int]):
|
125 |
+
self._current_task = task
|
126 |
if task is None:
|
127 |
self.forward_fn = lambda x: x
|
128 |
else:
|
|
|
197 |
@classmethod
|
198 |
def select_task_for_layer(cls, layer: nn.Module, task_idx: Optional[int] = None):
|
199 |
if isinstance(layer, LoRAParametrization):
|
200 |
+
layer.current_task = task_idx
|
201 |
|
202 |
|
203 |
class BertLoRA(BertPreTrainedModel):
|