DIRECTOR-demo / src /datasets /multimodal_dataset.py
robin-courant's picture
Add app
f7a5cb1 verified
raw
history blame
3.01 kB
from copy import deepcopy as dp
from pathlib import Path
from torch.utils.data import Dataset
class MultimodalDataset(Dataset):
def __init__(
self,
name,
dataset_name,
dataset_dir,
trajectory,
feature_type,
num_rawfeats,
num_feats,
num_cams,
num_cond_feats,
standardization,
augmentation=None,
**modalities,
):
self.dataset_dir = Path(dataset_dir)
self.name = name
self.dataset_name = dataset_name
self.feature_type = feature_type
self.num_rawfeats = num_rawfeats
self.num_feats = num_feats
self.num_cams = num_cams
self.trajectory_dataset = trajectory
self.standardization = standardization
self.modality_datasets = modalities
if augmentation is not None:
self.augmentation = True
self.augmentation_rate = augmentation.rate
self.trajectory_dataset.set_augmentation(augmentation.trajectory)
if hasattr(augmentation, "modalities"):
for modality, augments in augmentation.modalities:
self.modality_datasets[modality].set_augmentation(augments)
else:
self.augmentation = False
# --------------------------------------------------------------------------------- #
def set_split(self, split: str, train_rate: float = 1.0):
self.split = split
# Get trajectory split
self.trajectory_dataset = dp(self.trajectory_dataset).set_split(
split, train_rate
)
self.root_filenames = self.trajectory_dataset.filenames
# Get modality split
for modality_name in self.modality_datasets.keys():
self.modality_datasets[modality_name].filenames = self.root_filenames
self.get_feature = self.trajectory_dataset.get_feature
self.get_matrix = self.trajectory_dataset.get_matrix
return self
# --------------------------------------------------------------------------------- #
def __getitem__(self, index):
traj_out = self.trajectory_dataset[index]
traj_filename, traj_feature, padding_mask, intrinsics = traj_out
out = {
"traj_filename": traj_filename,
"traj_feat": traj_feature,
"padding_mask": padding_mask,
"intrinsics": intrinsics,
}
for modality_name, modality_dataset in self.modality_datasets.items():
modality_filename, modality_feature, modality_raw = modality_dataset[index]
assert traj_filename.split(".")[0] == modality_filename.split(".")[0]
out[f"{modality_name}_filename"] = modality_filename
out[f"{modality_name}_feat"] = modality_feature
out[f"{modality_name}_raw"] = modality_raw
out[f"{modality_name}_padding_mask"] = padding_mask
return out
def __len__(self):
return len(self.trajectory_dataset)