File size: 3,977 Bytes
fd01725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import logging
from torchvision import transforms
from .transforms import (
GaussianBlur,
make_normalize_transform,
)
logger = logging.getLogger("dinov2")
class DataAugmentationDINO(object):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
global_crops_size=224,
local_crops_size=96,
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
self.local_crops_number = local_crops_number
self.global_crops_size = global_crops_size
self.local_crops_size = local_crops_size
logger.info("###################################")
logger.info("Using data augmentation parameters:")
logger.info(f"global_crops_scale: {global_crops_scale}")
logger.info(f"local_crops_scale: {local_crops_scale}")
logger.info(f"local_crops_number: {local_crops_number}")
logger.info(f"global_crops_size: {global_crops_size}")
logger.info(f"local_crops_size: {local_crops_size}")
logger.info("###################################")
# random resized crop and flip
self.geometric_augmentation_global = transforms.Compose(
[
transforms.RandomResizedCrop(
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
self.geometric_augmentation_local = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
# color distorsions / blurring
color_jittering = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
global_transfo1_extra = GaussianBlur(p=1.0)
global_transfo2_extra = transforms.Compose(
[
GaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=128, p=0.2),
]
)
local_transfo_extra = GaussianBlur(p=0.5)
# normalization
self.normalize = transforms.Compose(
[
transforms.ToTensor(),
make_normalize_transform(),
]
)
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
def __call__(self, image):
output = {}
# global crops:
im1_base = self.geometric_augmentation_global(image)
global_crop_1 = self.global_transfo1(im1_base)
im2_base = self.geometric_augmentation_global(image)
global_crop_2 = self.global_transfo2(im2_base)
output["global_crops"] = [global_crop_1, global_crop_2]
# global crops for teacher:
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
# local crops:
local_crops = [
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
]
output["local_crops"] = local_crops
output["offsets"] = ()
return output
|