BuildingExtraction / Utils /Augmentations.py
KyanChen's picture
add model
ab01e4a
import numpy as np
import cv2
import torch
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for t in self.transforms:
data = t(data)
return data
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class ConvertUcharToFloat(object):
"""
Convert img form uchar to float32
"""
def __call__(self, data):
data = [x.astype(np.float32) for x in data]
return data
class RandomContrast(object):
"""
Get random contrast img
"""
def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5):
self.phase = phase
self.lower = lower
self.upper = upper
self.prob = prob
assert self.upper >= self.lower, "contrast upper must be >= lower!"
assert self.lower > 0, "contrast lower must be non-negative!"
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper)
img *= alpha.numpy()
return_data = img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper)
img1 *= alpha.numpy()
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper)
img2 *= alpha.numpy()
return_data = img1, label1, img2, label2
return return_data
class RandomBrightness(object):
"""
Get random brightness img
"""
def __init__(self, phase, delta=10, prob=0.5):
self.phase = phase
self.delta = delta
self.prob = prob
assert 0. <= self.delta < 255., "brightness delta must between 0 to 255"
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
if torch.rand(1) < self.prob:
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta)
img += delta.numpy()
return_data = img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta)
img1 += delta.numpy()
if torch.rand(1) < self.prob:
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta)
img2 += delta.numpy()
return_data = img1, label1, img2, label2
return return_data
class ConvertColor(object):
"""
Convert img color BGR to HSV or HSV to BGR for later img distortion.
"""
def __init__(self, phase, current='RGB', target='HSV'):
self.phase = phase
self.current = current
self.target = target
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
if self.current == 'RGB' and self.target == 'HSV':
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
elif self.current == 'HSV' and self.target == 'RGB':
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
else:
raise NotImplementedError("Convert color fail!")
return_data = img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if self.current == 'RGB' and self.target == 'HSV':
img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV)
img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2HSV)
elif self.current == 'HSV' and self.target == 'RGB':
img1 = cv2.cvtColor(img1, cv2.COLOR_HSV2RGB)
img2 = cv2.cvtColor(img2, cv2.COLOR_HSV2RGB)
else:
raise NotImplementedError("Convert color fail!")
return_data = img1, label1, img2, label2
return return_data
class RandomSaturation(object):
"""
get random saturation img
apply the restriction on saturation S
"""
def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5):
self.phase = phase
self.lower = lower
self.upper = upper
self.prob = prob
assert self.upper >= self.lower, "saturation upper must be >= lower!"
assert self.lower > 0, "saturation lower must be non-negative!"
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper)
img[:, :, 1] *= alpha.numpy()
return_data = img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper)
img1[:, :, 1] *= alpha.numpy()
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper)
img2[:, :, 1] *= alpha.numpy()
return_data = img1, label1, img2, label2
return return_data
class RandomHue(object):
"""
get random Hue img
apply the restriction on Hue H
"""
def __init__(self, phase, delta=10., prob=0.5):
self.phase = phase
self.delta = delta
self.prob = prob
assert 0 <= self.delta < 360, "Hue delta must between 0 to 360!"
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta)
img[:, :, 0] += alpha.numpy()
img[:, :, 0][img[:, :, 0] > 360.0] -= 360.0
img[:, :, 0][img[:, :, 0] < 0.0] += 360.0
return_data = img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta)
img1[:, :, 0] += alpha.numpy()
img1[:, :, 0][img1[:, :, 0] > 360.0] -= 360.0
img1[:, :, 0][img1[:, :, 0] < 0.0] += 360.0
if torch.rand(1) < self.prob:
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta)
img2[:, :, 0] += alpha.numpy()
img2[:, :, 0][img2[:, :, 0] > 360.0] -= 360.0
img2[:, :, 0][img2[:, :, 0] < 0.0] += 360.0
return_data = img1, label1, img2, label2
return return_data
class RandomChannelNoise(object):
"""
Get random shuffle channels
"""
def __init__(self, phase, prob=0.4):
self.phase = phase
self.prob = prob
self.perms = ((0, 1, 2), (0, 2, 1),
(1, 0, 2), (1, 2, 0),
(2, 0, 1), (2, 1, 0))
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
if torch.rand(1) < self.prob:
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])]
img = img[:, :, shuffle_factor]
return_data = img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])]
img1 = img1[:, :, shuffle_factor]
if torch.rand(1) < self.prob:
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])]
img2 = img2[:, :, shuffle_factor]
return_data = img1, label1, img2, label2
return return_data
class ImgDistortion(object):
"""
Change img by distortion
"""
def __init__(self, phase, prob=0.5):
self.phase = phase
self.prob = prob
self.operation = [
RandomContrast(phase),
ConvertColor(phase, current='RGB', target='HSV'),
RandomSaturation(phase),
RandomHue(phase),
ConvertColor(phase, current='HSV', target='RGB'),
RandomContrast(phase)
]
self.random_brightness = RandomBrightness(phase)
self.random_light_noise = RandomChannelNoise(phase)
def __call__(self, data):
if torch.rand(1) < self.prob:
data = self.random_brightness(data)
if torch.rand(1) < self.prob:
distort = Compose(self.operation[:-1])
else:
distort = Compose(self.operation[1:])
data = distort(data)
data = self.random_light_noise(data)
return data
class ExpandImg(object):
"""
Get expand img
"""
def __init__(self, phase, prior_mean, prob=0.5, expand_ratio=0.2):
self.phase = phase
self.prior_mean = np.array(prior_mean) * 255
self.prob = prob
self.expand_ratio = expand_ratio
def __call__(self, data):
if self.phase == 'seg':
img, label = data
if torch.rand(1) < self.prob:
return data
height, width, channels = img.shape
ratio_width = self.expand_ratio * torch.rand([])
ratio_height = self.expand_ratio * torch.rand([])
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2])
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2])
img = cv2.copyMakeBorder(
img, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean)
label = cv2.copyMakeBorder(
label, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0)
return img, label
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
return data
height, width, channels = img1.shape
ratio_width = self.expand_ratio * torch.rand([])
ratio_height = self.expand_ratio * torch.rand([])
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2])
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2])
img1 = cv2.copyMakeBorder(
img1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean)
label1 = cv2.copyMakeBorder(
label1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0)
img2 = cv2.copyMakeBorder(
img2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean)
label2 = cv2.copyMakeBorder(
label2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0)
return img1, label1, img2, label2
elif self.phase == 'od':
if torch.rand(1) < self.prob:
return data
img, label = data
height, width, channels = img.shape
ratio_width = self.expand_ratio * torch.rand([])
ratio_height = self.expand_ratio * torch.rand([])
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2])
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2])
left = int(left)
right = int(right)
top = int(top)
bottom = int(bottom)
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.prior_mean)
label[:, 1::2] += left
label[:, 2::2] += top
return img, label
class RandomSampleCrop(object):
"""
Crop
Arguments:
img (Image): the image being input during training
boxes (Tensor): the original bounding boxes in pt form
label (Tensor): the class label for each bbox
mode (float tuple): the min and max jaccard overlaps
Return:
(img, boxes, classes)
img (Image): the cropped image
boxes (Tensor): the adjusted bounding boxes in pt form
label (Tensor): the class label for each bbox
"""
def __init__(self,
phase,
original_size=[512, 512],
prob=0.5,
crop_scale_ratios_range=[0.8, 1.2],
aspect_ratio_range=[4./5, 5./4]):
self.phase = phase
self.prob = prob
self.scale_range = crop_scale_ratios_range
self.original_size = original_size
self.aspect_ratio_range = aspect_ratio_range # h/w
self.max_try_times = 500
def __call__(self, data):
if self.phase == 'seg':
img, label = data
w, h, c = img.shape
if torch.rand(1) < self.prob:
return data
else:
try_times = 0
while try_times < self.max_try_times:
crop_w = torch.randint(
min(w, int(self.scale_range[0] * self.original_size[0])),
min(w + 1, int(self.scale_range[1] * self.original_size[0])),
size=[]
)
crop_h = torch.randint(
min(h, int(self.scale_range[0] * self.original_size[1])),
min(h + 1, int(self.scale_range[1] * self.original_size[1])),
size=[]
)
# aspect ratio constraint
if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]:
break
else:
try_times += 1
if try_times >= self.max_try_times:
print("try times over max threshold!", flush=True)
return img, label
left = torch.randint(0, w - crop_w + 1, size=[])
top = torch.randint(0, h - crop_h + 1, size=[])
img = img[top:(top + crop_h), left:(left + crop_w), :]
label = label[top:(top + crop_h), left:(left + crop_w)]
return img, label
elif self.phase == 'od':
if torch.rand(1) < self.prob:
return data
img, label = data
w, h, c = img.shape
while True:
crop_w = torch.randint(
min(w, int(self.scale_range[0] * self.original_size[0])),
min(w + 1, int(self.scale_range[1] * self.original_size[0])),
size=[]
)
crop_h = torch.randint(
min(h, int(self.scale_range[0] * self.original_size[1])),
min(h + 1, int(self.scale_range[1] * self.original_size[1])),
size=[]
)
# aspect ratio constraint
if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]:
break
left = torch.randint(0, w - crop_w + 1, size=[])
top = torch.randint(0, h - crop_h + 1, size=[])
left = left.numpy()
top = top.numpy()
crop_h = crop_h.numpy()
crop_w = crop_w.numpy()
img = img[top:(top + crop_h), left:(left + crop_w), :]
if len(label):
# keep overlap with gt box IF center in sampled patch
centers = (label[:, 1:3] + label[:, 3:]) / 2.0
# mask in all gt boxes that above and to the left of centers
m1 = (left <= centers[:, 0]) * (top <= centers[:, 1])
# mask in all gt boxes that under and to the right of centers
m2 = ((left + crop_w) >= centers[:, 0]) * ((top + crop_h) > centers[:, 1])
# mask in that both m1 and m2 are true
mask = m1 * m2
# take only matching gt boxes
current_label = label[mask, :]
# adjust to crop (by substracting crop's left,top)
current_label[:, 1::2] -= left
current_label[:, 2::2] -= top
label = current_label
return img, label
class RandomMirror(object):
def __init__(self, phase, prob=0.5):
self.phase = phase
self.prob = prob
def __call__(self, data):
if self.phase == 'seg':
img, label = data
if torch.rand(1) < self.prob:
img = img[:, ::-1]
label = label[:, ::-1]
return img, label
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
img1 = img1[:, ::-1]
label1 = label1[:, ::-1]
img2 = img2[:, ::-1]
label2 = label2[:, ::-1]
return img1, label1, img2, label2
elif self.phase == 'od':
img, label = data
if torch.rand(1) < self.prob:
_, width, _ = img.shape
img = img[:, ::-1]
label[:, 1::2] = width - label[:, 3::-2]
return img, label
class RandomFlipV(object):
def __init__(self, phase, prob=0.5):
self.phase = phase
self.prob = prob
def __call__(self, data):
if self.phase == 'seg':
img, label = data
if torch.rand(1) < self.prob:
img = img[::-1, :]
label = label[::-1, :]
return img, label
elif self.phase == 'cd':
img1, label1, img2, label2 = data
if torch.rand(1) < self.prob:
img1 = img1[::-1, :]
label1 = label1[::-1, :]
img2 = img2[::-1, :]
label2 = label2[::-1, :]
return img1, label1, img2, label2
elif self.phase == 'od':
img, label = data
if torch.rand(1) < self.prob:
height, _, _ = img.shape
img = img[::-1, :]
label[:, 2::2] = height - label[:, 4:1:-2]
return img, label
class Resize(object):
def __init__(self, phase, size):
self.phase = phase
self.size = size
def __call__(self, data):
if self.phase == 'seg':
img, label = data
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
# for label
label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST)
return img, label
elif self.phase == 'cd':
img1, label1, img2, label2 = data
img1 = cv2.resize(img1, self.size, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, self.size, interpolation=cv2.INTER_LINEAR)
# for label
label1 = cv2.resize(label1, self.size, interpolation=cv2.INTER_NEAREST)
label2 = cv2.resize(label2, self.size, interpolation=cv2.INTER_NEAREST)
return img1, label1, img2, label2
elif self.phase == 'od':
img, label = data
height, width, _ = img.shape
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
label[:, 1::2] = label[:, 1::2] / width * self.size[0]
label[:, 2::2] = label[:, 2::2] / height * self.size[1]
return img, label
class Normalize(object):
def __init__(self, phase, prior_mean, prior_std):
self.phase = phase
self.prior_mean = np.array([[prior_mean]], dtype=np.float32)
self.prior_std = np.array([[prior_std]], dtype=np.float32)
def __call__(self, data):
if self.phase in ['od', 'seg']:
img, _ = data
img = img / 255.
img = (img - self.prior_mean) / (self.prior_std + 1e-10)
return img, _
elif self.phase == 'cd':
img1, label1, img2, label2 = data
img1 = img1 / 255.
img1 = (img1 - self.prior_mean) / (self.prior_std + 1e-10)
img2 = img2 / 255.
img2 = (img2 - self.prior_mean) / (self.prior_std + 1e-10)
return img1, label1, img2, label2
class InvNormalize(object):
def __init__(self, prior_mean, prior_std):
self.prior_mean = np.array([[prior_mean]], dtype=np.float32)
self.prior_std = np.array([[prior_std]], dtype=np.float32)
def __call__(self, img):
img = img * self.prior_std + self.prior_mean
img = img * 255.
img = np.clip(img, a_min=0, a_max=255)
return img
class Augmentations(object):
def __init__(self, size, prior_mean=0, prior_std=1, pattern='train', phase='seg', *args, **kwargs):
self.size = size
self.prior_mean = prior_mean
self.prior_std = prior_std
self.phase = phase
augments = {
'train': Compose([
ConvertUcharToFloat(),
ImgDistortion(self.phase),
ExpandImg(self.phase, self.prior_mean),
RandomSampleCrop(self.phase, original_size=self.size),
RandomMirror(self.phase),
RandomFlipV(self.phase),
Resize(self.phase, self.size),
Normalize(self.phase, self.prior_mean, self.prior_std),
]),
'val': Compose([
ConvertUcharToFloat(),
Resize(self.phase, self.size),
Normalize(self.phase, self.prior_mean, self.prior_std),
]),
'test': Compose([
ConvertUcharToFloat(),
Resize(self.phase, self.size),
Normalize(self.phase, self.prior_mean, self.prior_std),
])
}
self.augment = augments[pattern]
def __call__(self, data):
return self.augment(data)