|
"""
|
|
Adapted from https://github.com/nv-nguyen/template-pose/blob/main/src/utils/augmentation.py
|
|
"""
|
|
|
|
from torchvision import transforms
|
|
from PIL import ImageEnhance, ImageFilter, Image
|
|
import numpy as np
|
|
import random
|
|
import logging
|
|
from torchvision.transforms import RandomResizedCrop, ToTensor
|
|
|
|
|
|
class PillowRGBAugmentation:
|
|
def __init__(self, pillow_fn, p, factor_interval):
|
|
self._pillow_fn = pillow_fn
|
|
self.p = p
|
|
self.factor_interval = factor_interval
|
|
|
|
def __call__(self, PIL_image):
|
|
if random.random() <= self.p:
|
|
factor = random.uniform(*self.factor_interval)
|
|
if PIL_image.mode != "RGB":
|
|
logging.warning(
|
|
f"Error when apply data aug, image mode: {PIL_image.mode}"
|
|
)
|
|
imgs = imgs.convert("RGB")
|
|
logging.warning(f"Success to change to {PIL_image.mode}")
|
|
PIL_image = (self._pillow_fn(PIL_image).enhance(factor=factor)).convert(
|
|
"RGB"
|
|
)
|
|
return PIL_image
|
|
|
|
|
|
class PillowSharpness(PillowRGBAugmentation):
|
|
def __init__(
|
|
self,
|
|
p=0.3,
|
|
factor_interval=(0, 40.0),
|
|
):
|
|
super().__init__(
|
|
pillow_fn=ImageEnhance.Sharpness,
|
|
p=p,
|
|
factor_interval=factor_interval,
|
|
)
|
|
|
|
|
|
class PillowContrast(PillowRGBAugmentation):
|
|
def __init__(
|
|
self,
|
|
p=0.3,
|
|
factor_interval=(0.5, 1.6),
|
|
):
|
|
super().__init__(
|
|
pillow_fn=ImageEnhance.Contrast,
|
|
p=p,
|
|
factor_interval=factor_interval,
|
|
)
|
|
|
|
|
|
class PillowBrightness(PillowRGBAugmentation):
|
|
def __init__(
|
|
self,
|
|
p=0.5,
|
|
factor_interval=(0.5, 2.0),
|
|
):
|
|
super().__init__(
|
|
pillow_fn=ImageEnhance.Brightness,
|
|
p=p,
|
|
factor_interval=factor_interval,
|
|
)
|
|
|
|
|
|
class PillowColor(PillowRGBAugmentation):
|
|
def __init__(
|
|
self,
|
|
p=1,
|
|
factor_interval=(0.0, 20.0),
|
|
):
|
|
super().__init__(
|
|
pillow_fn=ImageEnhance.Color,
|
|
p=p,
|
|
factor_interval=factor_interval,
|
|
)
|
|
|
|
|
|
class PillowBlur:
|
|
def __init__(self, p=0.4, factor_interval=(1, 3)):
|
|
self.p = p
|
|
self.k = random.randint(*factor_interval)
|
|
|
|
def __call__(self, PIL_image):
|
|
if random.random() <= self.p:
|
|
PIL_image = PIL_image.filter(ImageFilter.GaussianBlur(self.k))
|
|
return PIL_image
|
|
|
|
|
|
class NumpyGaussianNoise:
|
|
def __init__(self, p, factor_interval=(0.01, 0.3)):
|
|
self.noise_ratio = random.uniform(*factor_interval)
|
|
self.p = p
|
|
|
|
def __call__(self, img):
|
|
if random.random() <= self.p:
|
|
img = np.copy(img)
|
|
noisesigma = random.uniform(0, self.noise_ratio)
|
|
gauss = np.random.normal(0, noisesigma, img.shape) * 255
|
|
img = img + gauss
|
|
|
|
img[img > 255] = 255
|
|
img[img < 0] = 0
|
|
return Image.fromarray(np.uint8(img))
|
|
|
|
|
|
class StandardAugmentation:
|
|
def __init__(
|
|
self, names, brightness, contrast, sharpness, color, blur, gaussian_noise
|
|
):
|
|
self.brightness = brightness
|
|
self.contrast = contrast
|
|
self.sharpness = sharpness
|
|
self.color = color
|
|
self.blur = blur
|
|
self.gaussian_noise = gaussian_noise
|
|
|
|
|
|
self.names = names.split(",")
|
|
self.augmentations = {
|
|
"brightness": self.brightness,
|
|
"contrast": self.contrast,
|
|
"sharpness": self.sharpness,
|
|
"color": self.color,
|
|
"blur": self.blur,
|
|
"gaussian_noise": self.gaussian_noise,
|
|
}
|
|
|
|
def __call__(self, img):
|
|
for name in self.names:
|
|
img = self.augmentations[name](img)
|
|
return img
|
|
|
|
|
|
class GeometricAugmentation:
|
|
def __init__(
|
|
self,
|
|
names,
|
|
random_resized_crop,
|
|
random_horizontal_flip,
|
|
random_vertical_flip,
|
|
random_rotation,
|
|
):
|
|
self.random_resized_crop = random_resized_crop
|
|
self.random_horizontal_flip = random_horizontal_flip
|
|
self.random_vertical_flip = random_vertical_flip
|
|
self.random_rotation = random_rotation
|
|
self.names = names.split(",")
|
|
|
|
self.augmentations = {
|
|
"random_resized_crop": self.random_resized_crop,
|
|
"random_horizontal_flip": self.random_horizontal_flip,
|
|
"random_vertical_flip": self.random_vertical_flip,
|
|
"random_rotation": self.random_rotation,
|
|
}
|
|
|
|
def __call__(self, img):
|
|
for name in self.names:
|
|
img = self.augmentations[name](img)
|
|
return img
|
|
|
|
|
|
class ImageAugmentation:
|
|
def __init__(
|
|
self, names, clip_transform, standard_augmentation, geometric_augmentation
|
|
):
|
|
self.clip_transform = clip_transform
|
|
self.standard_augmentation = standard_augmentation
|
|
self.geometric_augmentation = geometric_augmentation
|
|
self.names = names.split(",")
|
|
self.transforms = {
|
|
"clip_transform": self.clip_transform,
|
|
"standard_augmentation": self.standard_augmentation,
|
|
"geometric_augmentation": self.geometric_augmentation,
|
|
}
|
|
print(f"Image augmentation: {self.names}")
|
|
|
|
def __call__(self, img):
|
|
for name in self.names:
|
|
img = self.transforms[name](img)
|
|
return img
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import glob
|
|
import torchvision.transforms as transforms
|
|
from torchvision.utils import save_image
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from hydra.utils import instantiate
|
|
import torch
|
|
from PIL import Image
|
|
|
|
augmentation_config = OmegaConf.load(
|
|
"./configs/dataset/train_transform/augmentation.yaml"
|
|
)
|
|
augmentation_config.names = "standard_augmentation,geometric_augmentation"
|
|
augmentation_transform = instantiate(augmentation_config)
|
|
img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg")
|
|
|
|
num_try = 20
|
|
num_try_per_image = 8
|
|
num_imgs = 8
|
|
|
|
for idx in range(num_try):
|
|
imgs = []
|
|
for idx_img in range(num_imgs):
|
|
img = Image.open(img_paths[idx_img])
|
|
for idx_try in range(num_try_per_image):
|
|
if idx_try == 0:
|
|
imgs.append(ToTensor()(img.resize((224, 224))))
|
|
img_aug = augmentation_transform(img.copy())
|
|
img_aug = ToTensor()(img_aug)
|
|
imgs.append(img_aug)
|
|
imgs = torch.stack(imgs)
|
|
save_image(imgs, f"augmentation_{idx:03d}.png", nrow=9)
|
|
|