|
import torch |
|
|
|
from torchvision import transforms |
|
|
|
from torchvision.transforms import Compose |
|
|
|
from PIL import Image |
|
|
|
|
|
class ToTensor(transforms.ToTensor): |
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
assert 'image' in input |
|
input['image'] = super().__call__(input['image']) |
|
return input |
|
|
|
|
|
class Normalize(transforms.Normalize): |
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
assert 'image' in input |
|
input['image'] = super().__call__(input['image']) |
|
return input |
|
|
|
|
|
class NormalizeBoxCoords(transforms.ToTensor): |
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
assert 'image' in input and 'bbox' in input |
|
_, H, W = input['image'].size() |
|
input['bbox'][:, (0, 2)] /= W |
|
input['bbox'][:, (1, 3)] /= H |
|
|
|
if 'tr_param' not in input: |
|
input['tr_param'] = [] |
|
input['tr_param'].append({'normalize_box_coords': (H, W)}) |
|
|
|
return input |
|
|
|
|
|
class SquarePad(torch.nn.Module): |
|
def __call__(self, input): |
|
if isinstance(input, Image.Image): |
|
raise NotImplementedError('put the SquarePad transform after ToTensor') |
|
|
|
assert 'image' in input |
|
_, h, w = input['image'].size() |
|
|
|
max_wh = max(w, h) |
|
xp = int(0.5 * (max_wh - w)) |
|
yp = int(0.5 * (max_wh - h)) |
|
padding = (xp, yp, (max_wh-xp)-w, (max_wh-yp)-h) |
|
|
|
input['image'] = transforms.functional.pad( |
|
input['image'], padding, fill=0, padding_mode='constant' |
|
) |
|
|
|
|
|
|
|
|
|
if 'mask' in input: |
|
input['mask'] = transforms.functional.pad( |
|
input['mask'], padding, fill=0, padding_mode='constant' |
|
) |
|
|
|
if 'bbox' in input: |
|
input['bbox'][:, (0, 2)] += xp |
|
input['bbox'][:, (1, 3)] += yp |
|
|
|
if 'tr_param' not in input: |
|
input['tr_param'] = [] |
|
input['tr_param'].append({'square_pad': padding}) |
|
|
|
return input |
|
|
|
|
|
class Resize(transforms.Resize): |
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
|
|
assert 'image' in input |
|
|
|
if not torch.is_tensor(input['image']): |
|
raise NotImplementedError('put the Resize transform after ToTensor') |
|
|
|
_, img_h, img_w = input['image'].size() |
|
|
|
if isinstance(self.size, int): |
|
dst_h = self.size if img_h < img_w else int(self.size * img_h / img_w) |
|
dst_w = self.size if img_w < img_h else int(self.size * img_w / img_h) |
|
else: |
|
dst_h, dst_w = self.size |
|
|
|
input['image'] = super().__call__(input['image']) |
|
|
|
if 'mask' in input: |
|
input['mask'] = super().__call__(input['mask']) |
|
|
|
sx, sy = dst_w / img_w, dst_h / img_h |
|
|
|
if 'bbox' in input: |
|
input['bbox'][:, (0, 2)] *= sx |
|
input['bbox'][:, (1, 3)] *= sy |
|
|
|
if 'tr_param' not in input: |
|
input['tr_param'] = [] |
|
input['tr_param'].append({'resize': (sx, sy)}) |
|
|
|
return input |
|
|
|
|
|
class RandomHorizontalFlip(transforms.RandomHorizontalFlip): |
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
|
|
assert 'image' in input |
|
|
|
if not torch.is_tensor(input['image']): |
|
raise NotImplementedError('use Resize after ToTensor') |
|
|
|
result = super().__call__(input['image']) |
|
if result is input['image']: |
|
return input |
|
input['image'] = result |
|
|
|
if 'mask' in input: |
|
input['mask'] = torch.flip(input['mask'], dims=(-1,)) |
|
|
|
img_w = input['image'].size(2) |
|
|
|
if 'bbox' in input: |
|
input['bbox'][:, (0, 2)] = img_w - input['bbox'][:, (2, 0)] |
|
|
|
if 'expr' in input: |
|
input['expr'] = input['expr'].replace('left', '<LEFT>').replace('right', 'left').replace('<LEFT>', 'right') |
|
|
|
return input |
|
|
|
|
|
class RandomAffine(transforms.RandomAffine): |
|
def get_params(self, *args, **kwargs): |
|
self.params = super().get_params(*args, **kwargs) |
|
return self.params |
|
|
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
|
|
assert 'image' in input |
|
|
|
if not torch.is_tensor(input['image']): |
|
raise NotImplementedError('put the Resize transform after ToTensor') |
|
|
|
|
|
result = super().__call__(input['image']) |
|
if result is input['image']: |
|
return input |
|
input['image'] = result |
|
|
|
_, img_h, img_w = input['image'].size() |
|
|
|
angle, translate, scale, shear = self.params |
|
center = (img_w * 0.5, img_h * 0.5) |
|
matrix = transforms.functional._get_inverse_affine_matrix(center, angle, translate, scale, shear) |
|
matrix = torch.FloatTensor([matrix[:3], matrix[3:], [0, 0, 1]]) |
|
matrix = torch.linalg.inv(matrix) |
|
|
|
if 'mask' in input: |
|
input['mask'] = transforms.functional.affine( |
|
input['mask'], *self.params, self.interpolation, self.fill |
|
) |
|
|
|
if 'bbox' in input: |
|
for i, (x1, y1, x2, y2) in enumerate(input['bbox']): |
|
pt = matrix @ torch.FloatTensor([ |
|
[x1, y1, 1], |
|
[x2, y1, 1], |
|
[x2, y2, 1], |
|
[x1, y2, 1] |
|
]).T |
|
x_min, y_min, _ = pt.min(dim=1).values |
|
x_max, y_max, _ = pt.max(dim=1).values |
|
input['bbox'][i, :] = torch.FloatTensor([x_min, y_min, x_max, y_max]) |
|
|
|
|
|
|
|
|
|
|
|
return input |
|
|
|
|
|
class ColorJitter(transforms.ColorJitter): |
|
def __call__(self, input): |
|
if not isinstance(input, dict): |
|
return super().__call__(input) |
|
assert 'image' in input |
|
input['image'] = super().__call__(input['image']) |
|
return input |
|
|
|
|
|
def get_transform(split, input_size=512): |
|
mean = [0.485, 0.456, 0.406] |
|
sdev = [0.229, 0.224, 0.225] |
|
|
|
if split in ('train', 'trainval'): |
|
transform = Compose([ |
|
|
|
ToTensor(), |
|
Normalize(mean, sdev), |
|
SquarePad(), |
|
Resize(size=(input_size, input_size)), |
|
|
|
RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.9, 1.1)), |
|
NormalizeBoxCoords(), |
|
]) |
|
elif split in ('val', 'test', 'testA', 'testB', 'testC'): |
|
transform = Compose([ |
|
ToTensor(), |
|
Normalize(mean, sdev), |
|
SquarePad(), |
|
Resize(size=(input_size, input_size)), |
|
NormalizeBoxCoords(), |
|
]) |
|
elif split in ('visu',): |
|
transform = Compose([ |
|
ToTensor(), |
|
SquarePad(), |
|
Resize(size=(input_size, input_size)), |
|
NormalizeBoxCoords(), |
|
]) |
|
else: |
|
raise ValueError(f'\'{split}\' is not a valid data split') |
|
|
|
return transform |
|
|
|
|
|
def denormalize(img): |
|
mean = [0.485, 0.456, 0.406] |
|
sdev = [0.229, 0.224, 0.225] |
|
return Normalize( |
|
mean=[-m/s for m, s in zip(mean, sdev)], std=[1./s for s in sdev] |
|
)(img) |
|
|
|
|
|
def undo_box_transforms(bbox, tr_param): |
|
|
|
bbox = bbox.clone() |
|
for tr in tr_param[::-1]: |
|
if 'resize' in tr: |
|
sx, sy = tr['resize'] |
|
bbox[:, (0, 2)] /= sx |
|
bbox[:, (1, 3)] /= sy |
|
elif 'square_pad' in tr: |
|
px, py, _, _ = tr['square_pad'] |
|
bbox[:, (0, 2)] -= px |
|
bbox[:, (1, 3)] -= py |
|
elif 'normalize_box_coords' in tr: |
|
img_h, img_w = tr['normalize_box_coords'] |
|
bbox[:, (0, 2)] *= img_w |
|
bbox[:, (1, 3)] *= img_h |
|
else: |
|
continue |
|
return bbox |
|
|
|
|
|
def undo_box_transforms_batch(bbox, tr_param): |
|
output = [] |
|
for i in range(bbox.size(0)): |
|
bb = undo_box_transforms(torch.atleast_2d(bbox[i]), tr_param[i]) |
|
output.append(bb) |
|
return torch.cat(output, dim=0) |
|
|