Drexubery's picture
update
df13f4b
raw
history blame
741 Bytes
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
# --------------------------------------------------------
# Masking utils
# --------------------------------------------------------
import torch
import torch.nn as nn
class RandomMask(nn.Module):
"""
random masking
"""
def __init__(self, num_patches, mask_ratio):
super().__init__()
self.num_patches = num_patches
self.num_mask = int(mask_ratio * self.num_patches)
def __call__(self, x):
noise = torch.rand(x.size(0), self.num_patches, device=x.device)
argsort = torch.argsort(noise, dim=1)
return argsort < self.num_mask