|
|
|
|
|
|
|
|
|
|
| import logging
|
|
|
| import numpy as np
|
| from torch import nn
|
| from torchvision import transforms
|
|
|
| from dinov3.data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, GaussianBlur, make_normalize_transform
|
|
|
| logger = logging.getLogger("dinov3")
|
|
|
|
|
| class DataAugmentationDINO(object):
|
| def __init__(
|
| self,
|
| global_crops_scale,
|
| local_crops_scale,
|
| local_crops_number,
|
| global_crops_size=224,
|
| local_crops_size=96,
|
| gram_teacher_crops_size=None,
|
| gram_teacher_no_distortions=False,
|
| teacher_no_color_jitter=False,
|
| local_crops_subset_of_global_crops=False,
|
| patch_size=16,
|
| share_color_jitter=False,
|
| horizontal_flips=True,
|
| mean=IMAGENET_DEFAULT_MEAN,
|
| std=IMAGENET_DEFAULT_STD,
|
| ):
|
| self.global_crops_scale = global_crops_scale
|
| self.local_crops_scale = local_crops_scale
|
| self.local_crops_number = local_crops_number
|
| self.global_crops_size = global_crops_size
|
| self.local_crops_size = local_crops_size
|
| self.gram_teacher_crops_size = gram_teacher_crops_size
|
| self.gram_teacher_no_distortions = gram_teacher_no_distortions
|
| self.teacher_no_color_jitter = teacher_no_color_jitter
|
| self.local_crops_subset_of_global_crops = local_crops_subset_of_global_crops
|
| self.patch_size = patch_size
|
| self.share_color_jitter = share_color_jitter
|
| self.mean = mean
|
| self.std = std
|
|
|
| logger.info("###################################")
|
| logger.info("Using data augmentation parameters:")
|
| logger.info(f"global_crops_scale: {global_crops_scale}")
|
| logger.info(f"local_crops_scale: {local_crops_scale}")
|
| logger.info(f"local_crops_number: {local_crops_number}")
|
| logger.info(f"global_crops_size: {global_crops_size}")
|
| logger.info(f"local_crops_size: {local_crops_size}")
|
| logger.info(f"gram_crops_size: {gram_teacher_crops_size}")
|
| logger.info(f"gram_teacher_no_distortions: {gram_teacher_no_distortions}")
|
| logger.info(f"teacher_no_color_jitter: {teacher_no_color_jitter}")
|
| logger.info(f"local_crops_subset_of_global_crops: {local_crops_subset_of_global_crops}")
|
| logger.info(f"patch_size if local_crops_subset_of_global_crops: {patch_size}")
|
| logger.info(f"share_color_jitter: {share_color_jitter}")
|
| logger.info(f"horizontal flips: {horizontal_flips}")
|
| logger.info("###################################")
|
|
|
|
|
|
|
| global_crop_max_size = max(global_crops_size, gram_teacher_crops_size if gram_teacher_crops_size else 0)
|
|
|
|
|
| self.geometric_augmentation_global = transforms.Compose(
|
| [
|
| transforms.RandomResizedCrop(
|
| global_crop_max_size,
|
| scale=global_crops_scale,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| ),
|
| transforms.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0),
|
| ]
|
| )
|
|
|
| resize_global = nn.Identity()
|
| self.resize_global_post_transf = (
|
| nn.Identity()
|
| )
|
| self.resize_gram_teacher = None
|
| if gram_teacher_crops_size is not None:
|
|
|
| if gram_teacher_no_distortions:
|
|
|
|
|
|
|
| resize_global = transforms.Resize(
|
| global_crops_size,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| )
|
| else:
|
|
|
|
|
| self.resize_global_post_transf = transforms.Resize(
|
| global_crops_size,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| )
|
|
|
| self.resize_gram_teacher = transforms.Resize(
|
| gram_teacher_crops_size,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| )
|
|
|
| self.geometric_augmentation_local = transforms.Compose(
|
| [
|
| transforms.RandomResizedCrop(
|
| local_crops_size,
|
| scale=local_crops_scale,
|
| interpolation=transforms.InterpolationMode.BICUBIC,
|
| ),
|
| transforms.RandomHorizontalFlip(p=0.5 if horizontal_flips else 0.0),
|
| ]
|
| )
|
|
|
|
|
| color_jittering = transforms.Compose(
|
| [
|
| transforms.RandomApply(
|
| [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
| p=0.8,
|
| ),
|
| transforms.RandomGrayscale(p=0.2),
|
| ]
|
| )
|
|
|
| global_transfo1_extra = GaussianBlur(p=1.0)
|
|
|
| global_transfo2_extra = transforms.Compose(
|
| [
|
| GaussianBlur(p=0.1),
|
| transforms.RandomSolarize(threshold=128, p=0.2),
|
| ]
|
| )
|
|
|
| local_transfo_extra = GaussianBlur(p=0.5)
|
|
|
|
|
| self.normalize = transforms.Compose(
|
| [
|
| transforms.ToTensor(),
|
| make_normalize_transform(mean=mean, std=std),
|
| ]
|
| )
|
|
|
| if self.share_color_jitter:
|
| self.color_jittering = color_jittering
|
| self.global_transfo1 = transforms.Compose([resize_global, global_transfo1_extra, self.normalize])
|
| self.global_transfo2 = transforms.Compose([resize_global, global_transfo2_extra, self.normalize])
|
| self.local_transfo = transforms.Compose([local_transfo_extra, self.normalize])
|
| else:
|
| self.global_transfo1 = transforms.Compose(
|
| [resize_global, color_jittering, global_transfo1_extra, self.normalize]
|
| )
|
| self.global_transfo2 = transforms.Compose(
|
| [resize_global, color_jittering, global_transfo2_extra, self.normalize]
|
| )
|
| self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
|
|
|
| def __call__(self, image):
|
| output = {}
|
| output["weak_flag"] = True
|
|
|
| if self.share_color_jitter:
|
| image = self.color_jittering(image)
|
|
|
|
|
| im1_base = self.geometric_augmentation_global(image)
|
| global_crop_1_transf = self.global_transfo1(im1_base)
|
| global_crop_1 = self.resize_global_post_transf(global_crop_1_transf)
|
|
|
| im2_base = self.geometric_augmentation_global(image)
|
| global_crop_2_transf = self.global_transfo2(im2_base)
|
| global_crop_2 = self.resize_global_post_transf(global_crop_2_transf)
|
|
|
| output["global_crops"] = [global_crop_1, global_crop_2]
|
|
|
|
|
| if self.teacher_no_color_jitter:
|
| output["global_crops_teacher"] = [
|
| self.normalize(im1_base),
|
| self.normalize(im2_base),
|
| ]
|
| else:
|
| output["global_crops_teacher"] = [global_crop_1, global_crop_2]
|
|
|
| if self.gram_teacher_crops_size is not None:
|
|
|
| if self.gram_teacher_no_distortions:
|
| gram_crop_1 = self.normalize(self.resize_gram_teacher(im1_base))
|
| gram_crop_2 = self.normalize(self.resize_gram_teacher(im2_base))
|
| else:
|
| gram_crop_1 = self.resize_gram_teacher(global_crop_1_transf)
|
| gram_crop_2 = self.resize_gram_teacher(global_crop_2_transf)
|
| output["gram_teacher_crops"] = [gram_crop_1, gram_crop_2]
|
|
|
|
|
| if self.local_crops_subset_of_global_crops:
|
| _local_crops = [self.local_transfo(im1_base) for _ in range(self.local_crops_number // 2)] + [
|
| self.local_transfo(im2_base) for _ in range(self.local_crops_number // 2)
|
| ]
|
|
|
| local_crops = []
|
| offsets = []
|
| gs = self.global_crops_size
|
| ls = self.local_crops_size
|
| for img in _local_crops:
|
| rx, ry = np.random.randint(0, (gs - ls) // self.patch_size, 2) * self.patch_size
|
| local_crops.append(img[:, rx : rx + ls, ry : ry + ls])
|
| offsets.append((rx, ry))
|
|
|
| output["local_crops"] = local_crops
|
| output["offsets"] = offsets
|
| else:
|
| local_crops = [
|
| self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
|
| ]
|
| output["local_crops"] = local_crops
|
| output["offsets"] = ()
|
|
|
| return output
|
|
|