from pathlib import Path from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import pytorch3d import torch from torch.utils.data import SequentialSampler from omegaconf import DictConfig from pytorch3d.implicitron.dataset.data_loader_map_provider import \ SequenceDataLoaderMapProvider from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( JsonIndexDatasetMapProviderV2, registry) from pytorch3d.implicitron.tools.config import expand_args_fields from pytorch3d.renderer.cameras import CamerasBase from torch.utils.data import DataLoader from pytorch3d.datasets import R2N2, collate_batched_meshes from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional from .utils import DatasetMap def get_dataset(cfg: ProjectConfig): if cfg.dataset.type == 'co3dv2': from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE dataset_cfg: CO3DConfig = cfg.dataset dataloader_cfg: DataloaderConfig = cfg.dataloader # Exclude bad and low-quality sequences, XH: why this is needed? exclude_sequence = [] exclude_sequence.extend(EXCLUDE_SEQUENCE.get(dataset_cfg.category, [])) exclude_sequence.extend(LOW_QUALITY_SEQUENCE.get(dataset_cfg.category, [])) # Whether to load pointclouds kwargs = dict( remove_empty_masks=True, n_frames_per_sequence=1, load_point_clouds=True, max_points=dataset_cfg.max_points, image_height=dataset_cfg.image_size, image_width=dataset_cfg.image_size, mask_images=dataset_cfg.mask_images, exclude_sequence=exclude_sequence, pick_sequence=() if dataset_cfg.restrict_model_ids is None else dataset_cfg.restrict_model_ids, ) # Get dataset mapper dataset_map_provider_type = registry.get(JsonIndexDatasetMapProviderV2, "JsonIndexDatasetMapProviderV2") expand_args_fields(dataset_map_provider_type) dataset_map_provider = dataset_map_provider_type( category=dataset_cfg.category, subset_name=dataset_cfg.subset_name, dataset_root=dataset_cfg.root, test_on_train=False, only_test_set=False, load_eval_batches=True, dataset_JsonIndexDataset_args=DictConfig(kwargs), ) # Get datasets datasets = dataset_map_provider.get_dataset_map() # how to select specific frames?? # PATCH BUG WITH POINT CLOUD LOCATIONS! for dataset in (datasets["train"], datasets["val"]): # print(dataset.seq_annots.items()) for key, ann in dataset.seq_annots.items(): correct_point_cloud_path = Path(dataset.dataset_root) / Path(*Path(ann.point_cloud.path).parts[-3:]) assert correct_point_cloud_path.is_file(), correct_point_cloud_path ann.point_cloud.path = str(correct_point_cloud_path) # Get dataloader mapper data_loader_map_provider_type = registry.get(SequenceDataLoaderMapProvider, "SequenceDataLoaderMapProvider") expand_args_fields(data_loader_map_provider_type) data_loader_map_provider = data_loader_map_provider_type( batch_size=dataloader_cfg.batch_size, num_workers=dataloader_cfg.num_workers, ) # QUICK HACK: Patch the train dataset because it is not used but it throws an error if (len(datasets['train']) == 0 and len(datasets[dataset_cfg.eval_split]) > 0 and dataset_cfg.restrict_model_ids is not None and cfg.run.job == 'sample'): datasets = DatasetMap(train=datasets[dataset_cfg.eval_split], val=datasets[dataset_cfg.eval_split], test=datasets[dataset_cfg.eval_split]) # XH: why all eval split? print('Note: You used restrict_model_ids and there were no ids in the train set.') # Get dataloaders dataloaders = data_loader_map_provider.get_data_loader_map(datasets) dataloader_train = dataloaders['train'] dataloader_val = dataloader_vis = dataloaders[dataset_cfg.eval_split] # Replace validation dataloader sampler with SequentialSampler # seems to be randomly sampled? with a fixed random seed? but one cannot control which image is being sampled?? dataloader_val.batch_sampler.sampler = SequentialSampler(dataloader_val.batch_sampler.sampler.data_source) # Modify for accelerate dataloader_train.batch_sampler.drop_last = True dataloader_val.batch_sampler.drop_last = False elif cfg.dataset.type == 'shapenet_r2n2': # from ..configs.structured import ShapeNetR2N2Config from .r2n2_my import R2N2Sample dataset_cfg: ShapeNetR2N2Config = cfg.dataset # for k in dataset_cfg: # print(k) datasets = [R2N2Sample(dataset_cfg.max_points, dataset_cfg.fix_sample, dataset_cfg.image_size, cfg.augmentations, s, dataset_cfg.shapenet_dir, dataset_cfg.r2n2_dir, dataset_cfg.splits_file, load_textures=False, return_all_views=True) for s in ['train', 'val', 'test']] dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=True) dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=False) dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=False) elif cfg.dataset.type in ['behave', 'behave-objonly', 'behave-humonly', 'behave-dtransl', 'behave-objonly-segm', 'behave-humonly-segm', 'behave-attn', 'behave-test', 'behave-attn-test', 'behave-hum-pe', 'behave-hum-noscale', 'behave-hum-surf', 'behave-objv2v']: from .behave_dataset import BehaveDataset, NTUDataset, BehaveObjOnly, BehaveHumanOnly, BehaveHumanOnlyPosEnc from .behave_dataset import BehaveHumanOnlySegmInput, BehaveObjOnlySegmInput, BehaveTestOnly, BehaveHumNoscale from .behave_dataset import BehaveHumanOnlySurfSample from .dtransl_dataset import DirectTranslDataset from .behave_paths import DataPaths from configs.structured import BehaveDatasetConfig from .behave_crossattn import BehaveCrossAttnDataset, BehaveCrossAttnTest from .behave_dataset import BehaveObjOnlyV2V dataset_cfg: BehaveDatasetConfig = cfg.dataset # print(dataset_cfg.behave_dir) train_paths, val_paths = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir) # exit(0) # split validation paths to only consider the selected batches bs = cfg.dataloader.batch_size num_batches_total = int(np.ceil(len(val_paths)/cfg.dataloader.batch_size)) end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total # print(cfg.run.batch_end, cfg.run.batch_start, end_idx) val_paths = val_paths[cfg.run.batch_start*bs:end_idx*bs] if cfg.dataset.type == 'behave': train_type = BehaveDataset val_datatype = BehaveDataset if 'ntu' not in dataset_cfg.split_file else NTUDataset elif cfg.dataset.type == 'behave-test': train_type = BehaveDataset val_datatype = BehaveTestOnly elif cfg.dataset.type == 'behave-objonly': train_type = BehaveObjOnly val_datatype = BehaveObjOnly assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' elif cfg.dataset.type == 'behave-humonly': train_type = BehaveHumanOnly val_datatype = BehaveHumanOnly assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' elif cfg.dataset.type == 'behave-hum-noscale': train_type = BehaveHumNoscale val_datatype = BehaveHumNoscale elif cfg.dataset.type == 'behave-hum-pe': train_type = BehaveHumanOnlyPosEnc val_datatype = BehaveHumanOnlyPosEnc elif cfg.dataset.type == 'behave-hum-surf': train_type = BehaveHumanOnlySurfSample val_datatype = BehaveHumanOnlySurfSample elif cfg.dataset.type == 'behave-humonly-segm': assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!' train_type = BehaveHumanOnly val_datatype = BehaveHumanOnlySegmInput assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' elif cfg.dataset.type == 'behave-objonly-segm': assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!' train_type = BehaveObjOnly val_datatype = BehaveObjOnlySegmInput assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' elif cfg.dataset.type == 'behave-dtransl': train_type = DirectTranslDataset val_datatype = DirectTranslDataset elif cfg.dataset.type == 'behave-attn': train_type = BehaveCrossAttnDataset val_datatype = BehaveCrossAttnDataset elif cfg.dataset.type == 'behave-attn-test': train_type = BehaveCrossAttnDataset val_datatype = BehaveCrossAttnTest elif cfg.dataset.type == 'behave-objv2v': train_type = BehaveObjOnlyV2V val_datatype = BehaveObjOnlyV2V else: raise NotImplementedError dataset_train = train_type(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, (dataset_cfg.image_size, dataset_cfg.image_size), split='train', sample_ratio_hum=dataset_cfg.sample_ratio_hum, normalize_type=dataset_cfg.normalize_type, smpl_type='gt', load_corr_points=dataset_cfg.load_corr_points, uniform_obj_sample=dataset_cfg.uniform_obj_sample, bkg_type=dataset_cfg.bkg_type, bbox_params=dataset_cfg.bbox_params, pred_binary=cfg.model.predict_binary, ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', use_gt_transl=cfg.dataset.use_gt_transl, cam_noise_std=cfg.dataset.cam_noise_std, sep_same_crop=cfg.dataset.sep_same_crop, aug_blur=cfg.dataset.aug_blur, std_coverage=cfg.dataset.std_coverage, v2v_path=cfg.dataset.v2v_path) dataset_val = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, (dataset_cfg.image_size, dataset_cfg.image_size), split='val', sample_ratio_hum=dataset_cfg.sample_ratio_hum, normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type, load_corr_points=dataset_cfg.load_corr_points, test_transl_type=dataset_cfg.test_transl_type, uniform_obj_sample=dataset_cfg.uniform_obj_sample, bkg_type=dataset_cfg.bkg_type, bbox_params=dataset_cfg.bbox_params, pred_binary=cfg.model.predict_binary, ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', use_gt_transl=cfg.dataset.use_gt_transl, sep_same_crop=cfg.dataset.sep_same_crop, std_coverage=cfg.dataset.std_coverage, v2v_path=cfg.dataset.v2v_path) # dataset_test = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, # (dataset_cfg.image_size, dataset_cfg.image_size), # split='test', sample_ratio_hum=dataset_cfg.sample_ratio_hum, # normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type, # load_corr_points=dataset_cfg.load_corr_points, # test_transl_type=dataset_cfg.test_transl_type, # uniform_obj_sample=dataset_cfg.uniform_obj_sample, # bkg_type=dataset_cfg.bkg_type, # bbox_params=dataset_cfg.bbox_params, # pred_binary=cfg.model.predict_binary, # ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, # compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', # use_gt_transl=cfg.dataset.use_gt_transl, # sep_same_crop=cfg.dataset.sep_same_crop) dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=True) shuffle = cfg.run.job == 'train' dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=shuffle) dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=shuffle) # datasets = [BehaveDataset(p, dataset_cfg.max_points, dataset_cfg.fix_sample, # (dataset_cfg.image_size, dataset_cfg.image_size), # split=s, sample_ratio_hum=dataset_cfg.sample_ratio_hum, # normalize_type=dataset_cfg.normalize_type) for p, s in zip([train_paths, val_paths, val_paths], # ['train', 'val', 'test'])] # dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size, # collate_fn=collate_batched_meshes, # num_workers=cfg.dataloader.num_workers, shuffle=True) # dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size, # collate_fn=collate_batched_meshes, # num_workers=cfg.dataloader.num_workers, shuffle=False) # dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size, # collate_fn=collate_batched_meshes, # num_workers=cfg.dataloader.num_workers, shuffle=False) elif cfg.dataset.type in ['shape']: from .shape_dataset import ShapeDataset from .behave_paths import DataPaths from configs.structured import ShapeDatasetConfig dataset_cfg: ShapeDatasetConfig = cfg.dataset train_paths, _ = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir) val_paths = train_paths # same as training, this is for overfitting # split validation paths to only consider the selected batches bs = cfg.dataloader.batch_size num_batches_total = int(np.ceil(len(val_paths) / cfg.dataloader.batch_size)) end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total # print(cfg.run.batch_end, cfg.run.batch_start, end_idx) val_paths = val_paths[cfg.run.batch_start * bs:end_idx * bs] dataset_train = ShapeDataset(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, (dataset_cfg.image_size, dataset_cfg.image_size), split='train', ) dataset_val = ShapeDataset(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, (dataset_cfg.image_size, dataset_cfg.image_size), split='train', ) dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=True) shuffle = cfg.run.job == 'train' dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=shuffle) dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, collate_fn=collate_batched_meshes, num_workers=cfg.dataloader.num_workers, shuffle=shuffle) else: raise NotImplementedError(cfg.dataset.type) return dataloader_train, dataloader_val, dataloader_vis