|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
from torchvision import transforms |
|
|
|
from .transforms import ( |
|
GaussianBlur, |
|
make_normalize_transform, |
|
) |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
class DataAugmentationDINO(object): |
|
def __init__( |
|
self, |
|
global_crops_scale, |
|
local_crops_scale, |
|
local_crops_number, |
|
global_crops_size=224, |
|
local_crops_size=96, |
|
): |
|
self.global_crops_scale = global_crops_scale |
|
self.local_crops_scale = local_crops_scale |
|
self.local_crops_number = local_crops_number |
|
self.global_crops_size = global_crops_size |
|
self.local_crops_size = local_crops_size |
|
|
|
logger.info("###################################") |
|
logger.info("Using data augmentation parameters:") |
|
logger.info(f"global_crops_scale: {global_crops_scale}") |
|
logger.info(f"local_crops_scale: {local_crops_scale}") |
|
logger.info(f"local_crops_number: {local_crops_number}") |
|
logger.info(f"global_crops_size: {global_crops_size}") |
|
logger.info(f"local_crops_size: {local_crops_size}") |
|
logger.info("###################################") |
|
|
|
|
|
self.geometric_augmentation_global = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
|
), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
] |
|
) |
|
|
|
self.geometric_augmentation_local = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
|
), |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
] |
|
) |
|
|
|
|
|
color_jittering = transforms.Compose( |
|
[ |
|
transforms.RandomApply( |
|
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], |
|
p=0.8, |
|
), |
|
transforms.RandomGrayscale(p=0.2), |
|
] |
|
) |
|
|
|
global_transfo1_extra = GaussianBlur(p=1.0) |
|
|
|
global_transfo2_extra = transforms.Compose( |
|
[ |
|
GaussianBlur(p=0.1), |
|
transforms.RandomSolarize(threshold=128, p=0.2), |
|
] |
|
) |
|
|
|
local_transfo_extra = GaussianBlur(p=0.5) |
|
|
|
|
|
self.normalize = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
make_normalize_transform(), |
|
] |
|
) |
|
|
|
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) |
|
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) |
|
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) |
|
|
|
def __call__(self, image): |
|
output = {} |
|
|
|
|
|
im1_base = self.geometric_augmentation_global(image) |
|
global_crop_1 = self.global_transfo1(im1_base) |
|
|
|
im2_base = self.geometric_augmentation_global(image) |
|
global_crop_2 = self.global_transfo2(im2_base) |
|
|
|
output["global_crops"] = [global_crop_1, global_crop_2] |
|
|
|
|
|
output["global_crops_teacher"] = [global_crop_1, global_crop_2] |
|
|
|
|
|
local_crops = [ |
|
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) |
|
] |
|
output["local_crops"] = local_crops |
|
output["offsets"] = () |
|
|
|
return output |
|
|