Sharp-It / src /data /objaverse_zero123plus.py
YiftachEde's picture
add src
a1d8bef
raw
history blame
7.68 kB
import os
import json
import numpy as np
import webdataset as wds
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from PIL import Image
from pathlib import Path
from tqdm import tqdm
from src.utils.train_util import instantiate_from_config
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test
def setup(self, stage):
if stage in ['fit']:
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
else:
raise NotImplementedError
def train_dataloader(self):
# sampler = DistributedSampler(self.datasets['train'])
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
def val_dataloader(self):
# sampler = DistributedSampler(self.datasets['validation'])
return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False)
def test_dataloader(self):
return wds.WebLoader(self.datasets['validation'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
class RefinementData(Dataset):
lights_to_caption = {
0 : "Morning light",
1 : "Midday light, clear sky",
2 : "Afternoon light, cloudy sky",
}
def __init__(self,
root_dir='refinement_dataset/',
gt_subpath='gt',
pred_subpath='shap_e',
validation=False,
overfit=False,
caption_path=None,
split_path=None,
single_view=False,
single_light=False,
) -> None:
self.root_dir = Path(root_dir)
self.gt_subpath = gt_subpath
self.pred_subpath = pred_subpath
self.single_view = single_view
self.single_light = single_light
if caption_path is not None:
caption_path = self.root_dir / caption_path
with open(caption_path) as f:
self.captions_dict = json.load(f)
split_json = self.root_dir / split_path
with open(split_json) as f:
split_dict = json.load(f)
# print(split_dict.keys
# exit(0)
if validation:
uuids = split_dict['val']
else:
uuids = split_dict['train']
self.paths = [self.root_dir / uuid for uuid in uuids]
print('============= length of dataset %d =============' % len(self.paths))
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
pil_img = Image.open(path)
image = np.asarray(pil_img, dtype=np.float32) / 255.
if image.shape[2] == 4:
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
else:
alpha = np.ones_like(image[:, :, :1])
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
if os.path.exists(self.paths[index] / 'lights.json'):
num_lights = 3
else:
num_lights = len(os.listdir(self.paths[index] / self.gt_subpath))
if self.single_view:
view_index = np.random.randint(0, 6)
if self.single_light:
light_index = 0
else:
light_index = np.random.randint(0, num_lights)
# print("light index", light_index)
# exit(0)
uuid = self.paths[index].name
caption = self.captions_dict[uuid]
# if "lights.json" in os.listdir(self.paths[index]) and num_lights == 3: # according to additions to the dataset
# caption += " " + self.lights_to_caption[light_index]
image_path_gt = os.path.join(self.paths[index],'gt',str(light_index), "latent.pt")
image_path_pred = os.path.join(self.paths[index],'shap_e', "latent.pt")
'''background color, default: white'''
try:
imgs_gt = torch.load(image_path_gt,map_location='cpu').squeeze()
imgs_pred = torch.load(image_path_pred,map_location='cpu').squeeze()
except Exception as e:
print("Error loading latent tensors, gt path %s, pred path %s" % (image_path_gt, image_path_pred))
raise e
if self.single_view:
row = view_index // 2
col = view_index % 2
imgs_gt = imgs_gt[:, row*40:(row+1)*40, col*40:(col+1)*40]
imgs_pred = imgs_pred[:, row*40:(row+1)*40, col*40:(col+1)*40]
# imgs_gt = imgs_gt
data = {
'refined_imgs': imgs_gt, # (6, 3, H, W)
'unrefined_imgs': imgs_pred, # (6, 3, H, W)
'caption': caption,
'index': index
}
return data
class ObjaverseData(Dataset):
def __init__(self,
root_dir='objaverse/',
meta_fname='valid_paths.json',
image_dir='rendering_zero123plus',
validation=False,
):
self.root_dir = Path(root_dir)
self.image_dir = image_dir
with open(os.path.join(root_dir, meta_fname)) as f:
lvis_dict = json.load(f)
paths = []
for k in lvis_dict.keys():
paths.extend(lvis_dict[k])
self.paths = paths
total_objects = len(self.paths)
if validation:
self.paths = self.paths[-16:] # used last 16 as validation
else:
self.paths = self.paths[:-16]
print('============= length of dataset %d =============' % len(self.paths))
def __len__(self):
return len(self.paths)
def load_im(self, path, color):
pil_img = Image.open(path)
image = np.asarray(pil_img, dtype=np.float32) / 255.
alpha = image[:, :, 3:]
image = image[:, :, :3] * alpha + color * (1 - alpha)
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
return image, alpha
def __getitem__(self, index):
while True:
image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index])
'''background color, default: white'''
bkg_color = [1., 1., 1.]
img_list = []
try:
for idx in range(7):
img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color)
img_list.append(img)
except Exception as e:
print(e)
index = np.random.randint(0, len(self.paths))
continue
break
imgs = torch.stack(img_list, dim=0).float()
data = {
'cond_imgs': imgs[0], # (3, H, W)
'target_imgs': imgs[1:], # (6, 3, H, W)
}
return data