import os import json import math import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader, IterableDataset import torchvision.transforms.functional as TF import pytorch_lightning as pl import datasets from models.ray_utils import get_ray_directions from utils.misc import get_rank class BlenderDatasetBase: def setup(self, config, split): self.config = config self.split = split self.rank = get_rank() self.has_mask = True self.apply_mask = True with open( os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), "r" ) as f: meta = json.load(f) if "w" in meta and "h" in meta: W, H = int(meta["w"]), int(meta["h"]) else: W, H = 800, 800 if "img_wh" in self.config: w, h = self.config.img_wh assert round(W / w * h) == H elif "img_downscale" in self.config: w, h = W // self.config.img_downscale, H // self.config.img_downscale else: raise KeyError("Either img_wh or img_downscale should be specified.") self.w, self.h = w, h self.img_wh = (self.w, self.h) self.near, self.far = self.config.near_plane, self.config.far_plane self.focal = ( 0.5 * w / math.tan(0.5 * meta["camera_angle_x"]) ) # scaled focal length # ray directions for all pixels, same for all images (same H, W, focal) self.directions = get_ray_directions( self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2 ).to( self.rank ) # (h, w, 3) self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] for i, frame in enumerate(meta["frames"]): c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4]) self.all_c2w.append(c2w) img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png") img = Image.open(img_path) img = img.resize(self.img_wh, Image.BICUBIC) img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) self.all_fg_masks.append(img[..., -1]) # (h, w) self.all_images.append(img[..., :3]) self.all_c2w, self.all_images, self.all_fg_masks = ( torch.stack(self.all_c2w, dim=0).float().to(self.rank), torch.stack(self.all_images, dim=0).float().to(self.rank), torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), ) class BlenderDataset(Dataset, BlenderDatasetBase): def __init__(self, config, split): self.setup(config, split) def __len__(self): return len(self.all_images) def __getitem__(self, index): return {"index": index} class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): def __init__(self, config, split): self.setup(config, split) def __iter__(self): while True: yield {} @datasets.register("blender") class VideoNVSDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.config = config def setup(self, stage=None): if stage in [None, "fit"]: self.train_dataset = BlenderIterableDataset( self.config, self.config.train_split ) if stage in [None, "fit", "validate"]: self.val_dataset = BlenderDataset(self.config, self.config.val_split) if stage in [None, "test"]: self.test_dataset = BlenderDataset(self.config, self.config.test_split) if stage in [None, "predict"]: self.predict_dataset = BlenderDataset(self.config, self.config.train_split) def prepare_data(self): pass def general_loader(self, dataset, batch_size): sampler = None return DataLoader( dataset, num_workers=os.cpu_count(), batch_size=batch_size, pin_memory=True, sampler=sampler, ) def train_dataloader(self): return self.general_loader(self.train_dataset, batch_size=1) def val_dataloader(self): return self.general_loader(self.val_dataset, batch_size=1) def test_dataloader(self): return self.general_loader(self.test_dataset, batch_size=1) def predict_dataloader(self): return self.general_loader(self.predict_dataset, batch_size=1)