import os import json import numpy as np import webdataset as wds import pytorch_lightning as pl import torch from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler from PIL import Image from pathlib import Path from tqdm import tqdm from src.utils.train_util import instantiate_from_config class DataModuleFromConfig(pl.LightningDataModule): def __init__( self, batch_size=8, num_workers=4, train=None, validation=None, test=None, **kwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs['train'] = train if validation is not None: self.dataset_configs['validation'] = validation if test is not None: self.dataset_configs['test'] = test def setup(self, stage): if stage in ['fit']: self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) else: raise NotImplementedError def train_dataloader(self): # sampler = DistributedSampler(self.datasets['train']) return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) def val_dataloader(self): # sampler = DistributedSampler(self.datasets['validation']) return wds.WebLoader(self.datasets['validation'], batch_size=4, num_workers=self.num_workers, shuffle=False) def test_dataloader(self): return wds.WebLoader(self.datasets['validation'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) class RefinementData(Dataset): lights_to_caption = { 0 : "Morning light", 1 : "Midday light, clear sky", 2 : "Afternoon light, cloudy sky", } def __init__(self, root_dir='refinement_dataset/', gt_subpath='gt', pred_subpath='shap_e', validation=False, overfit=False, caption_path=None, split_path=None, single_view=False, single_light=False, ) -> None: self.root_dir = Path(root_dir) self.gt_subpath = gt_subpath self.pred_subpath = pred_subpath self.single_view = single_view self.single_light = single_light if caption_path is not None: caption_path = self.root_dir / caption_path with open(caption_path) as f: self.captions_dict = json.load(f) split_json = self.root_dir / split_path with open(split_json) as f: split_dict = json.load(f) # print(split_dict.keys # exit(0) if validation: uuids = split_dict['val'] else: uuids = split_dict['train'] self.paths = [self.root_dir / uuid for uuid in uuids] print('============= length of dataset %d =============' % len(self.paths)) def __len__(self): return len(self.paths) def load_im(self, path, color): pil_img = Image.open(path) image = np.asarray(pil_img, dtype=np.float32) / 255. if image.shape[2] == 4: alpha = image[:, :, 3:] image = image[:, :, :3] * alpha + color * (1 - alpha) else: alpha = np.ones_like(image[:, :, :1]) image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() return image, alpha def __getitem__(self, index): if os.path.exists(self.paths[index] / 'lights.json'): num_lights = 3 else: num_lights = len(os.listdir(self.paths[index] / self.gt_subpath)) if self.single_view: view_index = np.random.randint(0, 6) if self.single_light: light_index = 0 else: light_index = np.random.randint(0, num_lights) # print("light index", light_index) # exit(0) uuid = self.paths[index].name caption = self.captions_dict[uuid] # if "lights.json" in os.listdir(self.paths[index]) and num_lights == 3: # according to additions to the dataset # caption += " " + self.lights_to_caption[light_index] image_path_gt = os.path.join(self.paths[index],'gt',str(light_index), "latent.pt") image_path_pred = os.path.join(self.paths[index],'shap_e', "latent.pt") '''background color, default: white''' try: imgs_gt = torch.load(image_path_gt,map_location='cpu').squeeze() imgs_pred = torch.load(image_path_pred,map_location='cpu').squeeze() except Exception as e: print("Error loading latent tensors, gt path %s, pred path %s" % (image_path_gt, image_path_pred)) raise e if self.single_view: row = view_index // 2 col = view_index % 2 imgs_gt = imgs_gt[:, row*40:(row+1)*40, col*40:(col+1)*40] imgs_pred = imgs_pred[:, row*40:(row+1)*40, col*40:(col+1)*40] # imgs_gt = imgs_gt data = { 'refined_imgs': imgs_gt, # (6, 3, H, W) 'unrefined_imgs': imgs_pred, # (6, 3, H, W) 'caption': caption, 'index': index } return data class ObjaverseData(Dataset): def __init__(self, root_dir='objaverse/', meta_fname='valid_paths.json', image_dir='rendering_zero123plus', validation=False, ): self.root_dir = Path(root_dir) self.image_dir = image_dir with open(os.path.join(root_dir, meta_fname)) as f: lvis_dict = json.load(f) paths = [] for k in lvis_dict.keys(): paths.extend(lvis_dict[k]) self.paths = paths total_objects = len(self.paths) if validation: self.paths = self.paths[-16:] # used last 16 as validation else: self.paths = self.paths[:-16] print('============= length of dataset %d =============' % len(self.paths)) def __len__(self): return len(self.paths) def load_im(self, path, color): pil_img = Image.open(path) image = np.asarray(pil_img, dtype=np.float32) / 255. alpha = image[:, :, 3:] image = image[:, :, :3] * alpha + color * (1 - alpha) image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() return image, alpha def __getitem__(self, index): while True: image_path = os.path.join(self.root_dir, self.image_dir, self.paths[index]) '''background color, default: white''' bkg_color = [1., 1., 1.] img_list = [] try: for idx in range(7): img, alpha = self.load_im(os.path.join(image_path, '%03d.png' % idx), bkg_color) img_list.append(img) except Exception as e: print(e) index = np.random.randint(0, len(self.paths)) continue break imgs = torch.stack(img_list, dim=0).float() data = { 'cond_imgs': imgs[0], # (3, H, W) 'target_imgs': imgs[1:], # (6, 3, H, W) } return data