DDMR / ddmr /layers /augmentation.py
andreped's picture
Renamed module to ddmr
a27d55f
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir) # PYTHON > 3.3 does not allow relative referencing
PYCHARM_EXEC = os.getenv('PYCHARM_EXEC') == 'True'
import tensorflow.keras.layers as kl
import tensorflow as tf
from tensorflow.python.framework.errors import InvalidArgumentError
from ddmr.utils.operators import soft_threshold, gaussian_kernel, sample_unique
import ddmr.utils.constants as C
from ddmr.utils.thin_plate_splines import ThinPlateSplines
from voxelmorph.tf.layers import SpatialTransformer
class AugmentationLayer(kl.Layer):
def __init__(self,
max_deformation,
max_displacement,
max_rotation,
num_control_points,
in_img_shape,
out_img_shape,
num_augmentations=1,
gamma_augmentation=True,
brightness_augmentation=True,
only_image=False,
only_resize=True,
return_displacement_map=False,
**kwargs):
super(AugmentationLayer, self).__init__(**kwargs)
self.max_deformation = max_deformation
self.max_displacement = max_displacement
self.max_rotation = max_rotation
self.num_control_points = num_control_points
self.num_augmentations = num_augmentations
self.in_img_shape = in_img_shape
self.out_img_shape = out_img_shape
self.only_image = only_image
self.return_disp_map = return_displacement_map
self.do_gamma_augm = gamma_augmentation
self.do_brightness_augm = brightness_augmentation
grid = C.CoordinatesGrid()
grid.set_coords_grid(in_img_shape, [C.TPS_NUM_CTRL_PTS_PER_AXIS] * 3)
self.control_grid = tf.identity(grid.grid_flat(), name='control_grid')
self.target_grid = tf.identity(grid.grid_flat(), name='target_grid')
grid.set_coords_grid(in_img_shape, in_img_shape)
self.fine_grid = tf.identity(grid.grid_flat(), 'fine_grid')
if out_img_shape is not None:
self.downsample_factor = [i // o for o, i in zip(out_img_shape, in_img_shape)]
self.img_gauss_filter = gaussian_kernel(3, 0.001, 1, 1, 3)
# self.resize_transf = tf.diag([*self.downsample_factor, 1])[:-1, :]
# self.resize_transf = tf.expand_dims(tf.reshape(self.resize_transf, [-1]), 0, name='resize_transformation') # ST expects a (12,) vector
self.augment = not only_resize
def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
img_shape = (input_shape[0], *self.out_img_shape, 1)
seg_shape = (input_shape[0], *self.out_img_shape, input_shape[-1] - 1)
disp_shape = (input_shape[0], *self.out_img_shape, 3)
# Expect the input to have the image and segmentations in the same tensor
if self.return_disp_map:
return (img_shape, img_shape, seg_shape, seg_shape, disp_shape)
else:
return (img_shape, img_shape, seg_shape, seg_shape)
#@tf.custom_gradient
def call(self, in_data, training=None):
# def custom_grad(in_grad):
# return tf.ones_like(in_grad)
if training is not None:
self.augment = training
return self.build_batch(in_data)# , custom_grad
def build_batch(self, fix_data: tf.Tensor):
if len(fix_data.get_shape().as_list()) < 5:
fix_data = tf.expand_dims(fix_data, axis=0) # Add Batch dimension
# fix_data = tf.tile(fix_data, (self.num_augmentations, *(1,)*4))
fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map = tf.map_fn(lambda x: self.augment_sample(x),
fix_data,
dtype=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
# map_fn unstacks elems on axis 0
if self.return_disp_map:
return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch, disp_map
else:
return fix_img_batch, mov_img_batch, fix_seg_batch, mov_seg_batch
def augment_sample(self, fix_data: tf.Tensor):
if self.only_image or not self.augment:
fix_img = fix_data
fix_segm = tf.zeros_like(fix_data, dtype=tf.float32)
else:
fix_img = fix_data[..., 0]
fix_img = tf.expand_dims(fix_img, -1)
fix_segm = fix_data[..., 1:] # We expect several segmentation masks
if self.augment:
# If we are training, do the full-fledged augmentation
fix_img = self.min_max_normalization(fix_img)
mov_img, mov_segm, disp_map = self.deform_image(tf.squeeze(fix_img), fix_segm)
mov_img = tf.expand_dims(mov_img, -1) # Add the removed channel axis
# Resample to output_shape
if self.out_img_shape is not None:
fix_img = self.downsize_image(fix_img)
mov_img = self.downsize_image(mov_img)
fix_segm = self.downsize_segmentation(fix_segm)
mov_segm = self.downsize_segmentation(mov_segm)
disp_map = self.downsize_displacement_map(disp_map)
if self.do_gamma_augm:
fix_img = self.gamma_augmentation(fix_img)
mov_img = self.gamma_augmentation(mov_img)
if self.do_brightness_augm:
fix_img = self.brightness_augmentation(fix_img)
mov_img = self.brightness_augmentation(mov_img)
else:
# During inference, just resize the input images
mov_img = tf.zeros_like(fix_img)
mov_segm = tf.zeros_like(fix_segm)
disp_map = tf.tile(tf.zeros_like(fix_img), [1, 1, 1, 1, 3]) # TODO: change, don't use tile!!
if self.out_img_shape is not None:
fix_img = self.downsize_image(fix_img)
mov_img = self.downsize_image(mov_img)
fix_segm = self.downsize_segmentation(fix_segm)
mov_segm = self.downsize_segmentation(mov_segm)
disp_map = self.downsize_displacement_map(disp_map)
fix_img = self.min_max_normalization(fix_img)
mov_img = self.min_max_normalization(mov_img)
return fix_img, mov_img, fix_segm, mov_segm, disp_map
def downsize_image(self, img):
img = tf.expand_dims(img, axis=0)
# The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
img = tf.nn.conv3d(img, self.img_gauss_filter, strides=[1, ] * 5, padding='SAME', data_format='NDHWC')
img = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(img)
return tf.squeeze(img, axis=0)
def downsize_segmentation(self, segm):
segm = tf.expand_dims(segm, axis=0)
segm = tf.layers.MaxPooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(segm)
segm = tf.cast(segm, tf.float32)
return tf.squeeze(segm, axis=0)
def downsize_displacement_map(self, disp_map):
disp_map = tf.expand_dims(disp_map, axis=0)
# The filter is symmetrical along the three axes, hence there is no need for transposing the H and D dims
disp_map = tf.layers.AveragePooling3D([1]*3, self.downsample_factor, padding='valid', data_format='channels_last')(disp_map)
# self.downsample_factor = in_shape / out_shape, but here we need out_shape / in_shape. Hence, 1 / factor
if self.downsample_factor[0] != self.downsample_factor[1] != self.downsample_factor[2]:
# Downsize the displacement magnitude along the different axes
disp_map_x = disp_map[..., 0] * 1 / self.downsample_factor[0]
disp_map_y = disp_map[..., 1] * 1 / self.downsample_factor[1]
disp_map_z = disp_map[..., 2] * 1 / self.downsample_factor[2]
disp_map = tf.stack([disp_map_x, disp_map_y, disp_map_z], axis=-1)
else:
disp_map = disp_map * 1 / self.downsample_factor[0]
return tf.squeeze(disp_map, axis=0)
def gamma_augmentation(self, in_img: tf.Tensor):
in_img += 1e-5 # To prevent NaNs
f = tf.random.uniform((), -1, 1, tf.float32) # gamma [0.5, 2]
gamma = tf.pow(2.0, f)
return tf.clip_by_value(tf.pow(in_img, gamma), 0, 1)
def brightness_augmentation(self, in_img: tf.Tensor):
c = tf.random.uniform((), -0.2, 0.2, tf.float32) # 20% shift
return tf.clip_by_value(c + in_img, 0, 1)
def min_max_normalization(self, in_img: tf.Tensor):
return tf.div(tf.subtract(in_img, tf.reduce_min(in_img)),
tf.subtract(tf.reduce_max(in_img), tf.reduce_min(in_img)))
def deform_image(self, fix_img: tf.Tensor, fix_segm: tf.Tensor):
# Get locations where the intensity > 0.0
idx_points_in_label = tf.where(tf.greater(fix_img, 0.0))
# Randomly select N points
# random_idx = tf.random.uniform((self.num_control_points,),
# minval=0, maxval=tf.shape(idx_points_in_label)[0],
# dtype=tf.int32)
#
# disp_location = tf.gather(idx_points_in_label, random_idx) # And get the coordinates
# disp_location = tf.cast(disp_location, tf.float32)
disp_location = sample_unique(idx_points_in_label, self.num_control_points, tf.float32)
# Get the coordinates of the control point displaces
rand_disp = tf.random.uniform((self.num_control_points, 3), minval=-1, maxval=1, dtype=tf.float32) * self.max_deformation
warped_location = disp_location + rand_disp
# Add the selected locations to the control grid and the warped locations to the target grid
control_grid = tf.concat([self.control_grid, disp_location], axis=0)
trg_grid = tf.concat([self.control_grid, warped_location], axis=0)
# Apply global transformation
valid_trf = False
while not valid_trf:
trg_grid, aff = self.global_transformation(trg_grid)
# Interpolate the displacement map
try:
tps = ThinPlateSplines(control_grid, trg_grid)
def_grid = tps.interpolate(self.fine_grid)
except InvalidArgumentError as err:
# If the transformation raises a non-invertible error,
# try again until we get a valid transformation
tf.print('TPS non invertible matrix', output_stream=sys.stdout)
continue
else:
valid_trf = True
disp_map = self.fine_grid - def_grid
disp_map = tf.reshape(disp_map, (*self.in_img_shape, -1))
# Apply the displacement map
fix_img = tf.expand_dims(tf.expand_dims(fix_img, -1), 0)
fix_segm = tf.expand_dims(fix_segm, 0)
disp_map = tf.cast(tf.expand_dims(disp_map, 0), tf.float32)
mov_img = SpatialTransformer(interp_method='linear', indexing='ij', single_transform=False)([fix_img, disp_map])
mov_segm = SpatialTransformer(interp_method='nearest', indexing='ij', single_transform=False)([fix_segm, disp_map])
mov_img = tf.where(tf.is_nan(mov_img), tf.zeros_like(mov_img), mov_img)
mov_img = tf.where(tf.is_inf(mov_img), tf.zeros_like(mov_img), mov_img)
mov_segm = tf.where(tf.is_nan(mov_segm), tf.zeros_like(mov_segm), mov_segm)
mov_segm = tf.where(tf.is_inf(mov_segm), tf.zeros_like(mov_segm), mov_segm)
return tf.squeeze(mov_img), tf.squeeze(mov_segm, axis=0), tf.squeeze(disp_map, axis=0)
def global_transformation(self, points: tf.Tensor):
axis = tf.random.uniform((), 0, 3)
alpha = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 0.), tf.less_equal(axis, 1.)),
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
lambda: tf.zeros((), tf.float32))
beta = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 1.), tf.less_equal(axis, 2.)),
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
lambda: tf.zeros((), tf.float32))
gamma = C.DEG_TO_RAD * tf.cond(tf.logical_and(tf.greater(axis, 2.), tf.less_equal(axis, 3.)),
lambda: tf.random.uniform((), -self.max_rotation, self.max_rotation),
lambda: tf.zeros((), tf.float32))
ti = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
tj = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
tk = tf.random.uniform((), minval=-1, maxval=1, dtype=tf.float32) * self.max_displacement
M = self.build_affine_transformation(tf.convert_to_tensor(self.in_img_shape, tf.float32),
alpha, beta, gamma, ti, tj, tk)
points = tf.transpose(points)
new_pts = tf.matmul(M[:3, :3], points)
new_pts = tf.expand_dims(M[:3, -1], -1) + new_pts
return tf.transpose(new_pts), M
@staticmethod
def build_affine_transformation(img_shape, alpha, beta, gamma, ti, tj, tk):
img_centre = tf.divide(img_shape, 2.)
# Rotation matrix around the image centre
# R* = T(p) R(ang) T(-p)
# tf.cos and tf.sin expect radians
T = tf.convert_to_tensor([[1, 0, 0, ti],
[0, 1, 0, tj],
[0, 0, 1, tk],
[0, 0, 0, 1]], tf.float32)
Ri = tf.convert_to_tensor([[1, 0, 0, 0],
[0, tf.math.cos(alpha), -tf.math.sin(alpha), 0],
[0, tf.math.sin(alpha), tf.math.cos(alpha), 0],
[0, 0, 0, 1]], tf.float32)
Rj = tf.convert_to_tensor([[ tf.math.cos(beta), 0, tf.math.sin(beta), 0],
[0, 1, 0, 0],
[-tf.math.sin(beta), 0, tf.math.cos(beta), 0],
[0, 0, 0, 1]], tf.float32)
Rk = tf.convert_to_tensor([[tf.math.cos(gamma), -tf.math.sin(gamma), 0, 0],
[tf.math.sin(gamma), tf.math.cos(gamma), 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]], tf.float32)
R = tf.matmul(tf.matmul(Ri, Rj), Rk)
Tc = tf.convert_to_tensor([[1, 0, 0, img_centre[0]],
[0, 1, 0, img_centre[1]],
[0, 0, 1, img_centre[2]],
[0, 0, 0, 1]], tf.float32)
Tc_ = tf.convert_to_tensor([[1, 0, 0, -img_centre[0]],
[0, 1, 0, -img_centre[1]],
[0, 0, 1, -img_centre[2]],
[0, 0, 0, 1]], tf.float32)
return tf.matmul(T, tf.matmul(Tc, tf.matmul(R, Tc_)))
def get_config(self):
config = super(AugmentationLayer, self).get_config()
return config