pengHTYX
'test'
a875c68
raw
history blame
No virus
14 kB
from typing import Dict
import numpy as np
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
from PIL import Image
from torchvision import transforms
from einops import rearrange, repeat
from typing import Literal, Tuple, Optional, Any
import cv2
import random
import json
import os, sys
import math
from PIL import Image, ImageOps
from normal_utils import worldNormal2camNormal, plot_grid_images, img2normal, norm_normalize, deg2rad
import pdb
from icecream import ic
def shift_list(lst, n):
length = len(lst)
n = n % length # Ensure n is within the range of the list length
return lst[-n:] + lst[:-n]
class ObjaverseDataset(Dataset):
def __init__(self,
root_dir: str,
azi_interval: float,
random_views: int,
predict_relative_views: list,
bg_color: Any,
object_list: str,
prompt_embeds_path: str,
img_wh: Tuple[int, int],
validation: bool = False,
num_validation_samples: int = 64,
num_samples: Optional[int] = None,
invalid_list: Optional[str] = None,
trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
# augment_data: bool = False,
side_views_rate: float = 0.,
read_normal: bool = True,
read_color: bool = False,
read_depth: bool = False,
mix_color_normal: bool = False,
random_view_and_domain: bool = False,
load_cache: bool = False,
exten: str = '.png',
elevation_list: Optional[str] = None,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = root_dir
self.fixed_views = int(360 // azi_interval)
self.bg_color = bg_color
self.validation = validation
self.num_samples = num_samples
self.trans_norm_system = trans_norm_system
# self.augment_data = augment_data
self.invalid_list = invalid_list
self.img_wh = img_wh
self.read_normal = read_normal
self.read_color = read_color
self.read_depth = read_depth
self.mix_color_normal = mix_color_normal # mix load color and normal maps
self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view
self.random_views = random_views
self.load_cache = load_cache
self.total_views = int(self.fixed_views * (self.random_views + 1))
self.predict_relative_views = predict_relative_views
self.pred_view_nums = len(self.predict_relative_views)
self.exten = exten
self.side_views_rate = side_views_rate
# ic(self.augment_data)
ic(self.total_views)
ic(self.fixed_views)
ic(self.predict_relative_views)
self.objects = []
if object_list is not None:
for dataset_list in object_list:
with open(dataset_list, 'r') as f:
# objects = f.readlines()
# objects = [o.strip() for o in objects]
objects = json.load(f)
self.objects.extend(objects)
else:
self.objects = os.listdir(self.root_dir)
# load fixed camera poses
self.trans_cv2gl_mat = np.linalg.inv(np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]))
self.fix_cam_poses = []
camera_path = os.path.join(self.root_dir, self.objects[0], 'camera')
for vid in range(0, self.total_views, self.random_views+1):
cam_info = np.load(f'{camera_path}/{vid:03d}.npy', allow_pickle=True).item()
assert cam_info['camera'] == 'ortho', 'Only support predict ortho camera !!!'
self.fix_cam_poses.append(cam_info['extrinsic'])
random.shuffle(self.objects)
# import pdb; pdb.set_trace()
invalid_objects = []
if self.invalid_list is not None:
for invalid_list in self.invalid_list:
if invalid_list[-4:] == '.txt':
with open(invalid_list, 'r') as f:
sub_invalid = f.readlines()
invalid_objects.extend([o.strip() for o in sub_invalid])
else:
with open(invalid_list) as f:
invalid_objects.extend(json.load(f))
self.invalid_objects = invalid_objects
ic(len(self.invalid_objects))
if elevation_list:
with open(elevation_list, 'r') as f:
ele_list = [o.strip() for o in f.readlines()]
self.objects = set(ele_list) & set(self.objects)
self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
self.all_objects = list(self.all_objects)
self.validation = validation
if not validation:
self.all_objects = self.all_objects[:-num_validation_samples]
# print('Warning: you are fitting in small-scale dataset')
# self.all_objects = self.all_objects
else:
self.all_objects = self.all_objects[-num_validation_samples:]
if num_samples is not None:
self.all_objects = self.all_objects[:num_samples]
ic(len(self.all_objects))
print("loading ", len(self.all_objects), " objects in the dataset")
self.normal_prompt_embedding = torch.load(f'{prompt_embeds_path}/normal_embeds.pt')
self.color_prompt_embedding = torch.load(f'{prompt_embeds_path}/clr_embeds.pt')
if self.mix_color_normal:
self.backup_data = self.__getitem_mix__(0, '8609cf7e67bf413487a7d94c73aeaa3e')
else:
self.backup_data = self.__getitem_norm__(0, '8609cf7e67bf413487a7d94c73aeaa3e')
def trans_cv2gl(self, rt):
r, t = rt[:3, :3], rt[:3, -1]
r = np.matmul(self.trans_cv2gl_mat, r)
t = np.matmul(self.trans_cv2gl_mat, t)
return np.concatenate([r, t[:, None]], axis=-1)
def get_bg_color(self):
if self.bg_color == 'white':
bg_color = np.array([1., 1., 1.], dtype=np.float32)
elif self.bg_color == 'black':
bg_color = np.array([0., 0., 0.], dtype=np.float32)
elif self.bg_color == 'gray':
bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
elif self.bg_color == 'random':
bg_color = np.random.rand(3)
elif self.bg_color == 'three_choices':
white = np.array([1., 1., 1.], dtype=np.float32)
black = np.array([0., 0., 0.], dtype=np.float32)
gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
bg_color = random.choice([white, black, gray])
elif isinstance(self.bg_color, float):
bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
else:
raise NotImplementedError
return bg_color
def load_image(self, img_path, bg_color, alpha=None, return_type='np'):
# not using cv2 as may load in uint16 format
# img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
# img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
# pil always returns uint8
rgba = np.array(Image.open(img_path).resize(self.img_wh))
rgba = rgba.astype(np.float32) / 255. # [0, 1]
img = rgba[..., :3]
if alpha is None:
assert rgba.shape[-1] == 4
alpha = rgba[..., 3:4]
assert alpha.sum() > 1e-8, 'w/o foreground'
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
alpha = torch.from_numpy(alpha)
else:
raise NotImplementedError
return img, alpha
def load_depth(self, img_path, bg_color, alpha, input_type='png', return_type='np'):
# not using cv2 as may load in uint16 format
# img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
# img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
# pil always returns uint8
img = np.array(Image.open(img_path).resize(self.img_wh))
img = img.astype(np.float32) / 65535. # [0, 1]
img[img > 0.4] = 0
img = img / 0.4
assert img.ndim == 2 # depth
img = np.stack([img]*3, axis=-1)
if alpha.shape[-1] != 1:
alpha = alpha[:, :, None]
# print(np.max(img[:, :, 0]))
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
else:
raise NotImplementedError
return img
def load_normal(self, img_path, bg_color, alpha, RT_w2c_cond=None, return_type='np'):
normal_np = np.array(Image.open(img_path).resize(self.img_wh))[:, :, :3]
assert np.var(normal_np) > 1e-8, 'pure normal'
normal_cv = img2normal(normal_np)
normal_relative_cv = worldNormal2camNormal(RT_w2c_cond[:3, :3], normal_cv)
normal_relative_cv = norm_normalize(normal_relative_cv)
# normal_relative_gl = normal_relative_cv[..., [ 0, 2, 1]]
# normal_relative_gl[..., 2] = -normal_relative_gl[..., 2]
normal_relative_gl = normal_relative_cv
normal_relative_gl[..., 1:] = -normal_relative_gl[..., 1:]
img = (normal_relative_cv*0.5 + 0.5).astype(np.float32) # [0, 1]
if alpha.shape[-1] != 1:
alpha = alpha[:, :, None]
img = img[...,:3] * alpha + bg_color * (1 - alpha)
if return_type == "np":
pass
elif return_type == "pt":
img = torch.from_numpy(img)
else:
raise NotImplementedError
return img
def __len__(self):
return len(self.all_objects)
def __getitem_norm__(self, index, debug_object=None):
# get the bg color
bg_color = self.get_bg_color()
if debug_object is not None:
object_name = debug_object
else:
object_name = self.all_objects[index % len(self.all_objects)]
if self.validation:
cond_ele0_idx = 12
else:
rand = random.random()
if rand < self.side_views_rate: # 0.1
cond_ele0_idx = random.sample([8, 0], 1)[0]
elif rand < 3 * self.side_views_rate: # 0.3
cond_ele0_idx = random.sample([10, 14], 1)[0]
else:
cond_ele0_idx = 12 # front view
cond_random_idx = random.sample(range(self.random_views+1), 1)[0]
# condition info
cond_ele0_vid = cond_ele0_idx * (self.random_views + 1)
cond_vid = cond_ele0_vid + cond_random_idx
cond_ele0_w2c = self.fix_cam_poses[cond_ele0_idx]
cond_info = np.load(f'{self.root_dir}/{object_name}/camera/{cond_vid:03d}.npy', allow_pickle=True).item()
cond_type = cond_info['camera']
focal_len = cond_info['focal']
cond_eles = np.array([deg2rad(cond_info['elevation'])])
img_tensors_in = [
self.load_image(f"{self.root_dir}/{object_name}/image/{cond_vid:03d}{self.exten}", bg_color, return_type='pt')[0].permute(2, 0, 1)
] * self.pred_view_nums
# output info
pred_vids = [(cond_ele0_vid + i * (self.random_views+1)) % self.total_views for i in self.predict_relative_views]
# pred_w2cs = [self.fix_cam_poses[(cond_ele0_idx + i) % self.fixed_views] for i in self.predict_relative_views]
img_tensors_out = []
normal_tensors_out = []
for i, vid in enumerate(pred_vids):
try:
img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image/{vid:03d}{self.exten}", bg_color, return_type='pt')
except:
img_tensor, alpha_ = self.load_image(f"{self.root_dir}/{object_name}/image_relit/{vid:03d}{self.exten}", bg_color, return_type='pt')
img_tensor = img_tensor.permute(2, 0, 1) # (3, H, W)
img_tensors_out.append(img_tensor)
normal_tensor = self.load_normal(f"{self.root_dir}/{object_name}/normal/{vid:03d}{self.exten}", bg_color, alpha_.numpy(), RT_w2c_cond=cond_ele0_w2c[:3, :], return_type="pt").permute(2, 0, 1)
normal_tensors_out.append(normal_tensor)
img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
elevations_cond = torch.as_tensor(cond_eles).float()
if cond_type == 'ortho':
focal_embed = torch.tensor([0.])
else:
focal_embed = torch.tensor([24./focal_len])
if not self.load_cache:
return {
'elevations_cond': elevations_cond,
'focal_cond': focal_embed,
'id': object_name,
'vid':cond_vid,
'imgs_in': img_tensors_in,
'imgs_out': img_tensors_out,
'normals_out': normal_tensors_out,
'normal_prompt_embeddings': self.normal_prompt_embedding,
'color_prompt_embeddings': self.color_prompt_embedding
}
def __getitem__(self, index):
try:
return self.__getitem_norm__(index)
except:
print("load error ", self.all_objects[index%len(self.all_objects)] )
return self.backup_data