|
import numpy as np |
|
import cv2 |
|
import torch |
|
|
|
class Compose(object): |
|
"""Composes several transforms together. |
|
|
|
Args: |
|
transforms (list of ``Transform`` objects): list of transforms to compose. |
|
|
|
Example: |
|
>>> transforms.Compose([ |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.ToTensor(), |
|
>>> ]) |
|
""" |
|
|
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, data): |
|
for t in self.transforms: |
|
data = t(data) |
|
return data |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + '(' |
|
for t in self.transforms: |
|
format_string += '\n' |
|
format_string += ' {0}'.format(t) |
|
format_string += '\n)' |
|
return format_string |
|
|
|
|
|
class ConvertUcharToFloat(object): |
|
""" |
|
Convert img form uchar to float32 |
|
""" |
|
|
|
def __call__(self, data): |
|
data = [x.astype(np.float32) for x in data] |
|
return data |
|
|
|
|
|
class RandomContrast(object): |
|
""" |
|
Get random contrast img |
|
""" |
|
def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): |
|
self.phase = phase |
|
self.lower = lower |
|
self.upper = upper |
|
self.prob = prob |
|
assert self.upper >= self.lower, "contrast upper must be >= lower!" |
|
assert self.lower > 0, "contrast lower must be non-negative!" |
|
|
|
def __call__(self, data): |
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
|
img *= alpha.numpy() |
|
return_data = img, _ |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
|
img1 *= alpha.numpy() |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
|
img2 *= alpha.numpy() |
|
return_data = img1, label1, img2, label2 |
|
return return_data |
|
|
|
|
|
class RandomBrightness(object): |
|
""" |
|
Get random brightness img |
|
""" |
|
def __init__(self, phase, delta=10, prob=0.5): |
|
self.phase = phase |
|
self.delta = delta |
|
self.prob = prob |
|
assert 0. <= self.delta < 255., "brightness delta must between 0 to 255" |
|
|
|
def __call__(self, data): |
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
if torch.rand(1) < self.prob: |
|
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) |
|
img += delta.numpy() |
|
return_data = img, _ |
|
|
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) |
|
img1 += delta.numpy() |
|
if torch.rand(1) < self.prob: |
|
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) |
|
img2 += delta.numpy() |
|
return_data = img1, label1, img2, label2 |
|
|
|
return return_data |
|
|
|
|
|
class ConvertColor(object): |
|
""" |
|
Convert img color BGR to HSV or HSV to BGR for later img distortion. |
|
""" |
|
def __init__(self, phase, current='RGB', target='HSV'): |
|
self.phase = phase |
|
self.current = current |
|
self.target = target |
|
|
|
def __call__(self, data): |
|
|
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
if self.current == 'RGB' and self.target == 'HSV': |
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) |
|
elif self.current == 'HSV' and self.target == 'RGB': |
|
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) |
|
else: |
|
raise NotImplementedError("Convert color fail!") |
|
return_data = img, _ |
|
|
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if self.current == 'RGB' and self.target == 'HSV': |
|
img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV) |
|
img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2HSV) |
|
elif self.current == 'HSV' and self.target == 'RGB': |
|
img1 = cv2.cvtColor(img1, cv2.COLOR_HSV2RGB) |
|
img2 = cv2.cvtColor(img2, cv2.COLOR_HSV2RGB) |
|
else: |
|
raise NotImplementedError("Convert color fail!") |
|
return_data = img1, label1, img2, label2 |
|
|
|
return return_data |
|
|
|
|
|
class RandomSaturation(object): |
|
""" |
|
get random saturation img |
|
apply the restriction on saturation S |
|
""" |
|
def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): |
|
self.phase = phase |
|
self.lower = lower |
|
self.upper = upper |
|
self.prob = prob |
|
assert self.upper >= self.lower, "saturation upper must be >= lower!" |
|
assert self.lower > 0, "saturation lower must be non-negative!" |
|
|
|
def __call__(self, data): |
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
|
img[:, :, 1] *= alpha.numpy() |
|
return_data = img, _ |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
|
img1[:, :, 1] *= alpha.numpy() |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) |
|
img2[:, :, 1] *= alpha.numpy() |
|
return_data = img1, label1, img2, label2 |
|
return return_data |
|
|
|
|
|
class RandomHue(object): |
|
""" |
|
get random Hue img |
|
apply the restriction on Hue H |
|
""" |
|
def __init__(self, phase, delta=10., prob=0.5): |
|
self.phase = phase |
|
self.delta = delta |
|
self.prob = prob |
|
assert 0 <= self.delta < 360, "Hue delta must between 0 to 360!" |
|
|
|
def __call__(self, data): |
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) |
|
img[:, :, 0] += alpha.numpy() |
|
img[:, :, 0][img[:, :, 0] > 360.0] -= 360.0 |
|
img[:, :, 0][img[:, :, 0] < 0.0] += 360.0 |
|
return_data = img, _ |
|
|
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) |
|
img1[:, :, 0] += alpha.numpy() |
|
img1[:, :, 0][img1[:, :, 0] > 360.0] -= 360.0 |
|
img1[:, :, 0][img1[:, :, 0] < 0.0] += 360.0 |
|
if torch.rand(1) < self.prob: |
|
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) |
|
img2[:, :, 0] += alpha.numpy() |
|
img2[:, :, 0][img2[:, :, 0] > 360.0] -= 360.0 |
|
img2[:, :, 0][img2[:, :, 0] < 0.0] += 360.0 |
|
|
|
return_data = img1, label1, img2, label2 |
|
|
|
return return_data |
|
|
|
|
|
class RandomChannelNoise(object): |
|
""" |
|
Get random shuffle channels |
|
""" |
|
def __init__(self, phase, prob=0.4): |
|
self.phase = phase |
|
self.prob = prob |
|
self.perms = ((0, 1, 2), (0, 2, 1), |
|
(1, 0, 2), (1, 2, 0), |
|
(2, 0, 1), (2, 1, 0)) |
|
|
|
def __call__(self, data): |
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
if torch.rand(1) < self.prob: |
|
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] |
|
img = img[:, :, shuffle_factor] |
|
return_data = img, _ |
|
|
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] |
|
img1 = img1[:, :, shuffle_factor] |
|
if torch.rand(1) < self.prob: |
|
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] |
|
img2 = img2[:, :, shuffle_factor] |
|
return_data = img1, label1, img2, label2 |
|
|
|
return return_data |
|
|
|
|
|
class ImgDistortion(object): |
|
""" |
|
Change img by distortion |
|
""" |
|
def __init__(self, phase, prob=0.5): |
|
self.phase = phase |
|
self.prob = prob |
|
self.operation = [ |
|
RandomContrast(phase), |
|
ConvertColor(phase, current='RGB', target='HSV'), |
|
RandomSaturation(phase), |
|
RandomHue(phase), |
|
ConvertColor(phase, current='HSV', target='RGB'), |
|
RandomContrast(phase) |
|
] |
|
self.random_brightness = RandomBrightness(phase) |
|
self.random_light_noise = RandomChannelNoise(phase) |
|
|
|
def __call__(self, data): |
|
if torch.rand(1) < self.prob: |
|
data = self.random_brightness(data) |
|
if torch.rand(1) < self.prob: |
|
distort = Compose(self.operation[:-1]) |
|
else: |
|
distort = Compose(self.operation[1:]) |
|
data = distort(data) |
|
data = self.random_light_noise(data) |
|
return data |
|
|
|
|
|
class ExpandImg(object): |
|
""" |
|
Get expand img |
|
""" |
|
def __init__(self, phase, prior_mean, prob=0.5, expand_ratio=0.2): |
|
self.phase = phase |
|
self.prior_mean = np.array(prior_mean) * 255 |
|
self.prob = prob |
|
self.expand_ratio = expand_ratio |
|
|
|
def __call__(self, data): |
|
if self.phase == 'seg': |
|
img, label = data |
|
if torch.rand(1) < self.prob: |
|
return data |
|
height, width, channels = img.shape |
|
ratio_width = self.expand_ratio * torch.rand([]) |
|
ratio_height = self.expand_ratio * torch.rand([]) |
|
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) |
|
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) |
|
img = cv2.copyMakeBorder( |
|
img, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) |
|
label = cv2.copyMakeBorder( |
|
label, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) |
|
return img, label |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
return data |
|
height, width, channels = img1.shape |
|
ratio_width = self.expand_ratio * torch.rand([]) |
|
ratio_height = self.expand_ratio * torch.rand([]) |
|
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) |
|
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) |
|
img1 = cv2.copyMakeBorder( |
|
img1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) |
|
label1 = cv2.copyMakeBorder( |
|
label1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) |
|
img2 = cv2.copyMakeBorder( |
|
img2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) |
|
label2 = cv2.copyMakeBorder( |
|
label2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) |
|
return img1, label1, img2, label2 |
|
|
|
elif self.phase == 'od': |
|
if torch.rand(1) < self.prob: |
|
return data |
|
img, label = data |
|
height, width, channels = img.shape |
|
ratio_width = self.expand_ratio * torch.rand([]) |
|
ratio_height = self.expand_ratio * torch.rand([]) |
|
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) |
|
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) |
|
left = int(left) |
|
right = int(right) |
|
top = int(top) |
|
bottom = int(bottom) |
|
img = cv2.copyMakeBorder( |
|
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.prior_mean) |
|
|
|
label[:, 1::2] += left |
|
label[:, 2::2] += top |
|
return img, label |
|
|
|
|
|
class RandomSampleCrop(object): |
|
""" |
|
Crop |
|
Arguments: |
|
img (Image): the image being input during training |
|
boxes (Tensor): the original bounding boxes in pt form |
|
label (Tensor): the class label for each bbox |
|
mode (float tuple): the min and max jaccard overlaps |
|
Return: |
|
(img, boxes, classes) |
|
img (Image): the cropped image |
|
boxes (Tensor): the adjusted bounding boxes in pt form |
|
label (Tensor): the class label for each bbox |
|
""" |
|
def __init__(self, |
|
phase, |
|
original_size=[512, 512], |
|
prob=0.5, |
|
crop_scale_ratios_range=[0.8, 1.2], |
|
aspect_ratio_range=[4./5, 5./4]): |
|
self.phase = phase |
|
self.prob = prob |
|
self.scale_range = crop_scale_ratios_range |
|
self.original_size = original_size |
|
self.aspect_ratio_range = aspect_ratio_range |
|
self.max_try_times = 500 |
|
|
|
def __call__(self, data): |
|
if self.phase == 'seg': |
|
img, label = data |
|
w, h, c = img.shape |
|
if torch.rand(1) < self.prob: |
|
return data |
|
else: |
|
try_times = 0 |
|
while try_times < self.max_try_times: |
|
crop_w = torch.randint( |
|
min(w, int(self.scale_range[0] * self.original_size[0])), |
|
min(w + 1, int(self.scale_range[1] * self.original_size[0])), |
|
size=[] |
|
) |
|
crop_h = torch.randint( |
|
min(h, int(self.scale_range[0] * self.original_size[1])), |
|
min(h + 1, int(self.scale_range[1] * self.original_size[1])), |
|
size=[] |
|
) |
|
|
|
if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: |
|
break |
|
else: |
|
try_times += 1 |
|
if try_times >= self.max_try_times: |
|
print("try times over max threshold!", flush=True) |
|
return img, label |
|
|
|
left = torch.randint(0, w - crop_w + 1, size=[]) |
|
top = torch.randint(0, h - crop_h + 1, size=[]) |
|
img = img[top:(top + crop_h), left:(left + crop_w), :] |
|
label = label[top:(top + crop_h), left:(left + crop_w)] |
|
return img, label |
|
|
|
elif self.phase == 'od': |
|
if torch.rand(1) < self.prob: |
|
return data |
|
img, label = data |
|
w, h, c = img.shape |
|
|
|
while True: |
|
crop_w = torch.randint( |
|
min(w, int(self.scale_range[0] * self.original_size[0])), |
|
min(w + 1, int(self.scale_range[1] * self.original_size[0])), |
|
size=[] |
|
) |
|
crop_h = torch.randint( |
|
min(h, int(self.scale_range[0] * self.original_size[1])), |
|
min(h + 1, int(self.scale_range[1] * self.original_size[1])), |
|
size=[] |
|
) |
|
|
|
|
|
if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: |
|
break |
|
|
|
left = torch.randint(0, w - crop_w + 1, size=[]) |
|
top = torch.randint(0, h - crop_h + 1, size=[]) |
|
left = left.numpy() |
|
top = top.numpy() |
|
crop_h = crop_h.numpy() |
|
crop_w = crop_w.numpy() |
|
img = img[top:(top + crop_h), left:(left + crop_w), :] |
|
if len(label): |
|
|
|
centers = (label[:, 1:3] + label[:, 3:]) / 2.0 |
|
|
|
m1 = (left <= centers[:, 0]) * (top <= centers[:, 1]) |
|
|
|
m2 = ((left + crop_w) >= centers[:, 0]) * ((top + crop_h) > centers[:, 1]) |
|
|
|
mask = m1 * m2 |
|
|
|
|
|
current_label = label[mask, :] |
|
|
|
|
|
current_label[:, 1::2] -= left |
|
current_label[:, 2::2] -= top |
|
label = current_label |
|
return img, label |
|
|
|
|
|
class RandomMirror(object): |
|
def __init__(self, phase, prob=0.5): |
|
self.phase = phase |
|
self.prob = prob |
|
|
|
def __call__(self, data): |
|
if self.phase == 'seg': |
|
img, label = data |
|
if torch.rand(1) < self.prob: |
|
img = img[:, ::-1] |
|
label = label[:, ::-1] |
|
return img, label |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
img1 = img1[:, ::-1] |
|
label1 = label1[:, ::-1] |
|
img2 = img2[:, ::-1] |
|
label2 = label2[:, ::-1] |
|
return img1, label1, img2, label2 |
|
elif self.phase == 'od': |
|
img, label = data |
|
if torch.rand(1) < self.prob: |
|
_, width, _ = img.shape |
|
img = img[:, ::-1] |
|
label[:, 1::2] = width - label[:, 3::-2] |
|
return img, label |
|
|
|
|
|
class RandomFlipV(object): |
|
def __init__(self, phase, prob=0.5): |
|
self.phase = phase |
|
self.prob = prob |
|
|
|
def __call__(self, data): |
|
if self.phase == 'seg': |
|
img, label = data |
|
if torch.rand(1) < self.prob: |
|
img = img[::-1, :] |
|
label = label[::-1, :] |
|
return img, label |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
if torch.rand(1) < self.prob: |
|
img1 = img1[::-1, :] |
|
label1 = label1[::-1, :] |
|
img2 = img2[::-1, :] |
|
label2 = label2[::-1, :] |
|
return img1, label1, img2, label2 |
|
elif self.phase == 'od': |
|
img, label = data |
|
if torch.rand(1) < self.prob: |
|
height, _, _ = img.shape |
|
img = img[::-1, :] |
|
label[:, 2::2] = height - label[:, 4:1:-2] |
|
return img, label |
|
|
|
|
|
class Resize(object): |
|
def __init__(self, phase, size): |
|
self.phase = phase |
|
self.size = size |
|
|
|
def __call__(self, data): |
|
if self.phase == 'seg': |
|
img, label = data |
|
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) |
|
|
|
label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST) |
|
return img, label |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
img1 = cv2.resize(img1, self.size, interpolation=cv2.INTER_LINEAR) |
|
img2 = cv2.resize(img2, self.size, interpolation=cv2.INTER_LINEAR) |
|
|
|
label1 = cv2.resize(label1, self.size, interpolation=cv2.INTER_NEAREST) |
|
label2 = cv2.resize(label2, self.size, interpolation=cv2.INTER_NEAREST) |
|
return img1, label1, img2, label2 |
|
elif self.phase == 'od': |
|
img, label = data |
|
height, width, _ = img.shape |
|
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) |
|
label[:, 1::2] = label[:, 1::2] / width * self.size[0] |
|
label[:, 2::2] = label[:, 2::2] / height * self.size[1] |
|
return img, label |
|
|
|
|
|
class Normalize(object): |
|
def __init__(self, phase, prior_mean, prior_std): |
|
self.phase = phase |
|
self.prior_mean = np.array([[prior_mean]], dtype=np.float32) |
|
self.prior_std = np.array([[prior_std]], dtype=np.float32) |
|
|
|
def __call__(self, data): |
|
if self.phase in ['od', 'seg']: |
|
img, _ = data |
|
img = img / 255. |
|
img = (img - self.prior_mean) / (self.prior_std + 1e-10) |
|
|
|
return img, _ |
|
elif self.phase == 'cd': |
|
img1, label1, img2, label2 = data |
|
img1 = img1 / 255. |
|
img1 = (img1 - self.prior_mean) / (self.prior_std + 1e-10) |
|
img2 = img2 / 255. |
|
img2 = (img2 - self.prior_mean) / (self.prior_std + 1e-10) |
|
|
|
return img1, label1, img2, label2 |
|
|
|
|
|
class InvNormalize(object): |
|
def __init__(self, prior_mean, prior_std): |
|
self.prior_mean = np.array([[prior_mean]], dtype=np.float32) |
|
self.prior_std = np.array([[prior_std]], dtype=np.float32) |
|
|
|
def __call__(self, img): |
|
img = img * self.prior_std + self.prior_mean |
|
img = img * 255. |
|
img = np.clip(img, a_min=0, a_max=255) |
|
return img |
|
|
|
|
|
class Augmentations(object): |
|
def __init__(self, size, prior_mean=0, prior_std=1, pattern='train', phase='seg', *args, **kwargs): |
|
self.size = size |
|
self.prior_mean = prior_mean |
|
self.prior_std = prior_std |
|
self.phase = phase |
|
|
|
augments = { |
|
'train': Compose([ |
|
ConvertUcharToFloat(), |
|
ImgDistortion(self.phase), |
|
ExpandImg(self.phase, self.prior_mean), |
|
RandomSampleCrop(self.phase, original_size=self.size), |
|
RandomMirror(self.phase), |
|
RandomFlipV(self.phase), |
|
Resize(self.phase, self.size), |
|
Normalize(self.phase, self.prior_mean, self.prior_std), |
|
]), |
|
'val': Compose([ |
|
ConvertUcharToFloat(), |
|
Resize(self.phase, self.size), |
|
Normalize(self.phase, self.prior_mean, self.prior_std), |
|
]), |
|
'test': Compose([ |
|
ConvertUcharToFloat(), |
|
Resize(self.phase, self.size), |
|
Normalize(self.phase, self.prior_mean, self.prior_std), |
|
]) |
|
} |
|
self.augment = augments[pattern] |
|
|
|
def __call__(self, data): |
|
return self.augment(data) |
|
|
|
|