Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,697 Bytes
5f093a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import math
from pathlib import Path
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import webdataset as wds
from torch.utils.data.distributed import DistributedSampler
import matplotlib.pyplot as plt
import sys
class ObjaverseDataLoader():
def __init__(self, root_dir, batch_size, total_view=12, num_workers=4):
self.root_dir = root_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.total_view = total_view
image_transforms = [torchvision.transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]
self.image_transforms = torchvision.transforms.Compose(image_transforms)
def train_dataloader(self):
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False,
image_transforms=self.image_transforms)
# sampler = DistributedSampler(dataset)
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
# sampler=sampler)
def val_dataloader(self):
dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True,
image_transforms=self.image_transforms)
sampler = DistributedSampler(dataset)
return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
def get_pose(transformation):
# transformation: 4x4
return transformation
class ObjaverseData(Dataset):
def __init__(self,
root_dir='.objaverse/hf-objaverse-v1/views',
image_transforms=None,
total_view=12,
validation=False,
T_in=1,
T_out=1,
fix_sample=False,
) -> None:
"""Create a dataset from a folder of images.
If you pass in a root directory it will be searched for images
ending in ext (ext can be a list)
"""
self.root_dir = Path(root_dir)
self.total_view = total_view
self.T_in = T_in
self.T_out = T_out
self.fix_sample = fix_sample
self.paths = []
# # include all folders
# for folder in os.listdir(self.root_dir):
# if os.path.isdir(os.path.join(self.root_dir, folder)):
# self.paths.append(folder)
# load ids from .npy so we have exactly the same ids/order
self.paths = np.load("../scripts/obj_ids.npy")
# # only use 100K objects for ablation study
# self.paths = self.paths[:100000]
total_objects = len(self.paths)
assert total_objects == 790152, 'total objects %d' % total_objects
if validation:
self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
else:
self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
print('============= length of dataset %d =============' % len(self.paths))
self.tform = image_transforms
downscale = 512 / 256.
self.fx = 560. / downscale
self.fy = 560. / downscale
self.intrinsic = torch.tensor([[self.fx, 0, 128., 0, self.fy, 128., 0, 0, 1.]], dtype=torch.float64).view(3, 3)
def __len__(self):
return len(self.paths)
def get_pose(self, transformation):
# transformation: 4x4
return transformation
def load_im(self, path, color):
'''
replace background pixel with random color in rendering
'''
try:
img = plt.imread(path)
except:
print(path)
sys.exit()
img[img[:, :, -1] == 0.] = color
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img
def __getitem__(self, index):
data = {}
total_view = 12
if self.fix_sample:
if self.T_out > 1:
indexes = range(total_view)
index_targets = list(indexes[:2]) + list(indexes[-(self.T_out-2):])
index_inputs = indexes[1:self.T_in+1] # one overlap identity
else:
indexes = range(total_view)
index_targets = indexes[:self.T_out]
index_inputs = indexes[self.T_out-1:self.T_in+self.T_out-1] # one overlap identity
else:
assert self.T_in + self.T_out <= total_view
# training with replace, including identity
indexes = np.random.choice(range(total_view), self.T_in+self.T_out, replace=True)
index_inputs = indexes[:self.T_in]
index_targets = indexes[self.T_in:]
filename = os.path.join(self.root_dir, self.paths[index])
color = [1., 1., 1., 1.]
try:
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
except:
print('error loading data ', filename)
filename = os.path.join(self.root_dir, '0a01f314e2864711aa7e33bace4bd8c8') # this one we know is valid
input_ims = []
target_ims = []
target_Ts = []
cond_Ts = []
# very hacky solution, sorry about this
for i, index_input in enumerate(index_inputs):
input_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_input), color))
input_ims.append(input_im)
input_RT = np.load(os.path.join(filename, '%03d.npy' % index_input))
cond_Ts.append(self.get_pose(np.concatenate([input_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
for i, index_target in enumerate(index_targets):
target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
target_ims.append(target_im)
target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
target_Ts.append(self.get_pose(np.concatenate([target_RT[:3, :], np.array([[0, 0, 0, 1]])], axis=0)))
# stack to batch
data['image_input'] = torch.stack(input_ims, dim=0)
data['image_target'] = torch.stack(target_ims, dim=0)
data['pose_out'] = np.stack(target_Ts)
data['pose_out_inv'] = np.linalg.inv(np.stack(target_Ts)).transpose([0, 2, 1])
data['pose_in'] = np.stack(cond_Ts)
data['pose_in_inv'] = np.linalg.inv(np.stack(cond_Ts)).transpose([0, 2, 1])
return data
def process_im(self, im):
im = im.convert("RGB")
return self.tform(im) |