xai-cl / data_transforms.py
Annonymous
Upload 3 files
cddd431
raw history blame
No virus
5.03 kB
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image, ImageOps, ImageFilter
import random
def add_normalization_to_transform(unnormalized_transforms):
"""Adds ImageNet normalization to all transforms"""
normalized_transform = {}
for key, value in unnormalized_transforms.items():
normalized_transform[key] = transforms.Compose([value,
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
return normalized_transform
def modify_transforms(normal_transforms, no_shift_transforms, ig_transforms):
normal_transforms = add_normalization_to_transform(normal_transforms)
no_shift_transforms = add_normalization_to_transform(no_shift_transforms)
ig_transforms = add_normalization_to_transform(ig_transforms)
return normal_transforms, no_shift_transforms, ig_transforms
class Solarization(object):
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
# no imagent normalization for simclrv2
pure_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
aug_transform = transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1,2.0))], p=0.5),
transforms.ToTensor()])
ig_pure_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
ig_transform_colorjitter = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.4)], p=1),
transforms.ToTensor()])
ig_transform_blur = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(11,11), sigma=(5,5))], p=1),
transforms.ToTensor()])
ig_transform_solarize = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
Solarization(p=1.0),
transforms.ToTensor()])
ig_transform_grayscale = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomGrayscale(p=1),
transforms.ToTensor()])
ig_transform_combine = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([transforms.GaussianBlur(kernel_size=(21,21), sigma=(0.1, 2.0))], p=0.5),
transforms.ToTensor()])
pure_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor()])
aug_transform_no_shift = transforms.Compose([transforms.Resize((224, 224)),
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor()])
normal_transforms = {'pure': pure_transform,
'aug': aug_transform}
no_shift_transforms = {'pure': pure_transform_no_shift,
'aug': aug_transform_no_shift}
ig_transforms = {'pure': ig_pure_transform,
'color_jitter': ig_transform_colorjitter,
'blur': ig_transform_blur,
'grayscale': ig_transform_grayscale,
'solarize': ig_transform_solarize,
'combine': ig_transform_combine}