Spaces:
Runtime error
Runtime error
import numpy as np | |
import cv2 | |
import torch | |
class Compose(object): | |
"""Composes several transforms together. | |
Args: | |
transforms (list of ``Transform`` objects): list of transforms to compose. | |
Example: | |
>>> transforms.Compose([ | |
>>> transforms.CenterCrop(10), | |
>>> transforms.ToTensor(), | |
>>> ]) | |
""" | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, data): | |
for t in self.transforms: | |
data = t(data) | |
return data | |
def __repr__(self): | |
format_string = self.__class__.__name__ + '(' | |
for t in self.transforms: | |
format_string += '\n' | |
format_string += ' {0}'.format(t) | |
format_string += '\n)' | |
return format_string | |
class ConvertUcharToFloat(object): | |
""" | |
Convert img form uchar to float32 | |
""" | |
def __call__(self, data): | |
data = [x.astype(np.float32) for x in data] | |
return data | |
class RandomContrast(object): | |
""" | |
Get random contrast img | |
""" | |
def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): | |
self.phase = phase | |
self.lower = lower | |
self.upper = upper | |
self.prob = prob | |
assert self.upper >= self.lower, "contrast upper must be >= lower!" | |
assert self.lower > 0, "contrast lower must be non-negative!" | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
img *= alpha.numpy() | |
return_data = img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
img1 *= alpha.numpy() | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
img2 *= alpha.numpy() | |
return_data = img1, label1, img2, label2 | |
return return_data | |
class RandomBrightness(object): | |
""" | |
Get random brightness img | |
""" | |
def __init__(self, phase, delta=10, prob=0.5): | |
self.phase = phase | |
self.delta = delta | |
self.prob = prob | |
assert 0. <= self.delta < 255., "brightness delta must between 0 to 255" | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
if torch.rand(1) < self.prob: | |
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) | |
img += delta.numpy() | |
return_data = img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) | |
img1 += delta.numpy() | |
if torch.rand(1) < self.prob: | |
delta = torch.FloatTensor(1).uniform_(- self.delta, self.delta) | |
img2 += delta.numpy() | |
return_data = img1, label1, img2, label2 | |
return return_data | |
class ConvertColor(object): | |
""" | |
Convert img color BGR to HSV or HSV to BGR for later img distortion. | |
""" | |
def __init__(self, phase, current='RGB', target='HSV'): | |
self.phase = phase | |
self.current = current | |
self.target = target | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
if self.current == 'RGB' and self.target == 'HSV': | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) | |
elif self.current == 'HSV' and self.target == 'RGB': | |
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) | |
else: | |
raise NotImplementedError("Convert color fail!") | |
return_data = img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if self.current == 'RGB' and self.target == 'HSV': | |
img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2HSV) | |
img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2HSV) | |
elif self.current == 'HSV' and self.target == 'RGB': | |
img1 = cv2.cvtColor(img1, cv2.COLOR_HSV2RGB) | |
img2 = cv2.cvtColor(img2, cv2.COLOR_HSV2RGB) | |
else: | |
raise NotImplementedError("Convert color fail!") | |
return_data = img1, label1, img2, label2 | |
return return_data | |
class RandomSaturation(object): | |
""" | |
get random saturation img | |
apply the restriction on saturation S | |
""" | |
def __init__(self, phase, lower=0.8, upper=1.2, prob=0.5): | |
self.phase = phase | |
self.lower = lower | |
self.upper = upper | |
self.prob = prob | |
assert self.upper >= self.lower, "saturation upper must be >= lower!" | |
assert self.lower > 0, "saturation lower must be non-negative!" | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
img[:, :, 1] *= alpha.numpy() | |
return_data = img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
img1[:, :, 1] *= alpha.numpy() | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(self.lower, self.upper) | |
img2[:, :, 1] *= alpha.numpy() | |
return_data = img1, label1, img2, label2 | |
return return_data | |
class RandomHue(object): | |
""" | |
get random Hue img | |
apply the restriction on Hue H | |
""" | |
def __init__(self, phase, delta=10., prob=0.5): | |
self.phase = phase | |
self.delta = delta | |
self.prob = prob | |
assert 0 <= self.delta < 360, "Hue delta must between 0 to 360!" | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) | |
img[:, :, 0] += alpha.numpy() | |
img[:, :, 0][img[:, :, 0] > 360.0] -= 360.0 | |
img[:, :, 0][img[:, :, 0] < 0.0] += 360.0 | |
return_data = img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) | |
img1[:, :, 0] += alpha.numpy() | |
img1[:, :, 0][img1[:, :, 0] > 360.0] -= 360.0 | |
img1[:, :, 0][img1[:, :, 0] < 0.0] += 360.0 | |
if torch.rand(1) < self.prob: | |
alpha = torch.FloatTensor(1).uniform_(-self.delta, self.delta) | |
img2[:, :, 0] += alpha.numpy() | |
img2[:, :, 0][img2[:, :, 0] > 360.0] -= 360.0 | |
img2[:, :, 0][img2[:, :, 0] < 0.0] += 360.0 | |
return_data = img1, label1, img2, label2 | |
return return_data | |
class RandomChannelNoise(object): | |
""" | |
Get random shuffle channels | |
""" | |
def __init__(self, phase, prob=0.4): | |
self.phase = phase | |
self.prob = prob | |
self.perms = ((0, 1, 2), (0, 2, 1), | |
(1, 0, 2), (1, 2, 0), | |
(2, 0, 1), (2, 1, 0)) | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
if torch.rand(1) < self.prob: | |
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] | |
img = img[:, :, shuffle_factor] | |
return_data = img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] | |
img1 = img1[:, :, shuffle_factor] | |
if torch.rand(1) < self.prob: | |
shuffle_factor = self.perms[torch.randint(0, len(self.perms), size=[])] | |
img2 = img2[:, :, shuffle_factor] | |
return_data = img1, label1, img2, label2 | |
return return_data | |
class ImgDistortion(object): | |
""" | |
Change img by distortion | |
""" | |
def __init__(self, phase, prob=0.5): | |
self.phase = phase | |
self.prob = prob | |
self.operation = [ | |
RandomContrast(phase), | |
ConvertColor(phase, current='RGB', target='HSV'), | |
RandomSaturation(phase), | |
RandomHue(phase), | |
ConvertColor(phase, current='HSV', target='RGB'), | |
RandomContrast(phase) | |
] | |
self.random_brightness = RandomBrightness(phase) | |
self.random_light_noise = RandomChannelNoise(phase) | |
def __call__(self, data): | |
if torch.rand(1) < self.prob: | |
data = self.random_brightness(data) | |
if torch.rand(1) < self.prob: | |
distort = Compose(self.operation[:-1]) | |
else: | |
distort = Compose(self.operation[1:]) | |
data = distort(data) | |
data = self.random_light_noise(data) | |
return data | |
class ExpandImg(object): | |
""" | |
Get expand img | |
""" | |
def __init__(self, phase, prior_mean, prob=0.5, expand_ratio=0.2): | |
self.phase = phase | |
self.prior_mean = np.array(prior_mean) * 255 | |
self.prob = prob | |
self.expand_ratio = expand_ratio | |
def __call__(self, data): | |
if self.phase == 'seg': | |
img, label = data | |
if torch.rand(1) < self.prob: | |
return data | |
height, width, channels = img.shape | |
ratio_width = self.expand_ratio * torch.rand([]) | |
ratio_height = self.expand_ratio * torch.rand([]) | |
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) | |
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) | |
img = cv2.copyMakeBorder( | |
img, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) | |
label = cv2.copyMakeBorder( | |
label, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) | |
return img, label | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
return data | |
height, width, channels = img1.shape | |
ratio_width = self.expand_ratio * torch.rand([]) | |
ratio_height = self.expand_ratio * torch.rand([]) | |
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) | |
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) | |
img1 = cv2.copyMakeBorder( | |
img1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) | |
label1 = cv2.copyMakeBorder( | |
label1, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) | |
img2 = cv2.copyMakeBorder( | |
img2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=self.prior_mean) | |
label2 = cv2.copyMakeBorder( | |
label2, int(top), int(bottom), int(left), int(right), cv2.BORDER_CONSTANT, value=0) | |
return img1, label1, img2, label2 | |
elif self.phase == 'od': | |
if torch.rand(1) < self.prob: | |
return data | |
img, label = data | |
height, width, channels = img.shape | |
ratio_width = self.expand_ratio * torch.rand([]) | |
ratio_height = self.expand_ratio * torch.rand([]) | |
left, right = torch.randint(high=int(max(1, width * ratio_width)), size=[2]) | |
top, bottom = torch.randint(high=int(max(1, width * ratio_height)), size=[2]) | |
left = int(left) | |
right = int(right) | |
top = int(top) | |
bottom = int(bottom) | |
img = cv2.copyMakeBorder( | |
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.prior_mean) | |
label[:, 1::2] += left | |
label[:, 2::2] += top | |
return img, label | |
class RandomSampleCrop(object): | |
""" | |
Crop | |
Arguments: | |
img (Image): the image being input during training | |
boxes (Tensor): the original bounding boxes in pt form | |
label (Tensor): the class label for each bbox | |
mode (float tuple): the min and max jaccard overlaps | |
Return: | |
(img, boxes, classes) | |
img (Image): the cropped image | |
boxes (Tensor): the adjusted bounding boxes in pt form | |
label (Tensor): the class label for each bbox | |
""" | |
def __init__(self, | |
phase, | |
original_size=[512, 512], | |
prob=0.5, | |
crop_scale_ratios_range=[0.8, 1.2], | |
aspect_ratio_range=[4./5, 5./4]): | |
self.phase = phase | |
self.prob = prob | |
self.scale_range = crop_scale_ratios_range | |
self.original_size = original_size | |
self.aspect_ratio_range = aspect_ratio_range # h/w | |
self.max_try_times = 500 | |
def __call__(self, data): | |
if self.phase == 'seg': | |
img, label = data | |
w, h, c = img.shape | |
if torch.rand(1) < self.prob: | |
return data | |
else: | |
try_times = 0 | |
while try_times < self.max_try_times: | |
crop_w = torch.randint( | |
min(w, int(self.scale_range[0] * self.original_size[0])), | |
min(w + 1, int(self.scale_range[1] * self.original_size[0])), | |
size=[] | |
) | |
crop_h = torch.randint( | |
min(h, int(self.scale_range[0] * self.original_size[1])), | |
min(h + 1, int(self.scale_range[1] * self.original_size[1])), | |
size=[] | |
) | |
# aspect ratio constraint | |
if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: | |
break | |
else: | |
try_times += 1 | |
if try_times >= self.max_try_times: | |
print("try times over max threshold!", flush=True) | |
return img, label | |
left = torch.randint(0, w - crop_w + 1, size=[]) | |
top = torch.randint(0, h - crop_h + 1, size=[]) | |
img = img[top:(top + crop_h), left:(left + crop_w), :] | |
label = label[top:(top + crop_h), left:(left + crop_w)] | |
return img, label | |
elif self.phase == 'od': | |
if torch.rand(1) < self.prob: | |
return data | |
img, label = data | |
w, h, c = img.shape | |
while True: | |
crop_w = torch.randint( | |
min(w, int(self.scale_range[0] * self.original_size[0])), | |
min(w + 1, int(self.scale_range[1] * self.original_size[0])), | |
size=[] | |
) | |
crop_h = torch.randint( | |
min(h, int(self.scale_range[0] * self.original_size[1])), | |
min(h + 1, int(self.scale_range[1] * self.original_size[1])), | |
size=[] | |
) | |
# aspect ratio constraint | |
if self.aspect_ratio_range[0] < crop_h / crop_w < self.aspect_ratio_range[1]: | |
break | |
left = torch.randint(0, w - crop_w + 1, size=[]) | |
top = torch.randint(0, h - crop_h + 1, size=[]) | |
left = left.numpy() | |
top = top.numpy() | |
crop_h = crop_h.numpy() | |
crop_w = crop_w.numpy() | |
img = img[top:(top + crop_h), left:(left + crop_w), :] | |
if len(label): | |
# keep overlap with gt box IF center in sampled patch | |
centers = (label[:, 1:3] + label[:, 3:]) / 2.0 | |
# mask in all gt boxes that above and to the left of centers | |
m1 = (left <= centers[:, 0]) * (top <= centers[:, 1]) | |
# mask in all gt boxes that under and to the right of centers | |
m2 = ((left + crop_w) >= centers[:, 0]) * ((top + crop_h) > centers[:, 1]) | |
# mask in that both m1 and m2 are true | |
mask = m1 * m2 | |
# take only matching gt boxes | |
current_label = label[mask, :] | |
# adjust to crop (by substracting crop's left,top) | |
current_label[:, 1::2] -= left | |
current_label[:, 2::2] -= top | |
label = current_label | |
return img, label | |
class RandomMirror(object): | |
def __init__(self, phase, prob=0.5): | |
self.phase = phase | |
self.prob = prob | |
def __call__(self, data): | |
if self.phase == 'seg': | |
img, label = data | |
if torch.rand(1) < self.prob: | |
img = img[:, ::-1] | |
label = label[:, ::-1] | |
return img, label | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
img1 = img1[:, ::-1] | |
label1 = label1[:, ::-1] | |
img2 = img2[:, ::-1] | |
label2 = label2[:, ::-1] | |
return img1, label1, img2, label2 | |
elif self.phase == 'od': | |
img, label = data | |
if torch.rand(1) < self.prob: | |
_, width, _ = img.shape | |
img = img[:, ::-1] | |
label[:, 1::2] = width - label[:, 3::-2] | |
return img, label | |
class RandomFlipV(object): | |
def __init__(self, phase, prob=0.5): | |
self.phase = phase | |
self.prob = prob | |
def __call__(self, data): | |
if self.phase == 'seg': | |
img, label = data | |
if torch.rand(1) < self.prob: | |
img = img[::-1, :] | |
label = label[::-1, :] | |
return img, label | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
if torch.rand(1) < self.prob: | |
img1 = img1[::-1, :] | |
label1 = label1[::-1, :] | |
img2 = img2[::-1, :] | |
label2 = label2[::-1, :] | |
return img1, label1, img2, label2 | |
elif self.phase == 'od': | |
img, label = data | |
if torch.rand(1) < self.prob: | |
height, _, _ = img.shape | |
img = img[::-1, :] | |
label[:, 2::2] = height - label[:, 4:1:-2] | |
return img, label | |
class Resize(object): | |
def __init__(self, phase, size): | |
self.phase = phase | |
self.size = size | |
def __call__(self, data): | |
if self.phase == 'seg': | |
img, label = data | |
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) | |
# for label | |
label = cv2.resize(label, self.size, interpolation=cv2.INTER_NEAREST) | |
return img, label | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
img1 = cv2.resize(img1, self.size, interpolation=cv2.INTER_LINEAR) | |
img2 = cv2.resize(img2, self.size, interpolation=cv2.INTER_LINEAR) | |
# for label | |
label1 = cv2.resize(label1, self.size, interpolation=cv2.INTER_NEAREST) | |
label2 = cv2.resize(label2, self.size, interpolation=cv2.INTER_NEAREST) | |
return img1, label1, img2, label2 | |
elif self.phase == 'od': | |
img, label = data | |
height, width, _ = img.shape | |
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) | |
label[:, 1::2] = label[:, 1::2] / width * self.size[0] | |
label[:, 2::2] = label[:, 2::2] / height * self.size[1] | |
return img, label | |
class Normalize(object): | |
def __init__(self, phase, prior_mean, prior_std): | |
self.phase = phase | |
self.prior_mean = np.array([[prior_mean]], dtype=np.float32) | |
self.prior_std = np.array([[prior_std]], dtype=np.float32) | |
def __call__(self, data): | |
if self.phase in ['od', 'seg']: | |
img, _ = data | |
img = img / 255. | |
img = (img - self.prior_mean) / (self.prior_std + 1e-10) | |
return img, _ | |
elif self.phase == 'cd': | |
img1, label1, img2, label2 = data | |
img1 = img1 / 255. | |
img1 = (img1 - self.prior_mean) / (self.prior_std + 1e-10) | |
img2 = img2 / 255. | |
img2 = (img2 - self.prior_mean) / (self.prior_std + 1e-10) | |
return img1, label1, img2, label2 | |
class InvNormalize(object): | |
def __init__(self, prior_mean, prior_std): | |
self.prior_mean = np.array([[prior_mean]], dtype=np.float32) | |
self.prior_std = np.array([[prior_std]], dtype=np.float32) | |
def __call__(self, img): | |
img = img * self.prior_std + self.prior_mean | |
img = img * 255. | |
img = np.clip(img, a_min=0, a_max=255) | |
return img | |
class Augmentations(object): | |
def __init__(self, size, prior_mean=0, prior_std=1, pattern='train', phase='seg', *args, **kwargs): | |
self.size = size | |
self.prior_mean = prior_mean | |
self.prior_std = prior_std | |
self.phase = phase | |
augments = { | |
'train': Compose([ | |
ConvertUcharToFloat(), | |
ImgDistortion(self.phase), | |
ExpandImg(self.phase, self.prior_mean), | |
RandomSampleCrop(self.phase, original_size=self.size), | |
RandomMirror(self.phase), | |
RandomFlipV(self.phase), | |
Resize(self.phase, self.size), | |
Normalize(self.phase, self.prior_mean, self.prior_std), | |
]), | |
'val': Compose([ | |
ConvertUcharToFloat(), | |
Resize(self.phase, self.size), | |
Normalize(self.phase, self.prior_mean, self.prior_std), | |
]), | |
'test': Compose([ | |
ConvertUcharToFloat(), | |
Resize(self.phase, self.size), | |
Normalize(self.phase, self.prior_mean, self.prior_std), | |
]) | |
} | |
self.augment = augments[pattern] | |
def __call__(self, data): | |
return self.augment(data) | |