# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # Masking utils # -------------------------------------------------------- import torch import torch.nn as nn class RandomMask(nn.Module): """ random masking """ def __init__(self, num_patches, mask_ratio): super().__init__() self.num_patches = num_patches self.num_mask = int(mask_ratio * self.num_patches) def __call__(self, x): noise = torch.rand(x.size(0), self.num_patches, device=x.device) argsort = torch.argsort(noise, dim=1) return argsort < self.num_mask