Update build_mlp.py
#14
by
unsubscribe
- opened
- build_mlp.py +1 -1
build_mlp.py
CHANGED
@@ -192,9 +192,9 @@ class PLoRA(nn.Linear):
|
|
192 |
def forward(self, x, im_mask=None):
|
193 |
B, N, C = x.shape
|
194 |
x = x.reshape(-1, C)
|
195 |
-
im_mask = im_mask.view(-1)
|
196 |
res = super().forward(x)
|
197 |
if im_mask is not None:
|
|
|
198 |
if torch.sum(im_mask) > 0:
|
199 |
part_x = x[im_mask]
|
200 |
res[im_mask] += self.Plora_B(self.Plora_A(
|
|
|
192 |
def forward(self, x, im_mask=None):
|
193 |
B, N, C = x.shape
|
194 |
x = x.reshape(-1, C)
|
|
|
195 |
res = super().forward(x)
|
196 |
if im_mask is not None:
|
197 |
+
im_mask = im_mask.view(-1)
|
198 |
if torch.sum(im_mask) > 0:
|
199 |
part_x = x[im_mask]
|
200 |
res[im_mask] += self.Plora_B(self.Plora_A(
|