BrainFM / ShapeID /misc.py
peirong26's picture
Upload 187 files
2571f24 verified
# ported from https://github.com/pvigier/perlin-numpy
import torch
import numpy as np
import matplotlib.pyplot as plt
def center_crop(img, win_size = [220, 220, 220]):
# center crop
if len(img.shape) == 4:
img = torch.permute(img, (3, 0, 1, 2)) # (move last dim to first)
img = img[None]
permuted = True
else:
assert len(img.shape) == 3
img = img[None, None]
permuted = False
orig_shp = img.shape[2:] # (1, d, s, r, c)
if win_size is None:
if permuted:
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
return img, [0, 0, 0], orig_shp
elif orig_shp[0] > win_size[0] or orig_shp[1] > win_size[1] or orig_shp[2] > win_size[2]:
crop_start = [ max((orig_shp[i] - win_size[i]), 0) // 2 for i in range(3) ]
crop_img = img[ :, :, crop_start[0] : crop_start[0] + win_size[0],
crop_start[1] : crop_start[1] + win_size[1],
crop_start[2] : crop_start[2] + win_size[2]]
if permuted:
return torch.permute(crop_img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
return crop_img, crop_start, orig_shp
else:
if permuted:
return torch.permute(img, (0, 2, 3, 4, 1)), [0, 0, 0], orig_shp
return img, [0, 0, 0], orig_shp
def V_plot(Vx, Vy, save_path):
# Meshgrid
X,Y = np.meshgrid(np.arange(0, Vx.shape[0], 1), np.arange(0, Vx.shape[1], 1))
# Assign vector directions
Ex = Vx
Ey = Vy
# Depict illustration
plt.figure()
plt.streamplot(X,Y,Ex,Ey, density=1.4, linewidth=None, color='orange')
plt.axis('off')
plt.savefig(save_path)
#plt.show()
def stream_2D(Phi, batched = False, delta_lst = [1., 1.]):
'''
input: Phi as a scalar field in 2D grid: (r, c) or (n_batch, r, c)
output: curl of Phi (divergence-free by definition)
'''
dD = gradient_c(Phi, batched = batched, delta_lst = delta_lst)
Vx = - dD[..., 1]
Vy = dD[..., 0]
return Vx, Vy
def stream_3D(Phi_a, Phi_b, Phi_c, batched = False, delta_lst = [1., 1., 1.]):
'''
input: (batch, s, r, c)
'''
device = Phi_a.device
dDa = gradient_c(Phi_a, batched = batched, delta_lst = delta_lst)
dDb = gradient_c(Phi_b, batched = batched, delta_lst = delta_lst)
dDc = gradient_c(Phi_c, batched = batched, delta_lst = delta_lst)
Va_x, Va_y, Va_z = dDa[..., 0], dDa[..., 1], dDa[..., 2]
Vb_x, Vb_y, Vb_z = dDb[..., 0], dDb[..., 1], dDb[..., 2]
Vc_x, Vc_y, Vc_z = dDc[..., 0], dDc[..., 1], dDc[..., 2]
Vx = Vc_y - Vb_z
Vy = Va_z - Vc_x
Vz = Vb_x - Va_y
return Vx, Vy, Vz
def gradient_f(X, batched = False, delta_lst = [1., 1., 1.]):
'''
Compute gradient of a torch tensor "X" in each direction
Upper-boundaries: Backward Difference
Non-boundaries & Upper-boundaries: Forward Difference
if X is batched: (n_batch, ...);
else: (...)
'''
device = X.device
dim = len(X.size()) - 1 if batched else len(X.size())
#print(batched)
#print(dim)
if dim == 1:
#print('dim = 1')
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
X = X.permute(1, 0) if batched else X
dX = dX.permute(1, 0) if batched else dX
dX[-1] = X[-1] - X[-2] # Backward Difference
dX[:-1] = X[1:] - X[:-1] # Forward Difference
dX = dX.permute(1, 0) if batched else dX
dX /= delta_lst[0]
elif dim == 2:
#print('dim = 2')
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
X = X.permute(1, 2, 0) if batched else X
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
dX[-1, :, 0] = X[-1, :] - X[-2, :] # Backward Difference
dX[:-1, :, 0] = X[1:] - X[:-1] # Forward Difference
dX[:, -1, 1] = X[:, -1] - X[:, -2] # Backward Difference
dX[:, :-1, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
dX = dX.permute(3, 0, 1, 2) if batched else dX
dX[..., 0] /= delta_lst[0]
dX[..., 1] /= delta_lst[1]
elif dim == 3:
#print('dim = 3')
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
X = X.permute(1, 2, 3, 0) if batched else X
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
dX[-1, :, :, 0] = X[-1, :, :] - X[-2, :, :] # Backward Difference
dX[:-1, :, :, 0] = X[1:] - X[:-1] # Forward Difference
dX[:, -1, :, 1] = X[:, -1] - X[:, -2] # Backward Difference
dX[:, :-1, :, 1] = X[:, 1:] - X[:, :-1] # Forward Difference
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2] # Backward Difference
dX[:, :, :-1, 2] = X[:, :, 1:] - X[:, :, :-1] # Forward Difference
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
dX[..., 0] /= delta_lst[0]
dX[..., 1] /= delta_lst[1]
dX[..., 2] /= delta_lst[2]
return dX
def gradient_b(X, batched = False, delta_lst = [1., 1., 1.]):
'''
Compute gradient of a torch tensor "X" in each direction
Non-boundaries & Upper-boundaries: Backward Difference
Lower-boundaries: Forward Difference
if X is batched: (n_batch, ...);
else: (...)
'''
device = X.device
dim = len(X.size()) - 1 if batched else len(X.size())
#print(batched)
#print(dim)
if dim == 1:
#print('dim = 1')
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
X = X.permute(1, 0) if batched else X
dX = dX.permute(1, 0) if batched else dX
dX[1:] = X[1:] - X[:-1] # Backward Difference
dX[0] = X[1] - X[0] # Forward Difference
dX = dX.permute(1, 0) if batched else dX
dX /= delta_lst[0]
elif dim == 2:
#print('dim = 2')
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
X = X.permute(1, 2, 0) if batched else X
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
dX[1:, :, 0] = X[1:, :] - X[:-1, :] # Backward Difference
dX[0, :, 0] = X[1] - X[0] # Forward Difference
dX[:, 1:, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
dX[:, 0, 1] = X[:, 1] - X[:, 0] # Forward Difference
dX = dX.permute(3, 0, 1, 2) if batched else dX
dX[..., 0] /= delta_lst[0]
dX[..., 1] /= delta_lst[1]
elif dim == 3:
#print('dim = 3')
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
X = X.permute(1, 2, 3, 0) if batched else X
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
dX[1:, :, :, 0] = X[1:, :, :] - X[:-1, :, :] # Backward Difference
dX[0, :, :, 0] = X[1] - X[0] # Forward Difference
dX[:, 1:, :, 1] = X[:, 1:] - X[:, :-1] # Backward Difference
dX[:, 0, :, 1] = X[:, 1] - X[:, 0] # Forward Difference
dX[:, :, 1:, 2] = X[:, :, 1:] - X[:, :, :-1] # Backward Difference
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0] # Forward Difference
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
dX[..., 0] /= delta_lst[0]
dX[..., 1] /= delta_lst[1]
dX[..., 2] /= delta_lst[2]
return dX
def gradient_c(X, batched = False, delta_lst = [1., 1., 1.]):
'''
Compute gradient of a torch tensor "X" in each direction
Non-boundaries: Central Difference
Upper-boundaries: Backward Difference
Lower-boundaries: Forward Difference
if X is batched: (n_batch, ...);
else: (...)
'''
device = X.device
dim = len(X.size()) - 1 if batched else len(X.size())
#print(X.size())
#print(batched)
#print(dim)
if dim == 1:
#print('dim = 1')
dX = torch.zeros(X.size(), dtype = torch.float, device = device)
X = X.permute(1, 0) if batched else X
dX = dX.permute(1, 0) if batched else dX
dX[1:-1] = (X[2:] - X[:-2]) / 2 # Central Difference
dX[0] = X[1] - X[0] # Forward Difference
dX[-1] = X[-1] - X[-2] # Backward Difference
dX = dX.permute(1, 0) if batched else dX
dX /= delta_lst[0]
elif dim == 2:
#print('dim = 2')
dX = torch.zeros(X.size() + tuple([2]), dtype = torch.float, device = device)
X = X.permute(1, 2, 0) if batched else X
dX = dX.permute(1, 2, 3, 0) if batched else dX # put batch to last dim
dX[1:-1, :, 0] = (X[2:, :] - X[:-2, :]) / 2
dX[0, :, 0] = X[1] - X[0]
dX[-1, :, 0] = X[-1] - X[-2]
dX[:, 1:-1, 1] = (X[:, 2:] - X[:, :-2]) / 2
dX[:, 0, 1] = X[:, 1] - X[:, 0]
dX[:, -1, 1] = X[:, -1] - X[:, -2]
dX = dX.permute(3, 0, 1, 2) if batched else dX
dX[..., 0] /= delta_lst[0]
dX[..., 1] /= delta_lst[1]
elif dim == 3:
#print('dim = 3')
dX = torch.zeros(X.size() + tuple([3]), dtype = torch.float, device = device)
X = X.permute(1, 2, 3, 0) if batched else X
dX = dX.permute(1, 2, 3, 4, 0) if batched else dX
dX[1:-1, :, :, 0] = (X[2:, :, :] - X[:-2, :, :]) / 2
dX[0, :, :, 0] = X[1] - X[0]
dX[-1, :, :, 0] = X[-1] - X[-2]
dX[:, 1:-1, :, 1] = (X[:, 2:, :] - X[:, :-2, :]) / 2
dX[:, 0, :, 1] = X[:, 1] - X[:, 0]
dX[:, -1, :, 1] = X[:, -1] - X[:, -2]
dX[:, :, 1:-1, 2] = (X[:, :, 2:] - X[:, :, :-2]) / 2
dX[:, :, 0, 2] = X[:, :, 1] - X[:, :, 0]
dX[:, :, -1, 2] = X[:, :, -1] - X[:, :, -2]
dX = dX.permute(4, 0, 1, 2, 3) if batched else dX
dX[..., 0] /= delta_lst[0]
dX[..., 1] /= delta_lst[1]
dX[..., 2] /= delta_lst[2]
return dX