Spaces:
Sleeping
Sleeping
import random | |
import numpy as np | |
from torchvision import transforms | |
from PIL import Image | |
class CutPaste(object): | |
def __init__(self, transform = True, type = 'binary'): | |
''' | |
This class creates to different augmentation CutPaste and CutPaste-Scar. Moreover, it returns augmented images | |
for binary and 3 way classification | |
:arg | |
:transform[binary]: - if True use Color Jitter augmentations for patches | |
:type[str]: options ['binary' or '3way'] - classification type | |
''' | |
self.type = type | |
if transform: | |
self.transform = transforms.ColorJitter(brightness = 0.1, | |
contrast = 0.1, | |
saturation = 0.1, | |
hue = 0.1) | |
else: | |
self.transform = None | |
def crop_and_paste_patch(image, patch_w, patch_h, transform, rotation=False): | |
""" | |
Crop patch from original image and paste it randomly on the same image. | |
:image: [PIL] _ original image | |
:patch_w: [int] _ width of the patch | |
:patch_h: [int] _ height of the patch | |
:transform: [binary] _ if True use Color Jitter augmentation | |
:rotation: [binary[ _ if True randomly rotates image from (-45, 45) range | |
:return: augmented image | |
""" | |
org_w, org_h = image.size | |
mask = None | |
patch_left, patch_top = random.randint(0, org_w - patch_w), random.randint(0, org_h - patch_h) | |
patch_right, patch_bottom = patch_left + patch_w, patch_top + patch_h | |
patch = image.crop((patch_left, patch_top, patch_right, patch_bottom)) | |
if transform: | |
patch= transform(patch) | |
if rotation: | |
random_rotate = random.uniform(*rotation) | |
patch = patch.convert("RGBA").rotate(random_rotate, expand=True) | |
mask = patch.split()[-1] | |
# new location | |
paste_left, paste_top = random.randint(0, org_w - patch_w), random.randint(0, org_h - patch_h) | |
aug_image = image.copy() | |
aug_image.paste(patch, (paste_left, paste_top), mask=mask) | |
# Create a mask of the pasted area | |
paste_right, paste_bottom = paste_left + patch_w, paste_top + patch_h | |
paste_mask = Image.new('L', image.size, 0) | |
paste_mask.paste(255, (paste_left, paste_top, paste_right, paste_bottom)) | |
return aug_image,paste_mask | |
def cutpaste(self, image, area_ratio = (0.02, 0.15), aspect_ratio = ((0.3, 1) , (1, 3.3))): | |
''' | |
CutPaste augmentation | |
:image: [PIL] - original image | |
:area_ratio: [tuple] - range for area ratio for patch | |
:aspect_ratio: [tuple] - range for aspect ratio | |
:return: PIL image after CutPaste transformation | |
''' | |
img_area = image.size[0] * image.size[1] | |
patch_area = random.uniform(*area_ratio) * img_area | |
patch_aspect = random.choice([random.uniform(*aspect_ratio[0]), random.uniform(*aspect_ratio[1])]) | |
patch_w = int(np.sqrt(patch_area*patch_aspect)) | |
patch_h = int(np.sqrt(patch_area/patch_aspect)) | |
cutpaste,paste_mask = self.crop_and_paste_patch(image, patch_w, patch_h, self.transform, rotation = False) | |
return cutpaste,paste_mask | |
def cutpaste_scar(self, image, width = [2,16], length = [10,25], rotation = (-45, 45)): | |
''' | |
:image: [PIL] - original image | |
:width: [list] - range for width of patch | |
:length: [list] - range for length of patch | |
:rotation: [tuple] - range for rotation | |
:return: PIL image after CutPaste-Scare transformation | |
''' | |
patch_w, patch_h = random.randint(*width), random.randint(*length) | |
cutpaste_scar,paste_mask = self.crop_and_paste_patch(image, patch_w, patch_h, self.transform, rotation = rotation) | |
return cutpaste_scar,paste_mask | |
def __call__(self, image): | |
''' | |
:image: [PIL] - original image | |
:return: if type == 'binary' returns original image and randomly chosen transformation, else it returns | |
original image, an image after CutPaste transformation and an image after CutPaste-Scar transformation | |
''' | |
if self.type == 'binary': | |
aug = random.choice([self.cutpaste, self.cutpaste_scar]) | |
return image, aug(image) | |
elif self.type == '3way': | |
cutpaste = self.cutpaste(image) | |
scar = self.cutpaste_scar(image) | |
return image, cutpaste, scar |