import os import json import math import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader, IterableDataset import torchvision.transforms.functional as TF from torchvision.utils import make_grid, save_image from einops import rearrange from mediapy import read_video from pathlib import Path from rembg import remove, new_session import pytorch_lightning as pl import datasets from models.ray_utils import get_ray_directions from utils.misc import get_rank from datasets.ortho import ( inv_RT, camNormal2worldNormal, RT_opengl2opencv, normal_opengl2opencv, ) from utils.dpt import DPT def get_c2w_from_up_and_look_at( up, look_at, pos, opengl=False, ): up = up / np.linalg.norm(up) z = look_at - pos z = z / np.linalg.norm(z) y = -up x = np.cross(y, z) x /= np.linalg.norm(x) y = np.cross(z, x) c2w = np.zeros([4, 4], dtype=np.float32) c2w[:3, 0] = x c2w[:3, 1] = y c2w[:3, 2] = z c2w[:3, 3] = pos c2w[3, 3] = 1.0 # opencv to opengl if opengl: c2w[..., 1:3] *= -1 return c2w def get_uniform_poses(num_frames, radius, elevation, opengl=False): T = num_frames azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T]) elevations = np.full_like(azimuths, np.deg2rad(elevation)) cam_dists = np.full_like(azimuths, radius) campos = np.stack( [ cam_dists * np.cos(elevations) * np.cos(azimuths), cam_dists * np.cos(elevations) * np.sin(azimuths), cam_dists * np.sin(elevations), ], axis=-1, ) center = np.array([0, 0, 0], dtype=np.float32) up = np.array([0, 0, 1], dtype=np.float32) poses = [] for t in range(T): poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl)) return np.stack(poses, axis=0) def blender2midas(img): """Blender: rub midas: lub """ img[..., 0] = -img[..., 0] img[..., 1] = -img[..., 1] img[..., -1] = -img[..., -1] return img def midas2blender(img): """Blender: rub midas: lub """ img[..., 0] = -img[..., 0] img[..., 1] = -img[..., 1] img[..., -1] = -img[..., -1] return img class BlenderDatasetBase: def setup(self, config, split): self.config = config self.rank = get_rank() self.has_mask = True self.apply_mask = True dpt = DPT(device=self.rank, mode="normal") # with open( # os.path.join( # self.config.root_dir, self.config.scene, f"transforms_train.json" # ), # "r", # ) as f: # meta = json.load(f) # if "w" in meta and "h" in meta: # W, H = int(meta["w"]), int(meta["h"]) # else: # W, H = 800, 800 frames = read_video(Path(self.config.root_dir) / f"{self.config.scene}") rembg_session = new_session() num_frames, H, W = frames.shape[:3] if "img_wh" in self.config: w, h = self.config.img_wh assert round(W / w * h) == H elif "img_downscale" in self.config: w, h = W // self.config.img_downscale, H // self.config.img_downscale else: raise KeyError("Either img_wh or img_downscale should be specified.") self.w, self.h = w, h self.img_wh = (self.w, self.h) # self.near, self.far = self.config.near_plane, self.config.far_plane self.focal = 0.5 * w / math.tan(0.5 * np.deg2rad(60)) # scaled focal length # ray directions for all pixels, same for all images (same H, W, focal) self.directions = get_ray_directions( self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2 ).to( self.rank ) # (h, w, 3) self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] radius = 2.0 elevation = 0.0 poses = get_uniform_poses(num_frames, radius, elevation, opengl=True) for i, (c2w, frame) in enumerate(zip(poses, frames)): c2w = torch.from_numpy(np.array(c2w)[:3, :4]) self.all_c2w.append(c2w) img = Image.fromarray(frame) img = remove(img, session=rembg_session) img = img.resize(self.img_wh, Image.BICUBIC) img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) self.all_fg_masks.append(img[..., -1]) # (h, w) self.all_images.append(img[..., :3]) self.all_c2w, self.all_images, self.all_fg_masks = ( torch.stack(self.all_c2w, dim=0).float().to(self.rank), torch.stack(self.all_images, dim=0).float().to(self.rank), torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), ) self.normals = dpt(self.all_images) self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1 self.normals = self.normals * 2.0 - 1.0 self.normals = midas2blender(self.normals).cpu().numpy() # self.normals = self.normals.cpu().numpy() self.normals[..., 0] *= -1 self.normals[~self.all_masks] = [0, 0, 0] normals = rearrange(self.normals, "b h w c -> b c h w") normals = normals * 0.5 + 0.5 normals = torch.from_numpy(normals) # save_image(make_grid(normals, nrow=4), "tmp/normals.png") # exit(0) ( self.all_poses, self.all_normals, self.all_normals_world, self.all_w2cs, self.all_color_masks, ) = ([], [], [], [], []) for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals): RT_opengl = inv_RT(c2w_opengl) RT_opencv = RT_opengl2opencv(RT_opengl) c2w_opencv = inv_RT(RT_opencv) self.all_poses.append(c2w_opencv) self.all_w2cs.append(RT_opencv) normal = normal_opengl2opencv(normal) normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal) self.all_normals.append(normal) self.all_normals_world.append(normal_world) self.directions = torch.stack([self.directions] * len(self.all_images)) self.origins = self.directions self.all_poses = np.stack(self.all_poses) self.all_normals = np.stack(self.all_normals) self.all_normals_world = np.stack(self.all_normals_world) self.all_w2cs = np.stack(self.all_w2cs) self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank) self.all_images = self.all_images.to(self.rank) self.all_fg_masks = self.all_fg_masks.to(self.rank) self.all_rgb_masks = self.all_fg_masks.to(self.rank) self.all_normals_world = ( torch.from_numpy(self.all_normals_world).float().to(self.rank) ) class BlenderDataset(Dataset, BlenderDatasetBase): def __init__(self, config, split): self.setup(config, split) def __len__(self): return len(self.all_images) def __getitem__(self, index): return {"index": index} class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): def __init__(self, config, split): self.setup(config, split) def __iter__(self): while True: yield {} @datasets.register("v3d") class BlenderDataModule(pl.LightningDataModule): def __init__(self, config): super().__init__() self.config = config def setup(self, stage=None): if stage in [None, "fit"]: self.train_dataset = BlenderIterableDataset( self.config, self.config.train_split ) if stage in [None, "fit", "validate"]: self.val_dataset = BlenderDataset(self.config, self.config.val_split) if stage in [None, "test"]: self.test_dataset = BlenderDataset(self.config, self.config.test_split) if stage in [None, "predict"]: self.predict_dataset = BlenderDataset(self.config, self.config.train_split) def prepare_data(self): pass def general_loader(self, dataset, batch_size): sampler = None return DataLoader( dataset, num_workers=os.cpu_count(), batch_size=batch_size, pin_memory=True, sampler=sampler, ) def train_dataloader(self): return self.general_loader(self.train_dataset, batch_size=1) def val_dataloader(self): return self.general_loader(self.val_dataset, batch_size=1) def test_dataloader(self): return self.general_loader(self.test_dataset, batch_size=1) def predict_dataloader(self): return self.general_loader(self.predict_dataset, batch_size=1)