# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # import torch def apply_masks(x, masks, concat=True): """ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep """ all_x = [] for m in masks: mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) all_x += [torch.gather(x, dim=1, index=mask_keep)] if not concat: return all_x return torch.cat(all_x, dim=0)