|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
multi-crop dataset to implement multi-crop augmentation and also dataset |
|
""" |
|
import copy |
|
import random |
|
|
|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image, ImageFilter, ImageOps |
|
from src.dataset import ImageFolder |
|
from src.RandAugment import rand_augment_transform |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from timm.data.random_erasing import RandomErasing |
|
from timm.data.transforms import _pil_interp |
|
|
|
|
|
class GaussianBlur(object): |
|
""" |
|
Apply Gaussian Blur to the PIL image. |
|
""" |
|
|
|
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0): |
|
self.prob = p |
|
self.radius_min = radius_min |
|
self.radius_max = radius_max |
|
|
|
def __call__(self, img): |
|
do_it = random.random() <= self.prob |
|
if not do_it: |
|
return img |
|
|
|
return img.filter( |
|
ImageFilter.GaussianBlur( |
|
radius=random.uniform(self.radius_min, self.radius_max) |
|
) |
|
) |
|
|
|
|
|
class Solarization(object): |
|
""" |
|
Apply Solarization to the PIL image. |
|
""" |
|
|
|
def __init__(self, p): |
|
self.p = p |
|
|
|
def __call__(self, img): |
|
if random.random() < self.p: |
|
return ImageOps.solarize(img) |
|
else: |
|
return img |
|
|
|
|
|
def strong_transforms( |
|
img_size=224, |
|
scale=(0.08, 1.0), |
|
ratio=(0.75, 1.3333333333333333), |
|
hflip=0.5, |
|
vflip=0.0, |
|
color_jitter=0.4, |
|
auto_augment="rand-m9-mstd0.5-inc1", |
|
interpolation="random", |
|
use_prefetcher=True, |
|
mean=IMAGENET_DEFAULT_MEAN, |
|
std=IMAGENET_DEFAULT_STD, |
|
re_prob=0.25, |
|
re_mode="pixel", |
|
re_count=1, |
|
re_num_splits=0, |
|
color_aug=False, |
|
strong_ratio=0.45, |
|
): |
|
""" |
|
for use in a mixing dataset that passes |
|
* all data through the first (primary) transform, called the 'clean' data |
|
* a portion of the data through the secondary transform |
|
* normalizes and converts the branches above with the third, final transform |
|
""" |
|
|
|
scale = tuple(scale or (0.08, 1.0)) |
|
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) |
|
|
|
primary_tfl = [] |
|
if hflip > 0.0: |
|
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] |
|
if vflip > 0.0: |
|
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] |
|
|
|
secondary_tfl = [] |
|
if auto_augment: |
|
assert isinstance(auto_augment, str) |
|
if isinstance(img_size, tuple): |
|
img_size_min = min(img_size) |
|
else: |
|
img_size_min = img_size |
|
aa_params = dict( |
|
translate_const=int(img_size_min * strong_ratio), |
|
img_mean=tuple([min(255, round(255 * x)) for x in mean]), |
|
) |
|
if interpolation and interpolation != "random": |
|
aa_params["interpolation"] = _pil_interp(interpolation) |
|
if auto_augment.startswith("rand"): |
|
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] |
|
if color_jitter is not None and color_aug: |
|
|
|
flip_and_color_jitter = [ |
|
transforms.RandomApply( |
|
[ |
|
transforms.ColorJitter( |
|
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 |
|
) |
|
], |
|
p=0.8, |
|
), |
|
transforms.RandomGrayscale(p=0.2), |
|
] |
|
secondary_tfl += flip_and_color_jitter |
|
|
|
if interpolation == "random": |
|
interpolation = (Image.BILINEAR, Image.BICUBIC) |
|
else: |
|
interpolation = _pil_interp(interpolation) |
|
final_tfl = [ |
|
transforms.RandomResizedCrop( |
|
size=img_size, scale=scale, ratio=ratio, interpolation=Image.BICUBIC |
|
) |
|
] |
|
if use_prefetcher: |
|
|
|
final_tfl += [transforms.ToTensor()] |
|
else: |
|
final_tfl += [ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), |
|
] |
|
if re_prob > 0.0: |
|
final_tfl.append( |
|
RandomErasing( |
|
re_prob, |
|
mode=re_mode, |
|
max_count=re_count, |
|
num_splits=re_num_splits, |
|
device="cpu", |
|
) |
|
) |
|
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) |
|
|
|
|
|
class DataAugmentation(object): |
|
""" |
|
implement multi-crop data augmentation. |
|
--global_crops_scale: scale range of the 224-sized cropped image before resizing |
|
--local_crops_scale: scale range of the 96-sized cropped image before resizing |
|
--local_crops_number: Number of small local views to generate |
|
--prob: when we use strong augmentation and weak augmentation, the ratio of images to |
|
be cropped with strong augmentation |
|
--vanilla_weak_augmentation: whether we use the same augmentation in DINO, namely |
|
only using weak augmentation |
|
--color_aug: after AutoAugment, whether we further perform color augmentation |
|
--local_crop_size: the small crop size |
|
--timm_auto_augment_par: the parameters for the AutoAugment used in DeiT |
|
--strong_ratio: the ratio of image augmentation for the AutoAugment used in DeiT |
|
--re_prob: the re-prob parameter of image augmentation for the AutoAugment used in DeiT |
|
--use_prefetcher: whether we use prefetcher which can accerelate the training speed |
|
""" |
|
|
|
def __init__( |
|
self, |
|
global_crops_scale, |
|
local_crops_scale, |
|
local_crops_number, |
|
prob=0.5, |
|
vanilla_weak_augmentation=False, |
|
color_aug=False, |
|
local_crop_size=[96], |
|
timm_auto_augment_par="rand-m9-mstd0.5-inc1", |
|
strong_ratio=0.45, |
|
re_prob=0.25, |
|
use_prefetcher=False, |
|
): |
|
|
|
|
|
self.prob = prob |
|
|
|
self.vanilla_weak_augmentation = vanilla_weak_augmentation |
|
|
|
flip_and_color_jitter = transforms.Compose( |
|
[ |
|
transforms.RandomHorizontalFlip(p=0.5), |
|
transforms.RandomApply( |
|
[ |
|
transforms.ColorJitter( |
|
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1 |
|
) |
|
], |
|
p=0.8, |
|
), |
|
transforms.RandomGrayscale(p=0.2), |
|
] |
|
) |
|
|
|
if use_prefetcher: |
|
normalize = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
] |
|
) |
|
else: |
|
normalize = transforms.Compose( |
|
[ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
] |
|
) |
|
|
|
|
|
|
|
self.global_transfo1 = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
224, scale=global_crops_scale, interpolation=Image.BICUBIC |
|
), |
|
flip_and_color_jitter, |
|
GaussianBlur(1.0), |
|
normalize, |
|
] |
|
) |
|
|
|
|
|
self.global_transfo2 = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
224, scale=global_crops_scale, interpolation=Image.BICUBIC |
|
), |
|
flip_and_color_jitter, |
|
GaussianBlur(0.1), |
|
Solarization(0.2), |
|
normalize, |
|
] |
|
) |
|
|
|
|
|
self.global_transfo3 = strong_transforms( |
|
img_size=224, |
|
scale=global_crops_scale, |
|
ratio=(0.75, 1.3333333333333333), |
|
hflip=0.5, |
|
vflip=0.0, |
|
color_jitter=0.4, |
|
auto_augment=timm_auto_augment_par, |
|
interpolation="random", |
|
use_prefetcher=use_prefetcher, |
|
mean=IMAGENET_DEFAULT_MEAN, |
|
std=IMAGENET_DEFAULT_STD, |
|
re_prob=re_prob, |
|
re_mode="pixel", |
|
re_count=1, |
|
re_num_splits=0, |
|
color_aug=color_aug, |
|
strong_ratio=strong_ratio, |
|
) |
|
|
|
|
|
self.local_crops_number = ( |
|
local_crops_number |
|
) |
|
assert local_crop_size[0] == 96 |
|
|
|
self.local_transfo = transforms.Compose( |
|
[ |
|
transforms.RandomResizedCrop( |
|
local_crop_size[0], |
|
scale=local_crops_scale, |
|
interpolation=Image.BICUBIC, |
|
), |
|
flip_and_color_jitter, |
|
GaussianBlur(p=0.5), |
|
normalize, |
|
] |
|
) |
|
|
|
self.local_transfo2 = strong_transforms( |
|
img_size=local_crop_size[0], |
|
scale=local_crops_scale, |
|
ratio=(0.75, 1.3333333333333333), |
|
hflip=0.5, |
|
vflip=0.0, |
|
color_jitter=0.4, |
|
auto_augment=timm_auto_augment_par, |
|
interpolation="random", |
|
use_prefetcher=use_prefetcher, |
|
mean=IMAGENET_DEFAULT_MEAN, |
|
std=IMAGENET_DEFAULT_STD, |
|
re_prob=re_prob, |
|
re_mode="pixel", |
|
re_count=1, |
|
re_num_splits=0, |
|
color_aug=color_aug, |
|
strong_ratio=strong_ratio, |
|
) |
|
|
|
def __call__(self, image): |
|
""" |
|
implement multi-crop data augmentation. Generate two 224-sized + |
|
"local_crops_number" 96-sized images |
|
""" |
|
crops = [] |
|
|
|
img1 = self.global_transfo1(image) |
|
img2 = self.global_transfo2(image) |
|
crops.append(img1) |
|
crops.append(img2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
weak_flag = False |
|
|
|
if self.vanilla_weak_augmentation is True: |
|
|
|
crops.append(copy.deepcopy(img1)) |
|
crops.append(copy.deepcopy(img2)) |
|
weak_flag = True |
|
elif self.prob < 1.0 and random.random() > self.prob: |
|
|
|
crops.append(self.global_transfo3(image)) |
|
crops.append(self.global_transfo3(image)) |
|
else: |
|
|
|
crops.append(self.global_transfo1(image)) |
|
crops.append(self.global_transfo2(image)) |
|
weak_flag = True |
|
|
|
|
|
for _ in range(self.local_crops_number): |
|
if self.prob < 1.0 and random.random() > self.prob: |
|
|
|
crops.append(self.local_transfo2(image)) |
|
else: |
|
|
|
crops.append(self.local_transfo(image)) |
|
|
|
return crops, weak_flag |
|
|
|
|
|
def get_dataset(args): |
|
""" |
|
build a multi-crop data augmentation and a dataset/dataloader |
|
""" |
|
|
|
transform = DataAugmentation( |
|
global_crops_scale=args.global_crops_scale, |
|
local_crops_scale=args.local_crops_scale, |
|
local_crops_number=args.local_crops_number, |
|
vanilla_weak_augmentation=args.vanilla_weak_augmentation, |
|
prob=args.prob, |
|
color_aug=args.color_aug, |
|
local_crop_size=args.size_crops, |
|
timm_auto_augment_par=args.timm_auto_augment_par, |
|
strong_ratio=args.strong_ratio, |
|
re_prob=args.re_prob, |
|
use_prefetcher=args.use_prefetcher, |
|
) |
|
|
|
|
|
|
|
class_num = 2 if args.debug else 1000 |
|
dataset = ImageFolder(args.data_path, transform=transform, class_num=class_num) |
|
|
|
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True) |
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
sampler=sampler, |
|
batch_size=args.batch_size_per_gpu, |
|
num_workers=args.num_workers, |
|
pin_memory=True, |
|
drop_last=True, |
|
) |
|
return data_loader |
|
|
|
|
|
class data_prefetcher: |
|
""" |
|
implement data prefetcher. we perform some augmentation on GPUs intead of CPUs |
|
--loader: a data loader |
|
--fp16: whether we use fp16, if yes, we need to tranform the data to be fp16 |
|
""" |
|
|
|
def __init__(self, loader, fp16=True): |
|
self.loader = iter(loader) |
|
self.fp16 = fp16 |
|
self.stream = torch.cuda.Stream() |
|
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 3, 1, 1) |
|
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 3, 1, 1) |
|
if fp16: |
|
self.mean = self.mean.half() |
|
self.std = self.std.half() |
|
|
|
self.preload() |
|
|
|
def preload(self): |
|
""" |
|
preload the next minibatch of data |
|
""" |
|
try: |
|
self.multi_crops, self.weak_flag = next(self.loader) |
|
except StopIteration: |
|
self.multi_crops, self.weak_flag = None, None |
|
return |
|
|
|
with torch.cuda.stream(self.stream): |
|
for i in range(len(self.multi_crops)): |
|
self.multi_crops[i] = self.multi_crops[i].cuda(non_blocking=True) |
|
if self.fp16: |
|
self.multi_crops[i] = ( |
|
self.multi_crops[i].half().sub_(self.mean).div_(self.std) |
|
) |
|
else: |
|
self.multi_crops[i] = ( |
|
self.multi_crops[i].float().sub_(self.mean).div_(self.std) |
|
) |
|
|
|
def next(self): |
|
""" |
|
load the next minibatch of data |
|
""" |
|
torch.cuda.current_stream().wait_stream(self.stream) |
|
multi_crops, weak_flags = self.multi_crops, self.weak_flag |
|
self.preload() |
|
return multi_crops, weak_flags |
|
|