# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import random import math import numpy as np class MaskingGenerator: def __init__( self, input_size, num_masking_patches=None, min_num_patches=4, max_num_patches=None, min_aspect=0.3, max_aspect=None, ): if not isinstance(input_size, tuple): input_size = (input_size,) * 2 self.height, self.width = input_size self.num_patches = self.height * self.width self.num_masking_patches = num_masking_patches self.min_num_patches = min_num_patches self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) def __repr__(self): repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( self.height, self.width, self.min_num_patches, self.max_num_patches, self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1], ) return repr_str def get_shape(self): return self.height, self.width def _mask(self, mask, max_mask_patches): delta = 0 for _ in range(10): target_area = random.uniform(self.min_num_patches, max_mask_patches) aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < self.width and h < self.height: top = random.randint(0, self.height - h) left = random.randint(0, self.width - w) num_masked = mask[top : top + h, left : left + w].sum() # Overlap if 0 < h * w - num_masked <= max_mask_patches: for i in range(top, top + h): for j in range(left, left + w): if mask[i, j] == 0: mask[i, j] = 1 delta += 1 if delta > 0: break return delta def __call__(self, num_masking_patches=0): mask = np.zeros(shape=self.get_shape(), dtype=bool) mask_count = 0 while mask_count < num_masking_patches: max_mask_patches = num_masking_patches - mask_count max_mask_patches = min(max_mask_patches, self.max_num_patches) delta = self._mask(mask, max_mask_patches) if delta == 0: break else: mask_count += delta return mask