|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implementations of dataset settings and augmentations for tokenization |
|
|
|
|
|
Run this command to interactively debug: |
|
|
python3 -m cosmos_predict1.tokenizer.training.datasets.dataset_provider |
|
|
|
|
|
""" |
|
|
|
|
|
from cosmos_predict1.tokenizer.training.datasets.augmentation_provider import ( |
|
|
video_train_augmentations, |
|
|
video_val_augmentations, |
|
|
) |
|
|
from cosmos_predict1.tokenizer.training.datasets.utils import categorize_aspect_and_store |
|
|
from cosmos_predict1.tokenizer.training.datasets.video_dataset import Dataset |
|
|
from cosmos_predict1.utils.lazy_config import instantiate |
|
|
|
|
|
_VIDEO_PATTERN_DICT = { |
|
|
"hdvila_video": "datasets/hdvila/videos/*.mp4", |
|
|
} |
|
|
|
|
|
|
|
|
def apply_augmentations(data_dict, augmentations_dict): |
|
|
""" |
|
|
Loop over each LazyCall object and apply it to data_dict in place. |
|
|
""" |
|
|
for aug_name, lazy_aug in augmentations_dict.items(): |
|
|
aug_instance = instantiate(lazy_aug) |
|
|
data_dict = aug_instance(data_dict) |
|
|
return data_dict |
|
|
|
|
|
|
|
|
class AugmentDataset(Dataset): |
|
|
def __init__(self, base_dataset, augmentations_dict): |
|
|
""" |
|
|
base_dataset: the video dataset instance |
|
|
augmentations_dict: the dictionary returned by |
|
|
video_train_augmentations() or video_val_augmentations() |
|
|
""" |
|
|
self.base_dataset = base_dataset |
|
|
|
|
|
|
|
|
self.augmentations = [] |
|
|
for aug_name, lazy_aug in augmentations_dict.items(): |
|
|
aug_instance = instantiate(lazy_aug) |
|
|
self.augmentations.append((aug_name, aug_instance)) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.base_dataset) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
|
|
data = self.base_dataset[index] |
|
|
data = categorize_aspect_and_store(data) |
|
|
|
|
|
|
|
|
for aug_name, aug_instance in self.augmentations: |
|
|
data = aug_instance(data) |
|
|
|
|
|
return data |
|
|
|
|
|
|
|
|
def dataset_entry( |
|
|
dataset_name: str, |
|
|
dataset_type: str, |
|
|
is_train: bool = True, |
|
|
resolution="720", |
|
|
crop_height=256, |
|
|
num_video_frames=25, |
|
|
) -> AugmentDataset: |
|
|
if dataset_type != "video": |
|
|
raise ValueError(f"Dataset type {dataset_type} is not supported") |
|
|
|
|
|
|
|
|
base_dataset = Dataset( |
|
|
video_pattern=_VIDEO_PATTERN_DICT[dataset_name.lower()], |
|
|
num_video_frames=num_video_frames, |
|
|
) |
|
|
|
|
|
|
|
|
if is_train: |
|
|
aug_dict = video_train_augmentations( |
|
|
input_keys=["video"], |
|
|
resolution=resolution, |
|
|
crop_height=crop_height, |
|
|
) |
|
|
else: |
|
|
aug_dict = video_val_augmentations( |
|
|
input_keys=["video"], |
|
|
resolution=resolution, |
|
|
crop_height=crop_height, |
|
|
) |
|
|
|
|
|
|
|
|
return AugmentDataset(base_dataset, aug_dict) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
dataset = dataset_entry( |
|
|
dataset_name="davis_video", |
|
|
dataset_type="video", |
|
|
is_train=False, |
|
|
resolution="720", |
|
|
crop_height=256, |
|
|
num_video_frames=25, |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Total samples in dataset: {len(dataset)}") |
|
|
|
|
|
|
|
|
if len(dataset) > 0: |
|
|
sample_idx = 0 |
|
|
sample = dataset[sample_idx] |
|
|
print(f"Sample index {sample_idx} keys: {list(sample.keys())}") |
|
|
if "video" in sample: |
|
|
print("Video shape:", sample["video"].shape) |
|
|
if "video_name" in sample: |
|
|
print("Video metadata:", sample["video_name"]) |
|
|
print("---\nSample loaded successfully.\n") |
|
|
else: |
|
|
print("Dataset has no samples!") |
|
|
|