NaRCan_demo / util.py
Koi953215's picture
init commit
f9cbc98
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
def get_mgrid(sidelen, vmin=-1, vmax=1):
if type(vmin) is not list:
vmin = [vmin for _ in range(len(sidelen))]
if type(vmax) is not list:
vmax = [vmax for _ in range(len(sidelen))]
tensors = tuple([torch.linspace(vmin[i], vmax[i], steps=sidelen[i]) for i in range(len(sidelen))])
mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
mgrid = mgrid.reshape(-1, len(sidelen))
return mgrid
def apply_homography(x, h):
h = torch.cat([h, torch.ones_like(h[:, [0]])], -1)
h = h.view(-1, 3, 3)
x = torch.cat([x, torch.ones_like(x[:, 0]).unsqueeze(-1)], -1).unsqueeze(-1)
o = torch.bmm(h, x).squeeze(-1)
o = o[:, :-1] / o[:, [-1]]
return o
def jacobian(y, x):
B, N = y.shape
jacobian = list()
for i in range(N):
v = torch.zeros_like(y)
v[:, i] = 1.
dy_i_dx = torch.autograd.grad(y,
x,
grad_outputs=v,
retain_graph=True,
create_graph=True)[0] # shape [B, N]
jacobian.append(dy_i_dx)
jacobian = torch.stack(jacobian, dim=1).requires_grad_()
return jacobian
def overlap_mix(img1, img2, img_order, overlap_num):
w1 = np.linspace(0, 1, overlap_num)[::-1]
w2 = 1 - w1
return w1[img_order] * img1 + w2[img_order] * img2
class VideoFitting(Dataset):
def __init__(self, path, transform=None):
super().__init__()
self.path = path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
self.video = self.get_video_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
shuffle = torch.randperm(len(self.pixels))
self.pixels = self.pixels[shuffle]
self.coords = self.coords[shuffle]
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels
class TestVideoFitting(Dataset):
def __init__(self, path, transform=None):
super().__init__()
self.path = path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
self.video = self.get_video_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels
class GroupVideoFitting(Dataset):
def __init__(self, path, mask_path, transform=None, mask_transform=None):
super().__init__()
self.path = path
self.mask_path = mask_path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
if mask_transform is None:
self.mask_transform = ToTensor()
else:
self.mask_transform = mask_transform
self.video = self.get_video_tensor()
self.mask = self.get_mask_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
shuffle = torch.randperm(len(self.pixels))
self.pixels = self.pixels[shuffle]
self.coords = self.coords[shuffle]
self.mask_pixels = self.mask_pixels[shuffle]
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def get_mask_tensor(self):
masks = sorted(os.listdir(self.mask_path))
all_mask = []
for i in range(len(masks)):
mask = Image.open(os.path.join(self.mask_path, masks[i]))
mask = self.mask_transform(mask)
all_mask.append(mask)
return torch.stack(all_mask, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels, self.mask_pixels
class TestGroupVideoFitting(Dataset):
def __init__(self, path, mask_path, back_mask_path, transform=None, mask_transform=None):
super().__init__()
self.path = path
self.mask_path = mask_path
self.back_mask_path = back_mask_path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
if mask_transform is None:
self.mask_transform = ToTensor()
else:
self.mask_transform = mask_transform
self.video = self.get_video_tensor()
self.mask = self.get_mask_tensor()
self.back_mask = self.get_back_mask_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1)
self.back_mask_pixels = self.back_mask.permute(2, 3, 0, 1).contiguous().view(-1, 1)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def get_mask_tensor(self):
masks = sorted(os.listdir(self.mask_path))
all_mask = []
for i in range(len(masks)):
mask = Image.open(os.path.join(self.mask_path, masks[i]))
mask = self.mask_transform(mask)
all_mask.append(mask)
return torch.stack(all_mask, 0)
def get_back_mask_tensor(self):
masks = sorted(os.listdir(self.back_mask_path))
all_mask = []
for i in range(len(masks)):
mask = Image.open(os.path.join(self.back_mask_path, masks[i]))
mask = self.mask_transform(mask)
all_mask.append(mask)
return torch.stack(all_mask, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels, self.mask_pixels, self.back_mask_pixels