NaRCan_demo / util.py
Koi953215's picture
init commit
f9cbc98
raw
history blame contribute delete
No virus
7.79 kB
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
def get_mgrid(sidelen, vmin=-1, vmax=1):
if type(vmin) is not list:
vmin = [vmin for _ in range(len(sidelen))]
if type(vmax) is not list:
vmax = [vmax for _ in range(len(sidelen))]
tensors = tuple([torch.linspace(vmin[i], vmax[i], steps=sidelen[i]) for i in range(len(sidelen))])
mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
mgrid = mgrid.reshape(-1, len(sidelen))
return mgrid
def apply_homography(x, h):
h = torch.cat([h, torch.ones_like(h[:, [0]])], -1)
h = h.view(-1, 3, 3)
x = torch.cat([x, torch.ones_like(x[:, 0]).unsqueeze(-1)], -1).unsqueeze(-1)
o = torch.bmm(h, x).squeeze(-1)
o = o[:, :-1] / o[:, [-1]]
return o
def jacobian(y, x):
B, N = y.shape
jacobian = list()
for i in range(N):
v = torch.zeros_like(y)
v[:, i] = 1.
dy_i_dx = torch.autograd.grad(y,
x,
grad_outputs=v,
retain_graph=True,
create_graph=True)[0] # shape [B, N]
jacobian.append(dy_i_dx)
jacobian = torch.stack(jacobian, dim=1).requires_grad_()
return jacobian
def overlap_mix(img1, img2, img_order, overlap_num):
w1 = np.linspace(0, 1, overlap_num)[::-1]
w2 = 1 - w1
return w1[img_order] * img1 + w2[img_order] * img2
class VideoFitting(Dataset):
def __init__(self, path, transform=None):
super().__init__()
self.path = path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
self.video = self.get_video_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
shuffle = torch.randperm(len(self.pixels))
self.pixels = self.pixels[shuffle]
self.coords = self.coords[shuffle]
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels
class TestVideoFitting(Dataset):
def __init__(self, path, transform=None):
super().__init__()
self.path = path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
self.video = self.get_video_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels
class GroupVideoFitting(Dataset):
def __init__(self, path, mask_path, transform=None, mask_transform=None):
super().__init__()
self.path = path
self.mask_path = mask_path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
if mask_transform is None:
self.mask_transform = ToTensor()
else:
self.mask_transform = mask_transform
self.video = self.get_video_tensor()
self.mask = self.get_mask_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
shuffle = torch.randperm(len(self.pixels))
self.pixels = self.pixels[shuffle]
self.coords = self.coords[shuffle]
self.mask_pixels = self.mask_pixels[shuffle]
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def get_mask_tensor(self):
masks = sorted(os.listdir(self.mask_path))
all_mask = []
for i in range(len(masks)):
mask = Image.open(os.path.join(self.mask_path, masks[i]))
mask = self.mask_transform(mask)
all_mask.append(mask)
return torch.stack(all_mask, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels, self.mask_pixels
class TestGroupVideoFitting(Dataset):
def __init__(self, path, mask_path, back_mask_path, transform=None, mask_transform=None):
super().__init__()
self.path = path
self.mask_path = mask_path
self.back_mask_path = back_mask_path
if transform is None:
self.transform = ToTensor()
else:
self.transform = transform
if mask_transform is None:
self.mask_transform = ToTensor()
else:
self.mask_transform = mask_transform
self.video = self.get_video_tensor()
self.mask = self.get_mask_tensor()
self.back_mask = self.get_back_mask_tensor()
self.num_frames, _, self.H, self.W = self.video.size()
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3)
self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1)
self.back_mask_pixels = self.back_mask.permute(2, 3, 0, 1).contiguous().view(-1, 1)
self.coords = get_mgrid([self.H, self.W, self.num_frames])
def get_video_tensor(self):
frames = sorted(os.listdir(self.path))
video = []
for i in range(len(frames)):
img = Image.open(os.path.join(self.path, frames[i]))
img = self.transform(img)
video.append(img)
return torch.stack(video, 0)
def get_mask_tensor(self):
masks = sorted(os.listdir(self.mask_path))
all_mask = []
for i in range(len(masks)):
mask = Image.open(os.path.join(self.mask_path, masks[i]))
mask = self.mask_transform(mask)
all_mask.append(mask)
return torch.stack(all_mask, 0)
def get_back_mask_tensor(self):
masks = sorted(os.listdir(self.back_mask_path))
all_mask = []
for i in range(len(masks)):
mask = Image.open(os.path.join(self.back_mask_path, masks[i]))
mask = self.mask_transform(mask)
all_mask.append(mask)
return torch.stack(all_mask, 0)
def __len__(self):
return 1
def __getitem__(self, idx):
if idx > 0: raise IndexError
return self.coords, self.pixels, self.mask_pixels, self.back_mask_pixels