|
|
|
|
|
|
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
def center_crop(img, win_size = [220, 220, 220]): |
|
|
|
if len(img.shape) == 4: |
|
img = torch.permute(img, (3, 0, 1, 2)) |
|
img = img[None] |
|
permuted = True |
|
else: |
|
assert len(img.shape) == 3 |
|
img = img[None, None] |
|
permuted = False |
|
|
|
orig_shp = img.shape[2:] |
|
if win_size is None: |
|
if permuted: |
|
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp |
|
return img, [0, 0, 0], orig_shp |
|
elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]: |
|
crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ] |
|
crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0], |
|
crop_start[1] : crop_start[1] + win_size[1], |
|
crop_start[2] : crop_start[2] + win_size[2]] |
|
if permuted: |
|
return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp |
|
return crop_img, crop_start, orig_shp |
|
else: |
|
if permuted: |
|
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp |
|
return img, [0, 0, 0], orig_shp |
|
|
|
|
|
|
|
def V_plot(Vx, Vy, save_path): |
|
|
|
X,Y = np.meshgrid(np.arange(0, Vx.shape[0], 1), np.arange(0, Vx.shape[1], 1)) |
|
|
|
Ex = Vx |
|
Ey = Vy |
|
|
|
|
|
plt.figure() |
|
plt.streamplot(X,Y,Ex,Ey, density=1.4, linewidth=None, color='orange') |
|
plt.axis('off') |
|
plt.savefig(save_path) |
|
|
|
|
|
def stream_2D(Phi, batched = False, delta_lst = [1., 1.]): |
|
''' |
|
input: Phi as a scalar field in 2D grid: (r, c) or (n_batch, r, c) |
|
output: curl of Phi (divergence-free by definition) |
|
''' |
|
dD = gradient_c(Phi, batched = batched, delta_lst = delta_lst) |
|
Vx = - dD[..., 1] |
|
Vy = dD[..., 0] |
|
return Vx, Vy |
|
|
|
|
|
def stream_3D(Phi_a, Phi_b, Phi_c, batched = False, delta_lst = [1., 1., 1.]): |
|
''' |
|
input: (batch, s, r, c) |
|
''' |
|
device = Phi_a.device |
|
dDa = gradient_c(Phi_a, batched = batched, delta_lst = delta_lst) |
|
dDb = gradient_c(Phi_b, batched = batched, delta_lst = delta_lst) |
|
dDc = gradient_c(Phi_c, batched = batched, delta_lst = delta_lst) |
|
Va_x, Va_y, Va_z = dDa[..., 0], dDa[..., 1], dDa[..., 2] |
|
Vb_x, Vb_y, Vb_z = dDb[..., 0], dDb[..., 1], dDb[..., 2] |
|
Vc_x, Vc_y, Vc_z = dDc[..., 0], dDc[..., 1], dDc[..., 2] |
|
Vx = Vc_y - Vb_z |
|
Vy = Va_z - Vc_x |
|
Vz = Vb_x - Va_y |
|
return Vx, Vy, Vz |
|
|
|
|
|
|
|
def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]): |
|
''' |
|
Compute gradient of a torch tensor "X" in each direction |
|
Upper-boundaries: Backward Difference |
|
Non-boundaries & Upper-boundaries: Forward Difference |
|
if X is batched: (n_batch, ...); |
|
else: (...) |
|
''' |
|
device = X.device |
|
dim = len(X.size()) - 1 if batched else len(X.size()) |
|
|
|
|
|
if dim == 1: |
|
|
|
dX = torch.zeros(X.size(), dtype = torch.float, device = device) |
|
X = X.permute(1, 0) if batched else X |
|
dX = dX.permute(1, 0) if batched else dX |
|
dX[-1] = X[-1] - X[-2] |
|
dX[:-1] = X[1:] - X[:-1] |
|
|
|
dX = dX.permute(1, 0) if batched else dX |
|
dX /= delta_lst[0] |
|
elif dim == 2: |
|
|
|
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) |
|
X = X.permute(1, 2, 0) if batched else X |
|
dX = dX.permute(1, 2, 3, 0) if batched else dX |
|
dX[-1, :, 0] = X[-1, :] - X[-2, :] |
|
dX[:-1, :, 0] = X[1:] - X[:-1] |
|
|
|
dX[:, -1, 1] = X[:, -1] - X[:, -2] |
|
dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] |
|
|
|
dX = dX.permute(3, 0, 1, 2) if batched else dX |
|
dX[..., 0] /= delta_lst[0] |
|
dX[..., 1] /= delta_lst[1] |
|
elif dim == 3: |
|
|
|
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) |
|
X = X.permute(1, 2, 3, 0) if batched else X |
|
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX |
|
dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] |
|
dX[:-1, :, :, 0] = X[1:] - X[:-1] |
|
|
|
dX[:, -1, :, 1] = X[:, -1] - X[:, -2] |
|
dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] |
|
|
|
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] |
|
dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] |
|
|
|
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX |
|
dX[..., 0] /= delta_lst[0] |
|
dX[..., 1] /= delta_lst[1] |
|
dX[..., 2] /= delta_lst[2] |
|
return dX |
|
|
|
|
|
def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]): |
|
''' |
|
Compute gradient of a torch tensor "X" in each direction |
|
Non-boundaries & Upper-boundaries: Backward Difference |
|
Lower-boundaries: Forward Difference |
|
if X is batched: (n_batch, ...); |
|
else: (...) |
|
''' |
|
device = X.device |
|
dim = len(X.size()) - 1 if batched else len(X.size()) |
|
|
|
|
|
if dim == 1: |
|
|
|
dX = torch.zeros(X.size(), dtype = torch.float, device = device) |
|
X = X.permute(1, 0) if batched else X |
|
dX = dX.permute(1, 0) if batched else dX |
|
dX[1:] = X[1:] - X[:-1] |
|
dX[0] = X[1] - X[0] |
|
|
|
dX = dX.permute(1, 0) if batched else dX |
|
dX /= delta_lst[0] |
|
elif dim == 2: |
|
|
|
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) |
|
X = X.permute(1, 2, 0) if batched else X |
|
dX = dX.permute(1, 2, 3, 0) if batched else dX |
|
dX[1:, :, 0] = X[1:, :] - X[:-1, :] |
|
dX[0, :, 0] = X[1] - X[0] |
|
|
|
dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] |
|
dX[:, 0, 1] = X[:, 1] - X[:, 0] |
|
|
|
dX = dX.permute(3, 0, 1, 2) if batched else dX |
|
dX[..., 0] /= delta_lst[0] |
|
dX[..., 1] /= delta_lst[1] |
|
elif dim == 3: |
|
|
|
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) |
|
X = X.permute(1, 2, 3, 0) if batched else X |
|
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX |
|
dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] |
|
dX[0, :, :, 0] = X[1] - X[0] |
|
|
|
dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] |
|
dX[:, 0, :, 1] = X[:, 1] - X[:, 0] |
|
|
|
dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] |
|
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] |
|
|
|
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX |
|
dX[..., 0] /= delta_lst[0] |
|
dX[..., 1] /= delta_lst[1] |
|
dX[..., 2] /= delta_lst[2] |
|
return dX |
|
|
|
|
|
def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]): |
|
''' |
|
Compute gradient of a torch tensor "X" in each direction |
|
Non-boundaries: Central Difference |
|
Upper-boundaries: Backward Difference |
|
Lower-boundaries: Forward Difference |
|
if X is batched: (n_batch, ...); |
|
else: (...) |
|
''' |
|
|
|
device = X.device |
|
dim = len(X.size()) - 1 if batched else len(X.size()) |
|
|
|
|
|
|
|
if dim == 1: |
|
|
|
dX = torch.zeros(X.size(), dtype = torch.float, device = device) |
|
X = X.permute(1, 0) if batched else X |
|
dX = dX.permute(1, 0) if batched else dX |
|
dX[1:-1] = (X[2:] - X[:-2]) / 2 |
|
dX[0] = X[1] - X[0] |
|
dX[-1] = X[-1] - X[-2] |
|
|
|
dX = dX.permute(1, 0) if batched else dX |
|
dX /= delta_lst[0] |
|
elif dim == 2: |
|
|
|
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device) |
|
X = X.permute(1, 2, 0) if batched else X |
|
dX = dX.permute(1, 2, 3, 0) if batched else dX |
|
dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2 |
|
dX[0, :, 0] = X[1] - X[0] |
|
dX[-1, :, 0] = X[-1] - X[-2] |
|
dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2 |
|
dX[:, 0, 1] = X[:, 1] - X[:, 0] |
|
dX[:, -1, 1] = X[:, -1] - X[:, -2] |
|
|
|
dX = dX.permute(3, 0, 1, 2) if batched else dX |
|
dX[..., 0] /= delta_lst[0] |
|
dX[..., 1] /= delta_lst[1] |
|
elif dim == 3: |
|
|
|
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device) |
|
X = X.permute(1, 2, 3, 0) if batched else X |
|
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX |
|
dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2 |
|
dX[0, :, :, 0] = X[1] - X[0] |
|
dX[-1, :, :, 0] = X[-1] - X[-2] |
|
dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2 |
|
dX[:, 0, :, 1] = X[:, 1] - X[:, 0] |
|
dX[:, -1, :, 1] = X[:, -1] - X[:, -2] |
|
dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2 |
|
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] |
|
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] |
|
|
|
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX |
|
dX[..., 0] /= delta_lst[0] |
|
dX[..., 1] /= delta_lst[1] |
|
dX[..., 2] /= delta_lst[2] |
|
|
|
return dX |
|
|
|
|
|
|