Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import math | |
from pathlib import Path | |
import torch | |
import torchvision | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import webdataset as wds | |
from torch.utils.data.distributed import DistributedSampler | |
import matplotlib.pyplot as plt | |
import sys | |
class ObjaverseDataLoader(): | |
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4): | |
self.root_dir = root_dir | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.total_view = total_view | |
image_transforms = [torchvision.transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5])] | |
self.image_transforms = torchvision.transforms.Compose(image_transforms) | |
def train_dataloader(self): | |
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, | |
image_transforms=self.image_transforms) | |
# sampler = DistributedSampler(dataset) | |
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
# sampler=sampler) | |
def val_dataloader(self): | |
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, | |
image_transforms=self.image_transforms) | |
sampler = DistributedSampler(dataset) | |
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) | |
def get_pose(transformation): | |
# transformation: 4x4 | |
return transformation | |
class ObjaverseData(Dataset): | |
def __init__(self, | |
root_dir='.objaverse/hf-objaverse-v1/views', | |
image_transforms=None, | |
total_view=12, | |
validation=False, | |
T_in=1, | |
T_out=1, | |
fix_sample=False, | |
) -> None: | |
"""Create a dataset from a folder of images. | |
If you pass in a root directory it will be searched for images | |
ending in ext (ext can be a list) | |
""" | |
self.root_dir = Path(root_dir) | |
self.total_view = total_view | |
self.T_in = T_in | |
self.T_out = T_out | |
self.fix_sample = fix_sample | |
self.paths = [] | |
# # include all folders | |
# for folder in os.listdir(self.root_dir): | |
# if os.path.isdir(os.path.join(self.root_dir, folder)): | |
# self.paths.append(folder) | |
# load ids from .npy so we have exactly the same ids/order | |
self.paths = np.load("../scripts/obj_ids.npy") | |
# # only use 100K objects for ablation study | |
# self.paths = self.paths[:100000] | |
total_objects = len(self.paths) | |
assert total_objects == 790152, 'total objects %d' % total_objects | |
if validation: | |
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation | |
else: | |
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training | |
print('============= length of dataset %d =============' % len(self.paths)) | |
self.tform = image_transforms | |
downscale = 512 / 256. | |
self.fx = 560. / downscale | |
self.fy = 560. / downscale | |
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3) | |
def __len__(self): | |
return len(self.paths) | |
def get_pose(self, transformation): | |
# transformation: 4x4 | |
return transformation | |
def load_im(self, path, color): | |
''' | |
replace background pixel with random color in rendering | |
''' | |
try: | |
img = plt.imread(path) | |
except: | |
print(path) | |
sys.exit() | |
img[img[:, :, -1] == 0.] = color | |
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) | |
return img | |
def __getitem__(self, index): | |
data = {} | |
total_view = 12 | |
if self.fix_sample: | |
if self.T_out > 1: | |
indexes = range(total_view) | |
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):]) | |
index_inputs = indexes[1:self.T_in+1] # one overlap identity | |
else: | |
indexes = range(total_view) | |
index_targets = indexes[:self.T_out] | |
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity | |
else: | |
assert self.T_in + self.T_out <= total_view | |
# training with replace, including identity | |
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True) | |
index_inputs = indexes[:self.T_in] | |
index_targets = indexes[self.T_in:] | |
filename = os.path.join(self.root_dir, self.paths[index]) | |
color = [1., 1., 1., 1.] | |
try: | |
input_ims = [] | |
target_ims = [] | |
target_Ts = [] | |
cond_Ts = [] | |
for i, index_input in enumerate(index_inputs): | |
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) | |
input_ims.append(input_im) | |
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) | |
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
for i, index_target in enumerate(index_targets): | |
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) | |
target_ims.append(target_im) | |
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) | |
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
except: | |
print('error loading data ', filename) | |
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid | |
input_ims = [] | |
target_ims = [] | |
target_Ts = [] | |
cond_Ts = [] | |
# very hacky solution, sorry about this | |
for i, index_input in enumerate(index_inputs): | |
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color)) | |
input_ims.append(input_im) | |
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input)) | |
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
for i, index_target in enumerate(index_targets): | |
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) | |
target_ims.append(target_im) | |
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) | |
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0))) | |
# stack to batch | |
data['image_input'] = torch.stack(input_ims, dim=0) | |
data['image_target'] = torch.stack(target_ims, dim=0) | |
data['pose_out'] = np.stack(target_Ts) | |
data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1]) | |
data['pose_in'] = np.stack(cond_Ts) | |
data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1]) | |
return data | |
def process_im(self, im): | |
im = im.convert("RGB") | |
return self.tform(im) |