roman-bachmann commited on
Commit
ff0b3d2
1 Parent(s): bf23504

Update multimae/multimae.py

Browse files
Files changed (1) hide show
  1. multimae/multimae.py +1 -1
multimae/multimae.py CHANGED
@@ -201,7 +201,7 @@ class MultiMAE(nn.Module):
201
  task_masks.append(mask)
202
 
203
  mask_all = torch.cat(task_masks, dim=1)
204
- ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float(), dim=1)
205
  ids_restore = torch.argsort(ids_shuffle, dim=1)
206
  ids_keep = ids_shuffle[:, :num_encoded_tokens]
207
 
 
201
  task_masks.append(mask)
202
 
203
  mask_all = torch.cat(task_masks, dim=1)
204
+ ids_shuffle = torch.argsort(mask_all + torch.rand_like(mask_all.float()), dim=1)
205
  ids_restore = torch.argsort(ids_shuffle, dim=1)
206
  ids_keep = ids_shuffle[:, :num_encoded_tokens]
207