|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import random |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
from torch.utils.data import Dataset |
|
|
|
import kiui |
|
from core.options import Options |
|
from core.utils import get_rays, grid_distortion, orbit_camera_jitter |
|
|
|
from kiui.cam import orbit_camera |
|
|
|
import tarfile |
|
from io import BytesIO |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
|
|
def load_np_array_from_tar(tar, path): |
|
array_file = BytesIO() |
|
array_file.write(tar.extractfile(path).read()) |
|
array_file.seek(0) |
|
return np.load(array_file) |
|
|
|
|
|
class ObjaverseDataset(Dataset): |
|
|
|
def _warn(self): |
|
raise NotImplementedError('this dataset is just an example and cannot be used directly, you should modify it to your own setting! (search keyword TODO)') |
|
|
|
def __init__(self, opt: Options, training=True, evaluating=False): |
|
|
|
self.opt = opt |
|
self.training = training |
|
self.evaluating = evaluating |
|
|
|
self.items = [] |
|
with open(self.opt.datalist, 'r') as f: |
|
for line in f.readlines(): |
|
self.items.append(line.strip()) |
|
|
|
|
|
anim_map = {} |
|
for x in self.items: |
|
k = x.split('-')[1] |
|
if k in anim_map: |
|
anim_map[k] += '|'+x |
|
else: |
|
anim_map[k] = x |
|
self.items = list(anim_map.values()) |
|
|
|
|
|
|
|
if self.training: |
|
self.items = self.items[:-self.opt.batch_size] |
|
elif self.evaluating: |
|
self.items = self.items[::1000] |
|
else: |
|
self.items = self.items[-self.opt.batch_size:] |
|
|
|
|
|
|
|
self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) |
|
self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) |
|
self.proj_matrix[0, 0] = 1 / self.tan_half_fov |
|
self.proj_matrix[1, 1] = 1 / self.tan_half_fov |
|
self.proj_matrix[2, 2] = (self.opt.zfar + self.opt.znear) / (self.opt.zfar - self.opt.znear) |
|
self.proj_matrix[3, 2] = - (self.opt.zfar * self.opt.znear) / (self.opt.zfar - self.opt.znear) |
|
self.proj_matrix[2, 3] = 1 |
|
|
|
def __len__(self): |
|
return len(self.items) |
|
|
|
def _get_batch(self, idx): |
|
if self.training: |
|
uid = random.choice(self.items[idx].split('|')) |
|
else: |
|
uid = self.items[idx].split('|')[0] |
|
|
|
results = {} |
|
|
|
|
|
images = [] |
|
masks = [] |
|
cam_poses = [] |
|
|
|
if self.training and self.opt.shuffle_input: |
|
vids = np.random.permutation(np.arange(32, 48))[:self.opt.num_input_views].tolist() + np.random.permutation(32).tolist() |
|
else: |
|
vids = np.arange(32, 48, 4).tolist() + np.arange(32).tolist() |
|
|
|
|
|
random_tar_name = 'random_clip/' + uid |
|
fixed_16_tar_name = 'fixed_16_clip/' + uid |
|
|
|
local_random_tar_name = os.environ["DATA_HOME"] + random_tar_name.replace('/', '-') |
|
local_fixed_16_tar_name = os.environ["DATA_HOME"] + fixed_16_tar_name.replace('/', '-') |
|
|
|
tar_random = tarfile.open(local_random_tar_name) |
|
tar_fixed = tarfile.open(local_fixed_16_tar_name) |
|
|
|
|
|
T = self.opt.num_frames |
|
for t_idx in range(T): |
|
t = t_idx |
|
vid_cnt = 0 |
|
for vid in vids: |
|
if vid >= 32: |
|
vid = vid % 32 |
|
tar = tar_fixed |
|
else: |
|
tar = tar_random |
|
|
|
image_path = os.path.join('.', f'{vid:03d}/img', f'{t:03d}.jpg') |
|
mask_path = os.path.join('.', f'{vid:03d}/mask', f'{t:03d}.png') |
|
|
|
elevation_path = os.path.join('.', f'{vid:03d}/camera', f'elevation.npy') |
|
rotation_path = os.path.join('.', f'{vid:03d}/camera', f'rotation.npy') |
|
|
|
image = np.frombuffer(tar.extractfile(image_path).read(), np.uint8) |
|
image = torch.from_numpy(cv2.imdecode(image, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255) |
|
|
|
azi = load_np_array_from_tar(tar, rotation_path)[t, None] |
|
elevation = load_np_array_from_tar(tar, elevation_path)[t, None] * -1 |
|
azi = float(azi) |
|
elevation = float(elevation) |
|
c2w = torch.from_numpy(orbit_camera(elevation, azi, radius=1.5, opengl=True)) |
|
|
|
image = image.permute(2, 0, 1) |
|
|
|
mask = np.frombuffer(tar.extractfile(mask_path).read(), np.uint8) |
|
mask = torch.from_numpy(cv2.imdecode(mask, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255).unsqueeze(0) |
|
|
|
image = F.interpolate(image.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0) |
|
mask = F.interpolate(mask.unsqueeze(0), size=(512, 512), mode='nearest').squeeze(0) |
|
|
|
image = image[:3] * mask + (1 - mask) |
|
image = image[[2,1,0]].contiguous() |
|
|
|
images.append(image) |
|
masks.append(mask.squeeze(0)) |
|
cam_poses.append(c2w) |
|
|
|
vid_cnt += 1 |
|
if vid_cnt == self.opt.num_views: |
|
break |
|
|
|
if vid_cnt < self.opt.num_views: |
|
print(f'[WARN] dataset {uid}: not enough valid views, only {vid_cnt} views found!') |
|
n = self.opt.num_views - vid_cnt |
|
images = images + [images[-1]] * n |
|
masks = masks + [masks[-1]] * n |
|
cam_poses = cam_poses + [cam_poses[-1]] * n |
|
|
|
images = torch.stack(images, dim=0) |
|
masks = torch.stack(masks, dim=0) |
|
cam_poses = torch.stack(cam_poses, dim=0) |
|
|
|
|
|
transform = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, self.opt.cam_radius], [0, 0, 0, 1]], dtype=torch.float32) @ torch.inverse(cam_poses[0]) |
|
cam_poses = transform.unsqueeze(0) @ cam_poses |
|
|
|
images_input = F.interpolate(images.reshape(T, self.opt.num_views, *images.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *images.shape[1:]).clone(), size=(self.opt.input_size, self.opt.input_size), mode='bilinear', align_corners=False) |
|
cam_poses_input = cam_poses.reshape(T, self.opt.num_views, *cam_poses.shape[1:])[:, :self.opt.num_input_views].reshape(-1, *cam_poses.shape[1:]).clone() |
|
|
|
|
|
if self.training: |
|
images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:]) |
|
cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:]) |
|
|
|
|
|
if random.random() < self.opt.prob_grid_distortion: |
|
for t in range(T): |
|
images_input[t, 1:] = grid_distortion(images_input[t, 1:]) |
|
|
|
if random.random() < self.opt.prob_cam_jitter: |
|
for t in range(T): |
|
cam_poses_input[t, 1:] = orbit_camera_jitter(cam_poses_input[t, 1:]) |
|
|
|
images_input = images_input.reshape(-1, *images_input.shape[2:]) |
|
cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:]) |
|
|
|
|
|
images_input = images_input.reshape(T, self.opt.num_input_views, *images_input.shape[1:]) |
|
images_input[1:, 1:] = images_input[0:1, 1:] |
|
images_input = images_input.reshape(-1, *images_input.shape[2:]) |
|
|
|
cam_poses_input = cam_poses_input.reshape(T, self.opt.num_input_views, *cam_poses.shape[1:]) |
|
cam_poses_input[1:, 1:] = cam_poses_input[0:1, 1:] |
|
cam_poses_input = cam_poses_input.reshape(-1, *cam_poses.shape[1:]) |
|
|
|
images_input = TF.normalize(images_input, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
|
|
|
results['images_output'] = F.interpolate(images, size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) |
|
results['masks_output'] = F.interpolate(masks.unsqueeze(1), size=(self.opt.output_size, self.opt.output_size), mode='bilinear', align_corners=False) |
|
|
|
|
|
rays_embeddings = [] |
|
for i in range(self.opt.num_input_views * T): |
|
rays_o, rays_d = get_rays(cam_poses_input[i], self.opt.input_size, self.opt.input_size, self.opt.fovy) |
|
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], dim=-1) |
|
rays_embeddings.append(rays_plucker) |
|
|
|
|
|
rays_embeddings = torch.stack(rays_embeddings, dim=0).permute(0, 3, 1, 2).contiguous() |
|
|
|
final_input = torch.cat([images_input, rays_embeddings], dim=1) |
|
results['input'] = final_input |
|
|
|
|
|
cam_poses[:, :3, 1:3] *= -1 |
|
|
|
|
|
cam_view = torch.inverse(cam_poses).transpose(1, 2) |
|
cam_view_proj = cam_view @ self.proj_matrix |
|
cam_pos = - cam_poses[:, :3, 3] |
|
|
|
results['cam_view'] = cam_view |
|
results['cam_view_proj'] = cam_view_proj |
|
results['cam_pos'] = cam_pos |
|
|
|
return results |
|
|
|
def __getitem__(self, idx): |
|
while True: |
|
try: |
|
results = self._get_batch(idx) |
|
break |
|
except Exception as e: |
|
print(f"{e}") |
|
idx = random.randint(0, len(self.items) - 1) |
|
return results |