""" pytorch grid_sample doesn't support second-order derivative implement custom version """ import torch import torch.nn.functional as F import numpy as np def grid_sample_2d(image, optical): N, C, IH, IW = image.shape _, H, W, _ = optical.shape ix = optical[..., 0] iy = optical[..., 1] ix = ((ix + 1) / 2) * (IW - 1); iy = ((iy + 1) / 2) * (IH - 1); with torch.no_grad(): ix_nw = torch.floor(ix); iy_nw = torch.floor(iy); ix_ne = ix_nw + 1; iy_ne = iy_nw; ix_sw = ix_nw; iy_sw = iy_nw + 1; ix_se = ix_nw + 1; iy_se = iy_nw + 1; nw = (ix_se - ix) * (iy_se - iy) ne = (ix - ix_sw) * (iy_sw - iy) sw = (ix_ne - ix) * (iy - iy_ne) se = (ix - ix_nw) * (iy - iy_nw) with torch.no_grad(): torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) torch.clamp(ix_se, 0, IW - 1, out=ix_se) torch.clamp(iy_se, 0, IH - 1, out=iy_se) image = image.view(N, C, IH * IW) nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) return out_val # - checked for correctness def grid_sample_3d(volume, optical): """ bilinear sampling cannot guarantee continuous first-order gradient mimic pytorch grid_sample function The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view) fnw (front north west) point bse (back south east) point :param volume: [B, C, X, Y, Z] :param optical: [B, x, y, z, 3] :return: """ N, C, ID, IH, IW = volume.shape _, D, H, W, _ = optical.shape ix = optical[..., 0] iy = optical[..., 1] iz = optical[..., 2] ix = ((ix + 1) / 2) * (IW - 1) iy = ((iy + 1) / 2) * (IH - 1) iz = ((iz + 1) / 2) * (ID - 1) mask_x = (ix > 0) & (ix < IW) mask_y = (iy > 0) & (iy < IH) mask_z = (iz > 0) & (iz < ID) mask = mask_x & mask_y & mask_z # [B, x, y, z] mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) # [B, C, x, y, z] with torch.no_grad(): # back north west ix_bnw = torch.floor(ix) iy_bnw = torch.floor(iy) iz_bnw = torch.floor(iz) ix_bne = ix_bnw + 1 iy_bne = iy_bnw iz_bne = iz_bnw ix_bsw = ix_bnw iy_bsw = iy_bnw + 1 iz_bsw = iz_bnw ix_bse = ix_bnw + 1 iy_bse = iy_bnw + 1 iz_bse = iz_bnw # front view ix_fnw = ix_bnw iy_fnw = iy_bnw iz_fnw = iz_bnw + 1 ix_fne = ix_bnw + 1 iy_fne = iy_bnw iz_fne = iz_bnw + 1 ix_fsw = ix_bnw iy_fsw = iy_bnw + 1 iz_fsw = iz_bnw + 1 ix_fse = ix_bnw + 1 iy_fse = iy_bnw + 1 iz_fse = iz_bnw + 1 # back view bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) # smaller volume, larger weight bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz) bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz) bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz) # front view fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) # smaller volume, larger weight fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw) fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne) fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw) with torch.no_grad(): # back view torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) # front view torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw) torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw) torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw) torch.clamp(ix_fne, 0, IW - 1, out=ix_fne) torch.clamp(iy_fne, 0, IH - 1, out=iy_fne) torch.clamp(iz_fne, 0, ID - 1, out=iz_fne) torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw) torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw) torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw) torch.clamp(ix_fse, 0, IW - 1, out=ix_fse) torch.clamp(iy_fse, 0, IH - 1, out=iy_fse) torch.clamp(iz_fse, 0, ID - 1, out=iz_fse) # xxx = volume[:, :, iz_bnw.long(), iy_bnw.long(), ix_bnw.long()] volume = volume.view(N, C, ID * IH * IW) # yyy = volume[:, :, (iz_bnw * ID + iy_bnw * IW + ix_bnw).long()] # back view bnw_val = torch.gather(volume, 2, (iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bne_val = torch.gather(volume, 2, (iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) bsw_val = torch.gather(volume, 2, (iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) bse_val = torch.gather(volume, 2, (iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) # front view fnw_val = torch.gather(volume, 2, (iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) fne_val = torch.gather(volume, 2, (iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1)) fsw_val = torch.gather(volume, 2, (iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) fse_val = torch.gather(volume, 2, (iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1)) out_val = ( # back bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) + # front fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) + fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) + fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) + fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W) ) # * zero padding out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device)) return out_val # Interpolation kernel def get_weight(s, a=-0.5): mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1) mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2) mask_2 = torch.abs(s) > 2 weight = torch.zeros_like(s).to(s.device) weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight) weight = torch.where(mask_1, a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a, weight) # if (torch.abs(s) >= 0) & (torch.abs(s) <= 1): # return (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1 # # elif (torch.abs(s) > 1) & (torch.abs(s) <= 2): # return a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a # return 0 return weight def cubic_interpolate(p, x): """ one dimensional cubic interpolation :param p: [N, 4] (4) should be in order :param x: [N] :return: """ return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * ( 2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * ( 3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0]))) def bicubic_interpolate(p, x, y, if_batch=True): """ two dimensional cubic interpolation :param p: [N, 4, 4] :param x: [N] :param y: [N] :return: """ num = p.shape[0] if not if_batch: arr0 = cubic_interpolate(p[:, 0, :], x) # [N] arr1 = cubic_interpolate(p[:, 1, :], x) arr2 = cubic_interpolate(p[:, 2, :], x) arr3 = cubic_interpolate(p[:, 3, :], x) return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) # [N] else: x = x[:, None].repeat(1, 4).view(-1) p = p.contiguous().view(num * 4, 4) arr = cubic_interpolate(p, x) arr = arr.view(num, 4) return cubic_interpolate(arr, y) def tricubic_interpolate(p, x, y, z): """ three dimensional cubic interpolation :param p: [N,4,4,4] :param x: [N] :param y: [N] :param z: [N] :return: """ num = p.shape[0] arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) # [N] arr1 = bicubic_interpolate(p[:, 1, :, :], x, y) arr2 = bicubic_interpolate(p[:, 2, :, :], x, y) arr3 = bicubic_interpolate(p[:, 3, :, :], x, y) return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) # [N] def cubic_interpolate_batch(p, x): """ one dimensional cubic interpolation :param p: [B, N, 4] (4) should be in order :param x: [B, N] :return: """ return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * ( 2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * ( 3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0]))) def bicubic_interpolate_batch(p, x, y): """ two dimensional cubic interpolation :param p: [B, N, 4, 4] :param x: [B, N] :param y: [B, N] :return: """ B, N, _, _ = p.shape x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) # [B, N*4] arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x) arr = arr.view(B, N, 4) return cubic_interpolate_batch(arr, y) # [B, N] # * batch version cannot speed up training def tricubic_interpolate_batch(p, x, y, z): """ three dimensional cubic interpolation :param p: [N,4,4,4] :param x: [N] :param y: [N] :param z: [N] :return: """ N = p.shape[0] x = x[None, :].repeat(4, 1) y = y[None, :].repeat(4, 1) p = p.permute(1, 0, 2, 3).contiguous() arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) # [4, N] arr = arr.permute(1, 0).contiguous() # [N, 4] return cubic_interpolate(arr, z) # [N] def tricubic_sample_3d(volume, optical): """ tricubic sampling; can guarantee continuous gradient (interpolation border) :param volume: [B, C, ID, IH, IW] :param optical: [B, D, H, W, 3] :param sample_num: :return: """ @torch.no_grad() def get_shifts(x): x1 = -1 * (1 + x - torch.floor(x)) x2 = -1 * (x - torch.floor(x)) x3 = torch.floor(x) + 1 - x x4 = torch.floor(x) + 2 - x return torch.stack([x1, x2, x3, x4], dim=-1) # (B,d,h,w,4) N, C, ID, IH, IW = volume.shape _, D, H, W, _ = optical.shape device = volume.device ix = optical[..., 0] iy = optical[..., 1] iz = optical[..., 2] ix = ((ix + 1) / 2) * (IW - 1) # (B,d,h,w) iy = ((iy + 1) / 2) * (IH - 1) iz = ((iz + 1) / 2) * (ID - 1) ix = ix.view(-1) iy = iy.view(-1) iz = iz.view(-1) with torch.no_grad(): shifts_x = get_shifts(ix).view(-1, 4) # (B*d*h*w,4) shifts_y = get_shifts(iy).view(-1, 4) shifts_z = get_shifts(iz).view(-1, 4) perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device) perm = torch.cumsum(perm_weights, dim=-1) - 1 # (B*d*h*w,64) perm_z = perm // 16 # [N*D*H*W, num] perm_y = (perm - perm_z * 16) // 4 perm_x = (perm - perm_z * 16 - perm_y * 4) shifts_x = torch.gather(shifts_x, 1, perm_x) # [N*D*H*W, num] shifts_y = torch.gather(shifts_y, 1, perm_y) shifts_z = torch.gather(shifts_z, 1, perm_z) ix_target = (ix[:, None] + shifts_x).long() # [N*D*H*W, num] iy_target = (iy[:, None] + shifts_y).long() iz_target = (iz[:, None] + shifts_z).long() torch.clamp(ix_target, 0, IW - 1, out=ix_target) torch.clamp(iy_target, 0, IH - 1, out=iy_target) torch.clamp(iz_target, 0, ID - 1, out=iz_target) local_dist_x = ix - ix_target[:, 1] # ! attention here is [:, 1] local_dist_y = iy - iy_target[:, 1 + 4] local_dist_z = iz - iz_target[:, 1 + 16] local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) # ! attention: IW is correct idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target # [N*D*H*W, num] volume = volume.view(N, C, ID * IH * IW) out = torch.gather(volume, 2, idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1)) out = out.view(N * C * D * H * W, 4, 4, 4) # - tricubic_interpolate() is a bit faster than tricubic_interpolate_batch() final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) # [N,C,D,H,W] return final if __name__ == "__main__": # image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3) # # optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2) # # print(grid_sample_2d(image, optical)) # # print(F.grid_sample(image, optical, padding_mode='border', align_corners=True)) from ops.generate_grids import generate_grid p = torch.tensor([x for x in range(4)]).view(1, 4).float() v = cubic_interpolate(p, torch.tensor([0.5]).view(1)) # v = bicubic_interpolate(p, torch.tensor([2/3]).view(1) , torch.tensor([2/3]).view(1)) vsize = 9 volume = generate_grid([vsize, vsize, vsize], 1) # [1,3,10,10,10] # volume = torch.tensor([x for x in range(1000)]).view(1, 1, 10, 10, 10).float() X, Y, Z = 0, 0, 6 x = 2 * X / (vsize - 1) - 1 y = 2 * Y / (vsize - 1) - 1 z = 2 * Z / (vsize - 1) - 1 # print(volume[:, :, Z, Y, X]) # volume = volume.view(1, 3, -1) # xx = volume[:, :, Z * 9*9 + Y * 9 + X] optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3) print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True)) print(grid_sample_3d(volume, optical)) print(tricubic_sample_3d(volume, optical)) # target, relative_coords = implicit_sample_3d(volume, optical, 1) # print(target)