rishitdagli's picture
push a demo
46ff99b
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
def collate_data_and_cast(
samples_list,
mask_ratio_tuple,
mask_probability,
dtype,
n_tokens=None,
mask_generator=None,
):
# dtype = torch.half # TODO: Remove
n_global_crops = len(samples_list[0][0]["global_crops"])
n_local_crops = len(samples_list[0][0]["local_crops"])
collated_global_crops = torch.stack(
[s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]
)
collated_local_crops = torch.stack(
[s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]
)
B = len(collated_global_crops)
N = n_tokens
n_samples_masked = int(B * mask_probability)
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
upperbound = 0
masks_list = []
for i in range(0, n_samples_masked):
prob_min = probs[i]
prob_max = probs[i + 1]
masks_list.append(
torch.BoolTensor(
mask_generator(int(N * random.uniform(prob_min, prob_max)))
)
)
upperbound += int(N * prob_max)
for i in range(n_samples_masked, B):
masks_list.append(torch.BoolTensor(mask_generator(0)))
random.shuffle(masks_list)
collated_masks = torch.stack(masks_list).flatten(1)
mask_indices_list = collated_masks.flatten().nonzero().flatten()
masks_weight = (
(1 / collated_masks.sum(-1).clamp(min=1.0))
.unsqueeze(-1)
.expand_as(collated_masks)[collated_masks]
)
return {
"collated_global_crops": collated_global_crops.to(dtype),
"collated_local_crops": collated_local_crops.to(dtype),
"collated_masks": collated_masks,
"mask_indices_list": mask_indices_list,
"masks_weight": masks_weight,
"upperbound": upperbound,
"n_masked_patches": torch.full(
(1,), fill_value=mask_indices_list.shape[0], dtype=torch.long
),
}