LituRout's picture
add dps
0de10e5
import numpy as np
import torch
import scipy
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
from motionblur.motionblur import Kernel
from .fastmri_utils import fft2c_new, ifft2c_new
"""
Helper functions for new types of inverse problems
"""
def fft2(x):
""" FFT with shifting DC to the center of the image"""
return torch.fft.fftshift(torch.fft.fft2(x), dim=[-1, -2])
def ifft2(x):
""" IFFT with shifting DC to the corner of the image prior to transform"""
return torch.fft.ifft2(torch.fft.ifftshift(x, dim=[-1, -2]))
def fft2_m(x):
""" FFT for multi-coil """
if not torch.is_complex(x):
x = x.type(torch.complex64)
return torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
def ifft2_m(x):
""" IFFT for multi-coil """
if not torch.is_complex(x):
x = x.type(torch.complex64)
return torch.view_as_complex(ifft2c_new(torch.view_as_real(x)))
def clear(x):
x = x.detach().cpu().squeeze().numpy()
return normalize_np(x)
def clear_color(x):
if torch.is_complex(x):
x = torch.abs(x)
x = x.detach().cpu().squeeze().numpy()
return normalize_np(np.transpose(x, (1, 2, 0)))
def normalize_np(img):
""" Normalize img in arbitrary range to [0, 1] """
img -= np.min(img)
img /= np.max(img)
return img
def prepare_im(load_dir, image_size, device):
ref_img = torch.from_numpy(normalize_np(plt.imread(load_dir)[:, :, :3].astype(np.float32))).to(device)
ref_img = ref_img.permute(2, 0, 1)
ref_img = ref_img.view(1, 3, image_size, image_size)
ref_img = ref_img * 2 - 1
return ref_img
def fold_unfold(img_t, kernel, stride):
img_shape = img_t.shape
B, C, H, W = img_shape
print("\n----- input shape: ", img_shape)
patches = img_t.unfold(3, kernel, stride).unfold(2, kernel, stride).permute(0, 1, 2, 3, 5, 4)
print("\n----- patches shape:", patches.shape)
# reshape output to match F.fold input
patches = patches.contiguous().view(B, C, -1, kernel*kernel)
print("\n", patches.shape) # [B, C, nb_patches_all, kernel_size*kernel_size]
patches = patches.permute(0, 1, 3, 2)
print("\n", patches.shape) # [B, C, kernel_size*kernel_size, nb_patches_all]
patches = patches.contiguous().view(B, C*kernel*kernel, -1)
print("\n", patches.shape) # [B, C*prod(kernel_size), L] as expected by Fold
output = F.fold(patches, output_size=(H, W),
kernel_size=kernel, stride=stride)
# mask that mimics the original folding:
recovery_mask = F.fold(torch.ones_like(patches), output_size=(
H, W), kernel_size=kernel, stride=stride)
output = output/recovery_mask
return patches, output
def reshape_patch(x, crop_size=128, dim_size=3):
x = x.transpose(0, 2).squeeze() # [9, 3*(128**2)]
x = x.view(dim_size**2, 3, crop_size, crop_size)
return x
def reshape_patch_back(x, crop_size=128, dim_size=3):
x = x.view(dim_size**2, 3*(crop_size**2)).unsqueeze(dim=-1)
x = x.transpose(0, 2)
return x
class Unfolder:
def __init__(self, img_size=256, crop_size=128, stride=64):
self.img_size = img_size
self.crop_size = crop_size
self.stride = stride
self.unfold = nn.Unfold(crop_size, stride=stride)
self.dim_size = (img_size - crop_size) // stride + 1
def __call__(self, x):
patch1D = self.unfold(x)
patch2D = reshape_patch(patch1D, crop_size=self.crop_size, dim_size=self.dim_size)
return patch2D
def center_crop(img, new_width=None, new_height=None):
width = img.shape[1]
height = img.shape[0]
if new_width is None:
new_width = min(width, height)
if new_height is None:
new_height = min(width, height)
left = int(np.ceil((width - new_width) / 2))
right = width - int(np.floor((width - new_width) / 2))
top = int(np.ceil((height - new_height) / 2))
bottom = height - int(np.floor((height - new_height) / 2))
if len(img.shape) == 2:
center_cropped_img = img[top:bottom, left:right]
else:
center_cropped_img = img[top:bottom, left:right, ...]
return center_cropped_img
class Folder:
def __init__(self, img_size=256, crop_size=128, stride=64):
self.img_size = img_size
self.crop_size = crop_size
self.stride = stride
self.fold = nn.Fold(img_size, crop_size, stride=stride)
self.dim_size = (img_size - crop_size) // stride + 1
def __call__(self, patch2D):
patch1D = reshape_patch_back(patch2D, crop_size=self.crop_size, dim_size=self.dim_size)
return self.fold(patch1D)
def random_sq_bbox(img, mask_shape, image_size=256, margin=(16, 16)):
"""Generate a random sqaure mask for inpainting
"""
B, C, H, W = img.shape
h, w = mask_shape
margin_height, margin_width = margin
maxt = image_size - margin_height - h
maxl = image_size - margin_width - w
# bb
t = np.random.randint(margin_height, maxt)
l = np.random.randint(margin_width, maxl)
# make mask
mask = torch.ones([B, C, H, W], device=img.device)
mask[..., t:t+h, l:l+w] = 0
return mask, t, t+h, l, l+w
class mask_generator:
def __init__(self, mask_type, mask_len_range=None, mask_prob_range=None,
image_size=256, margin=(16, 16)):
"""
(mask_len_range): given in (min, max) tuple.
Specifies the range of box size in each dimension
(mask_prob_range): for the case of random masking,
specify the probability of individual pixels being masked
"""
assert mask_type in ['box', 'random', 'both', 'extreme']
self.mask_type = mask_type
self.mask_len_range = mask_len_range
self.mask_prob_range = mask_prob_range
self.image_size = image_size
self.margin = margin
def _retrieve_box(self, img):
l, h = self.mask_len_range
l, h = int(l), int(h)
mask_h = np.random.randint(l, h)
mask_w = np.random.randint(l, h)
mask, t, tl, w, wh = random_sq_bbox(img,
mask_shape=(mask_h, mask_w),
image_size=self.image_size,
margin=self.margin)
return mask, t, tl, w, wh
def _retrieve_random(self, img):
total = self.image_size ** 2
# random pixel sampling
l, h = self.mask_prob_range
prob = np.random.uniform(l, h)
mask_vec = torch.ones([1, self.image_size * self.image_size])
samples = np.random.choice(self.image_size * self.image_size, int(total * prob), replace=False)
mask_vec[:, samples] = 0
mask_b = mask_vec.view(1, self.image_size, self.image_size)
mask_b = mask_b.repeat(3, 1, 1)
mask = torch.ones_like(img, device=img.device)
mask[:, ...] = mask_b
return mask
def __call__(self, img):
if self.mask_type == 'random':
mask = self._retrieve_random(img)
return mask
elif self.mask_type == 'box':
mask, t, th, w, wl = self._retrieve_box(img)
return mask
elif self.mask_type == 'extreme':
mask, t, th, w, wl = self._retrieve_box(img)
mask = 1. - mask
return mask
def unnormalize(img, s=0.95):
scaling = torch.quantile(img.abs(), s)
return img / scaling
def normalize(img, s=0.95):
scaling = torch.quantile(img.abs(), s)
return img * scaling
def dynamic_thresholding(img, s=0.95):
img = normalize(img, s=s)
return torch.clip(img, -1., 1.)
def get_gaussian_kernel(kernel_size=31, std=0.5):
n = np.zeros([kernel_size, kernel_size])
n[kernel_size//2, kernel_size//2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=std)
k = k.astype(np.float32)
return k
def init_kernel_torch(kernel, device="cuda:0"):
h, w = kernel.shape
kernel = Variable(torch.from_numpy(kernel).to(device), requires_grad=True)
kernel = kernel.view(1, 1, h, w)
kernel = kernel.repeat(1, 3, 1, 1)
return kernel
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.device = device
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k).to(self.device)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
class exact_posterior():
def __init__(self, betas, sigma_0, label_dim, input_dim):
self.betas = betas
self.sigma_0 = sigma_0
self.label_dim = label_dim
self.input_dim = input_dim
def py_given_x0(self, x0, y, A, verbose=False):
norm_const = 1/((2 * np.pi)**self.input_dim * self.sigma_0**2)
exp_in = -1/(2 * self.sigma_0**2) * torch.linalg.norm(y - A(x0))**2
if not verbose:
return norm_const * torch.exp(exp_in)
else:
return norm_const * torch.exp(exp_in), norm_const, exp_in
def pxt_given_x0(self, x0, xt, t, verbose=False):
beta_t = self.betas[t]
norm_const = 1/((2 * np.pi)**self.label_dim * beta_t)
exp_in = -1/(2 * beta_t) * torch.linalg.norm(xt - np.sqrt(1 - beta_t)*x0)**2
if not verbose:
return norm_const * torch.exp(exp_in)
else:
return norm_const * torch.exp(exp_in), norm_const, exp_in
def prod_logsumexp(self, x0, xt, y, A, t):
py_given_x0_density, pyx0_nc, pyx0_ei = self.py_given_x0(x0, y, A, verbose=True)
pxt_given_x0_density, pxtx0_nc, pxtx0_ei = self.pxt_given_x0(x0, xt, t, verbose=True)
summand = (pyx0_nc * pxtx0_nc) * torch.exp(-pxtx0_ei - pxtx0_ei)
return torch.logsumexp(summand, dim=0)
def map2tensor(gray_map):
"""Move gray maps to GPU, no normalization is done"""
return torch.FloatTensor(gray_map).unsqueeze(0).unsqueeze(0).cuda()
def create_penalty_mask(k_size, penalty_scale):
"""Generate a mask of weights penalizing values close to the boundaries"""
center_size = k_size // 2 + k_size % 2
mask = create_gaussian(size=k_size, sigma1=k_size, is_tensor=False)
mask = 1 - mask / np.max(mask)
margin = (k_size - center_size) // 2 - 1
mask[margin:-margin, margin:-margin] = 0
return penalty_scale * mask
def create_gaussian(size, sigma1, sigma2=-1, is_tensor=False):
"""Return a Gaussian"""
func1 = [np.exp(-z ** 2 / (2 * sigma1 ** 2)) / np.sqrt(2 * np.pi * sigma1 ** 2) for z in range(-size // 2 + 1, size // 2 + 1)]
func2 = func1 if sigma2 == -1 else [np.exp(-z ** 2 / (2 * sigma2 ** 2)) / np.sqrt(2 * np.pi * sigma2 ** 2) for z in range(-size // 2 + 1, size // 2 + 1)]
return torch.FloatTensor(np.outer(func1, func2)).cuda() if is_tensor else np.outer(func1, func2)
def total_variation_loss(img, weight):
tv_h = ((img[:, :, 1:, :] - img[:, :, :-1, :]).pow(2)).mean()
tv_w = ((img[:, :, :, 1:] - img[:, :, :, :-1]).pow(2)).mean()
return weight * (tv_h + tv_w)
if __name__ == '__main__':
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
device = 'cuda:0'
load_path = '/media/harry/tomo/FFHQ/256/test/00000.png'
img = torch.tensor(plt.imread(load_path)[:, :, :3]) #rgb
img = torch.permute(img, (2, 0, 1)).view(1, 3, 256, 256).to(device)
mask_len_range = (32, 128)
mask_prob_range = (0.3, 0.7)
image_size = 256
# mask
mask_gen = mask_generator(
mask_len_range=mask_len_range,
mask_prob_range=mask_prob_range,
image_size=image_size
)
mask = mask_gen(img)
mask = np.transpose(mask.squeeze().cpu().detach().numpy(), (1, 2, 0))
plt.imshow(mask)
plt.show()