Spaces:
Running
Running
| import os | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import ToTensor | |
| def get_mgrid(sidelen, vmin=-1, vmax=1): | |
| if type(vmin) is not list: | |
| vmin = [vmin for _ in range(len(sidelen))] | |
| if type(vmax) is not list: | |
| vmax = [vmax for _ in range(len(sidelen))] | |
| tensors = tuple([torch.linspace(vmin[i], vmax[i], steps=sidelen[i]) for i in range(len(sidelen))]) | |
| mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1) | |
| mgrid = mgrid.reshape(-1, len(sidelen)) | |
| return mgrid | |
| def apply_homography(x, h): | |
| h = torch.cat([h, torch.ones_like(h[:, [0]])], -1) | |
| h = h.view(-1, 3, 3) | |
| x = torch.cat([x, torch.ones_like(x[:, 0]).unsqueeze(-1)], -1).unsqueeze(-1) | |
| o = torch.bmm(h, x).squeeze(-1) | |
| o = o[:, :-1] / o[:, [-1]] | |
| return o | |
| def jacobian(y, x): | |
| B, N = y.shape | |
| jacobian = list() | |
| for i in range(N): | |
| v = torch.zeros_like(y) | |
| v[:, i] = 1. | |
| dy_i_dx = torch.autograd.grad(y, | |
| x, | |
| grad_outputs=v, | |
| retain_graph=True, | |
| create_graph=True)[0] # shape [B, N] | |
| jacobian.append(dy_i_dx) | |
| jacobian = torch.stack(jacobian, dim=1).requires_grad_() | |
| return jacobian | |
| def overlap_mix(img1, img2, img_order, overlap_num): | |
| w1 = np.linspace(0, 1, overlap_num)[::-1] | |
| w2 = 1 - w1 | |
| return w1[img_order] * img1 + w2[img_order] * img2 | |
| class VideoFitting(Dataset): | |
| def __init__(self, path, transform=None): | |
| super().__init__() | |
| self.path = path | |
| if transform is None: | |
| self.transform = ToTensor() | |
| else: | |
| self.transform = transform | |
| self.video = self.get_video_tensor() | |
| self.num_frames, _, self.H, self.W = self.video.size() | |
| self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
| self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
| shuffle = torch.randperm(len(self.pixels)) | |
| self.pixels = self.pixels[shuffle] | |
| self.coords = self.coords[shuffle] | |
| def get_video_tensor(self): | |
| frames = sorted(os.listdir(self.path)) | |
| video = [] | |
| for i in range(len(frames)): | |
| img = Image.open(os.path.join(self.path, frames[i])) | |
| img = self.transform(img) | |
| video.append(img) | |
| return torch.stack(video, 0) | |
| def __len__(self): | |
| return 1 | |
| def __getitem__(self, idx): | |
| if idx > 0: raise IndexError | |
| return self.coords, self.pixels | |
| class TestVideoFitting(Dataset): | |
| def __init__(self, path, transform=None): | |
| super().__init__() | |
| self.path = path | |
| if transform is None: | |
| self.transform = ToTensor() | |
| else: | |
| self.transform = transform | |
| self.video = self.get_video_tensor() | |
| self.num_frames, _, self.H, self.W = self.video.size() | |
| self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
| self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
| def get_video_tensor(self): | |
| frames = sorted(os.listdir(self.path)) | |
| video = [] | |
| for i in range(len(frames)): | |
| img = Image.open(os.path.join(self.path, frames[i])) | |
| img = self.transform(img) | |
| video.append(img) | |
| return torch.stack(video, 0) | |
| def __len__(self): | |
| return 1 | |
| def __getitem__(self, idx): | |
| if idx > 0: raise IndexError | |
| return self.coords, self.pixels | |
| class GroupVideoFitting(Dataset): | |
| def __init__(self, path, mask_path, transform=None, mask_transform=None): | |
| super().__init__() | |
| self.path = path | |
| self.mask_path = mask_path | |
| if transform is None: | |
| self.transform = ToTensor() | |
| else: | |
| self.transform = transform | |
| if mask_transform is None: | |
| self.mask_transform = ToTensor() | |
| else: | |
| self.mask_transform = mask_transform | |
| self.video = self.get_video_tensor() | |
| self.mask = self.get_mask_tensor() | |
| self.num_frames, _, self.H, self.W = self.video.size() | |
| self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
| self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1) | |
| self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
| shuffle = torch.randperm(len(self.pixels)) | |
| self.pixels = self.pixels[shuffle] | |
| self.coords = self.coords[shuffle] | |
| self.mask_pixels = self.mask_pixels[shuffle] | |
| def get_video_tensor(self): | |
| frames = sorted(os.listdir(self.path)) | |
| video = [] | |
| for i in range(len(frames)): | |
| img = Image.open(os.path.join(self.path, frames[i])) | |
| img = self.transform(img) | |
| video.append(img) | |
| return torch.stack(video, 0) | |
| def get_mask_tensor(self): | |
| masks = sorted(os.listdir(self.mask_path)) | |
| all_mask = [] | |
| for i in range(len(masks)): | |
| mask = Image.open(os.path.join(self.mask_path, masks[i])) | |
| mask = self.mask_transform(mask) | |
| all_mask.append(mask) | |
| return torch.stack(all_mask, 0) | |
| def __len__(self): | |
| return 1 | |
| def __getitem__(self, idx): | |
| if idx > 0: raise IndexError | |
| return self.coords, self.pixels, self.mask_pixels | |
| class TestGroupVideoFitting(Dataset): | |
| def __init__(self, path, mask_path, back_mask_path, transform=None, mask_transform=None): | |
| super().__init__() | |
| self.path = path | |
| self.mask_path = mask_path | |
| self.back_mask_path = back_mask_path | |
| if transform is None: | |
| self.transform = ToTensor() | |
| else: | |
| self.transform = transform | |
| if mask_transform is None: | |
| self.mask_transform = ToTensor() | |
| else: | |
| self.mask_transform = mask_transform | |
| self.video = self.get_video_tensor() | |
| self.mask = self.get_mask_tensor() | |
| self.back_mask = self.get_back_mask_tensor() | |
| self.num_frames, _, self.H, self.W = self.video.size() | |
| self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
| self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1) | |
| self.back_mask_pixels = self.back_mask.permute(2, 3, 0, 1).contiguous().view(-1, 1) | |
| self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
| def get_video_tensor(self): | |
| frames = sorted(os.listdir(self.path)) | |
| video = [] | |
| for i in range(len(frames)): | |
| img = Image.open(os.path.join(self.path, frames[i])) | |
| img = self.transform(img) | |
| video.append(img) | |
| return torch.stack(video, 0) | |
| def get_mask_tensor(self): | |
| masks = sorted(os.listdir(self.mask_path)) | |
| all_mask = [] | |
| for i in range(len(masks)): | |
| mask = Image.open(os.path.join(self.mask_path, masks[i])) | |
| mask = self.mask_transform(mask) | |
| all_mask.append(mask) | |
| return torch.stack(all_mask, 0) | |
| def get_back_mask_tensor(self): | |
| masks = sorted(os.listdir(self.back_mask_path)) | |
| all_mask = [] | |
| for i in range(len(masks)): | |
| mask = Image.open(os.path.join(self.back_mask_path, masks[i])) | |
| mask = self.mask_transform(mask) | |
| all_mask.append(mask) | |
| return torch.stack(all_mask, 0) | |
| def __len__(self): | |
| return 1 | |
| def __getitem__(self, idx): | |
| if idx > 0: raise IndexError | |
| return self.coords, self.pixels, self.mask_pixels, self.back_mask_pixels |