EscherNet / 6DoF /dataset.py
kxhit
update
5f093a6
raw
history blame
7.7 kB
import os
import math
from pathlib import Path
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import webdataset as wds
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
import sys
class ObjaverseDataLoader():
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.total_view = total_view
image_transforms = [torchvision.transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]
self.image_transforms = torchvision.transforms.Compose(image_transforms)
def train_dataloader(self):
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
image_transforms=self.image_transforms)
# sampler = DistributedSampler(dataset)
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
# sampler=sampler)
def val_dataloader(self):
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
image_transforms=self.image_transforms)
sampler = DistributedSampler(dataset)
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def get_pose(transformation):
# transformation: 4x4
return transformation
class ObjaverseData(Dataset):
def __init__(self,
root_dir='.objaverse/hf-objaverse-v1/views',
image_transforms=None,
total_view=12,
validation=False,
T_in=1,
T_out=1,
fix_sample=False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.total_view = total_view
self.T_in = T_in
self.T_out = T_out
self.fix_sample = fix_sample
self.paths = []
# # include all folders
# for folder in os.listdir(self.root_dir):
# if os.path.isdir(os.path.join(self.root_dir, folder)):
# self.paths.append(folder)
# load ids from .npy so we have exactly the same ids/order
self.paths = np.load("../scripts/obj_ids.npy")
# # only use 100K objects for ablation study
# self.paths = self.paths[:100000]
total_objects = len(self.paths)
assert total_objects == 790152, 'total objects %d' % total_objects
if validation:
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
else:
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
print('============= length of dataset %d =============' % len(self.paths))
self.tform = image_transforms
downscale = 512 / 256.
self.fx = 560. / downscale
self.fy = 560. / downscale
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
def __len__(self):
return len(self.paths)
def get_pose(self, transformation):
# transformation: 4x4
return transformation
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
try:
img = plt.imread(path)
except:
print(path)
sys.exit()
img[img[:, :, -1] == 0.] = color
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img
def __getitem__(self, index):
data = {}
total_view = 12
if self.fix_sample:
if self.T_out > 1:
indexes = range(total_view)
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
index_inputs = indexes[1:self.T_in+1] # one overlap identity
else:
indexes = range(total_view)
index_targets = indexes[:self.T_out]
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
else:
assert self.T_in + self.T_out <= total_view
# training with replace, including identity
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
index_inputs = indexes[:self.T_in]
index_targets = indexes[self.T_in:]
filename = os.path.join(self.root_dir, self.paths[index])
color = [1., 1., 1., 1.]
try:
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
except:
print('error loading data ', filename)
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
# very hacky solution, sorry about this
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
# stack to batch
data['image_input'] = torch.stack(input_ims, dim=0)
data['image_target'] = torch.stack(target_ims, dim=0)
data['pose_out'] = np.stack(target_Ts)
data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
data['pose_in'] = np.stack(cond_Ts)
data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
return data
def process_im(self, im):
im = im.convert("RGB")
return self.tform(im)