| | """ |
| | Adapted from https://github.com/nv-nguyen/template-pose/blob/main/src/utils/augmentation.py |
| | """ |
| |
|
| | from torchvision import transforms |
| | from PIL import ImageEnhance, ImageFilter, Image |
| | import numpy as np |
| | import random |
| | import logging |
| | from torchvision.transforms import RandomResizedCrop, ToTensor |
| |
|
| |
|
| | class PillowRGBAugmentation: |
| | def __init__(self, pillow_fn, p, factor_interval): |
| | self._pillow_fn = pillow_fn |
| | self.p = p |
| | self.factor_interval = factor_interval |
| |
|
| | def __call__(self, PIL_image): |
| | if random.random() <= self.p: |
| | factor = random.uniform(*self.factor_interval) |
| | if PIL_image.mode != "RGB": |
| | logging.warning( |
| | f"Error when apply data aug, image mode: {PIL_image.mode}" |
| | ) |
| | imgs = imgs.convert("RGB") |
| | logging.warning(f"Success to change to {PIL_image.mode}") |
| | PIL_image = (self._pillow_fn(PIL_image).enhance(factor=factor)).convert( |
| | "RGB" |
| | ) |
| | return PIL_image |
| |
|
| |
|
| | class PillowSharpness(PillowRGBAugmentation): |
| | def __init__( |
| | self, |
| | p=0.3, |
| | factor_interval=(0, 40.0), |
| | ): |
| | super().__init__( |
| | pillow_fn=ImageEnhance.Sharpness, |
| | p=p, |
| | factor_interval=factor_interval, |
| | ) |
| |
|
| |
|
| | class PillowContrast(PillowRGBAugmentation): |
| | def __init__( |
| | self, |
| | p=0.3, |
| | factor_interval=(0.5, 1.6), |
| | ): |
| | super().__init__( |
| | pillow_fn=ImageEnhance.Contrast, |
| | p=p, |
| | factor_interval=factor_interval, |
| | ) |
| |
|
| |
|
| | class PillowBrightness(PillowRGBAugmentation): |
| | def __init__( |
| | self, |
| | p=0.5, |
| | factor_interval=(0.5, 2.0), |
| | ): |
| | super().__init__( |
| | pillow_fn=ImageEnhance.Brightness, |
| | p=p, |
| | factor_interval=factor_interval, |
| | ) |
| |
|
| |
|
| | class PillowColor(PillowRGBAugmentation): |
| | def __init__( |
| | self, |
| | p=1, |
| | factor_interval=(0.0, 20.0), |
| | ): |
| | super().__init__( |
| | pillow_fn=ImageEnhance.Color, |
| | p=p, |
| | factor_interval=factor_interval, |
| | ) |
| |
|
| |
|
| | class PillowBlur: |
| | def __init__(self, p=0.4, factor_interval=(1, 3)): |
| | self.p = p |
| | self.k = random.randint(*factor_interval) |
| |
|
| | def __call__(self, PIL_image): |
| | if random.random() <= self.p: |
| | PIL_image = PIL_image.filter(ImageFilter.GaussianBlur(self.k)) |
| | return PIL_image |
| |
|
| |
|
| | class NumpyGaussianNoise: |
| | def __init__(self, p, factor_interval=(0.01, 0.3)): |
| | self.noise_ratio = random.uniform(*factor_interval) |
| | self.p = p |
| |
|
| | def __call__(self, img): |
| | if random.random() <= self.p: |
| | img = np.copy(img) |
| | noisesigma = random.uniform(0, self.noise_ratio) |
| | gauss = np.random.normal(0, noisesigma, img.shape) * 255 |
| | img = img + gauss |
| |
|
| | img[img > 255] = 255 |
| | img[img < 0] = 0 |
| | return Image.fromarray(np.uint8(img)) |
| |
|
| |
|
| | class StandardAugmentation: |
| | def __init__( |
| | self, names, brightness, contrast, sharpness, color, blur, gaussian_noise |
| | ): |
| | self.brightness = brightness |
| | self.contrast = contrast |
| | self.sharpness = sharpness |
| | self.color = color |
| | self.blur = blur |
| | self.gaussian_noise = gaussian_noise |
| |
|
| | |
| | self.names = names.split(",") |
| | self.augmentations = { |
| | "brightness": self.brightness, |
| | "contrast": self.contrast, |
| | "sharpness": self.sharpness, |
| | "color": self.color, |
| | "blur": self.blur, |
| | "gaussian_noise": self.gaussian_noise, |
| | } |
| |
|
| | def __call__(self, img): |
| | for name in self.names: |
| | img = self.augmentations[name](img) |
| | return img |
| |
|
| |
|
| | class GeometricAugmentation: |
| | def __init__( |
| | self, |
| | names, |
| | random_resized_crop, |
| | random_horizontal_flip, |
| | random_vertical_flip, |
| | random_rotation, |
| | ): |
| | self.random_resized_crop = random_resized_crop |
| | self.random_horizontal_flip = random_horizontal_flip |
| | self.random_vertical_flip = random_vertical_flip |
| | self.random_rotation = random_rotation |
| | self.names = names.split(",") |
| |
|
| | self.augmentations = { |
| | "random_resized_crop": self.random_resized_crop, |
| | "random_horizontal_flip": self.random_horizontal_flip, |
| | "random_vertical_flip": self.random_vertical_flip, |
| | "random_rotation": self.random_rotation, |
| | } |
| |
|
| | def __call__(self, img): |
| | for name in self.names: |
| | img = self.augmentations[name](img) |
| | return img |
| |
|
| |
|
| | class ImageAugmentation: |
| | def __init__( |
| | self, names, clip_transform, standard_augmentation, geometric_augmentation |
| | ): |
| | self.clip_transform = clip_transform |
| | self.standard_augmentation = standard_augmentation |
| | self.geometric_augmentation = geometric_augmentation |
| | self.names = names.split(",") |
| | self.transforms = { |
| | "clip_transform": self.clip_transform, |
| | "standard_augmentation": self.standard_augmentation, |
| | "geometric_augmentation": self.geometric_augmentation, |
| | } |
| | print(f"Image augmentation: {self.names}") |
| |
|
| | def __call__(self, img): |
| | for name in self.names: |
| | img = self.transforms[name](img) |
| | return img |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import glob |
| | import torchvision.transforms as transforms |
| | from torchvision.utils import save_image |
| | from omegaconf import DictConfig, OmegaConf |
| | from hydra.utils import instantiate |
| | import torch |
| | from PIL import Image |
| |
|
| | augmentation_config = OmegaConf.load( |
| | "./configs/dataset/train_transform/augmentation.yaml" |
| | ) |
| | augmentation_config.names = "standard_augmentation,geometric_augmentation" |
| | augmentation_transform = instantiate(augmentation_config) |
| | img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg") |
| |
|
| | num_try = 20 |
| | num_try_per_image = 8 |
| | num_imgs = 8 |
| |
|
| | for idx in range(num_try): |
| | imgs = [] |
| | for idx_img in range(num_imgs): |
| | img = Image.open(img_paths[idx_img]) |
| | for idx_try in range(num_try_per_image): |
| | if idx_try == 0: |
| | imgs.append(ToTensor()(img.resize((224, 224)))) |
| | img_aug = augmentation_transform(img.copy()) |
| | img_aug = ToTensor()(img_aug) |
| | imgs.append(img_aug) |
| | imgs = torch.stack(imgs) |
| | save_image(imgs, f"augmentation_{idx:03d}.png", nrow=9) |
| |
|