|
import os |
|
import cv2 |
|
import random |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.data import Dataset |
|
|
|
import kiui |
|
from core.options import Options |
|
from core.utils import get_rays, grid_distortion, orbit_camera_jitter |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
class ObjaverseDataset(Dataset): |
|
|
|
def _warn(self): |
|
raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') |
|
|
|
def __init__(self, opt: Options, training=True): |
|
|
|
self.opt = opt |
|
self.training = training |
|
|
|
|
|
self._warn() |
|
|
|
|
|
self.items = [] |
|
with open('TODO: file containing the list', 'r') as f: |
|
for line in f.readlines(): |
|
self.items.append(line.strip()) |
|
|
|
|
|
if self.training: |
|
self.items = self.items[:-self.opt.batch_size] |
|
else: |
|
self.items = self.items[-self.opt.batch_size:] |
|
|
|
|
|
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) |
|
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) |
|
self.proj_matrix[0, 0] = 1 / self.tan_half_fov |
|
self.proj_matrix[1, 1] = 1 / self.tan_half_fov |
|
self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) |
|
self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) |
|
self.proj_matrix[2, 3] = 1 |
|
|
|
|
|
def __len__(self): |
|
return len(self.items) |
|
|
|
def __getitem__(self, idx): |
|
|
|
uid = self.items[idx] |
|
results = {} |
|
|
|
|
|
images = [] |
|
masks = [] |
|
cam_poses = [] |
|
|
|
vid_cnt = 0 |
|
|
|
|
|
if self.training: |
|
|
|
vids = np.random.permutation(np.arange(36, 73))[:self.opt.num_input_views].tolist() + np.random.permutation(100).tolist() |
|
else: |
|
|
|
vids = np.arange(36, 73, 4).tolist() + np.arange(100).tolist() |
|
|
|
for vid in vids: |
|
|
|
image_path = os.path.join(uid, 'rgb', f'{vid:03d}.png') |
|
camera_path = os.path.join(uid, 'pose', f'{vid:03d}.txt') |
|
|
|
try: |
|
|
|
image = np.frombuffer(self.client.get(image_path), np.uint8) |
|
image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) |
|
c2w = [float(t) for t in self.client.get(camera_path).decode().strip().split(' ')] |
|
c2w = torch.tensor(c2w, dtype=torch.float32).reshape(4, 4) |
|
except Exception as e: |
|
|
|
continue |
|
|
|
|
|
|
|
c2w[1] *= -1 |
|
c2w[[1, 2]] = c2w[[2, 1]] |
|
c2w[:3, 1:3] *= -1 |
|
|
|
|
|
c2w[:3, 3] *= self.opt.cam_radius / 1.5 |
|
|
|
image = image.permute(2, 0, 1) |
|
mask = image[3:4] |
|
image = image[:3] * mask + (1 - mask) |
|
image = image[[2,1,0]].contiguous() |
|
|
|
images.append(image) |
|
masks.append(mask.squeeze(0)) |
|
cam_poses.append(c2w) |
|
|
|
vid_cnt += 1 |
|
if vid_cnt == self.opt.num_views: |
|
break |
|
|
|
if vid_cnt < self.opt.num_views: |
|
print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') |
|
n = self.opt.num_views - vid_cnt |
|
images = images + [images[-1]] * n |
|
masks = masks + [masks[-1]] * n |
|
cam_poses = cam_poses + [cam_poses[-1]] * n |
|
|
|
images = torch.stack(images, dim=0) |
|
masks = torch.stack(masks, dim=0) |
|
cam_poses = torch.stack(cam_poses, dim=0) |
|
|
|
|
|
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) |
|
cam_poses = transform.unsqueeze(0) @ cam_poses |
|
|
|
images_input = F.interpolate(images[:self.opt.num_input_views].clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) |
|
cam_poses_input = cam_poses[:self.opt.num_input_views].clone() |
|
|
|
|
|
if self.training: |
|
|
|
if random.random() < self.opt.prob_grid_distortion: |
|
images_input[1:] = grid_distortion(images_input[1:]) |
|
|
|
if random.random() < self.opt.prob_cam_jitter: |
|
cam_poses_input[1:] = orbit_camera_jitter(cam_poses_input[1:]) |
|
|
|
images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
|
|
|
results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) |
|
results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) |
|
|
|
|
|
rays_embeddings = [] |
|
for i in range(self.opt.num_input_views): |
|
rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) |
|
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) |
|
rays_embeddings.append(rays_plucker) |
|
|
|
|
|
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() |
|
final_input = torch.cat([images_input, rays_embeddings], dim=1) |
|
results['input'] = final_input |
|
|
|
|
|
cam_poses[:, :3, 1:3] *= -1 |
|
|
|
|
|
cam_view = torch.inverse(cam_poses).transpose(1, 2) |
|
cam_view_proj = cam_view @ self.proj_matrix |
|
cam_pos = - cam_poses[:, :3, 3] |
|
|
|
results['cam_view'] = cam_view |
|
results['cam_view_proj'] = cam_view_proj |
|
results['cam_pos'] = cam_pos |
|
|
|
return results |