| | from typing import List |
| | from torch.utils.data import IterableDataset, Dataset |
| | from omegaconf import DictConfig |
| | import torch |
| | import numpy as np |
| | from datasets.dummy import DummyVideoDataset |
| | from datasets.openx_base import OpenXVideoDataset |
| | from datasets.droid import DroidVideoDataset |
| | from datasets.something_something import SomethingSomethingDataset |
| | from datasets.epic_kitchen import EpicKitchenDataset |
| | from datasets.pandas import PandasVideoDataset |
| | from datasets.deprecated.video_1x_wm import WorldModel1XDataset |
| | from datasets.agibot_world import AgibotWorldDataset |
| | from datasets.ego4d import Ego4DVideoDataset |
| |
|
| | subset_classes = dict( |
| | dummy=DummyVideoDataset, |
| | something_something=SomethingSomethingDataset, |
| | epic_kitchen=EpicKitchenDataset, |
| | pandas=PandasVideoDataset, |
| | agibot_world=AgibotWorldDataset, |
| | video_1x_wm=WorldModel1XDataset, |
| | ego4d=Ego4DVideoDataset, |
| | droid=DroidVideoDataset, |
| | austin_buds=OpenXVideoDataset, |
| | austin_sailor=OpenXVideoDataset, |
| | austin_sirius=OpenXVideoDataset, |
| | bc_z=OpenXVideoDataset, |
| | berkeley_autolab=OpenXVideoDataset, |
| | berkeley_cable=OpenXVideoDataset, |
| | berkeley_fanuc=OpenXVideoDataset, |
| | bridge=OpenXVideoDataset, |
| | cmu_stretch=OpenXVideoDataset, |
| | dlr_edan=OpenXVideoDataset, |
| | dobbe=OpenXVideoDataset, |
| | fmb=OpenXVideoDataset, |
| | fractal=OpenXVideoDataset, |
| | iamlab_cmu=OpenXVideoDataset, |
| | jaco_play=OpenXVideoDataset, |
| | language_table=OpenXVideoDataset, |
| | nyu_franka=OpenXVideoDataset, |
| | roboturk=OpenXVideoDataset, |
| | stanford_hydra=OpenXVideoDataset, |
| | taco_play=OpenXVideoDataset, |
| | toto=OpenXVideoDataset, |
| | ucsd_kitchen=OpenXVideoDataset, |
| | utaustin_mutex=OpenXVideoDataset, |
| | viola=OpenXVideoDataset, |
| | ) |
| |
|
| |
|
| | class MixtureDataset(IterableDataset): |
| | """ |
| | A fault tolerant mixture of video datasets |
| | """ |
| |
|
| | def __init__(self, cfg: DictConfig, split: str = "training"): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.debug = cfg.debug |
| | self.split = split |
| | self.random_seed = np.random.get_state()[1][0] |
| | self.subset_cfg = { |
| | k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/") |
| | } |
| | if split == "all": |
| | raise ValueError("split cannot be `all` for MixtureDataset`") |
| | weight = dict(self.cfg[split].weight) |
| | |
| | for key in weight: |
| | if key not in self.subset_cfg: |
| | raise ValueError( |
| | f"Dataset '{key}' specified in weights but not found in configuration" |
| | ) |
| | self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight} |
| | weight_type = self.cfg[split].weight_type |
| | self.subsets: List[Dataset] = [] |
| | for subset_name, subset_cfg in self.subset_cfg.items(): |
| | subset_cfg["height"] = self.cfg.height |
| | subset_cfg["width"] = self.cfg.width |
| | subset_cfg["n_frames"] = self.cfg.n_frames |
| | subset_cfg["fps"] = self.cfg.fps |
| | subset_cfg["load_video_latent"] = self.cfg.load_video_latent |
| | subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed |
| | subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens |
| | subset_cfg["image_to_video"] = self.cfg.image_to_video |
| | self.subsets.append(subset_classes[subset_name](subset_cfg, split)) |
| | if weight_type == "relative": |
| | weight[subset_name] = weight[subset_name] * len(self.subsets[-1]) |
| |
|
| | |
| | total_weight = sum(weight.values()) |
| | self.normalized_weights = {k: v / total_weight for k, v in weight.items()} |
| |
|
| | |
| | dataset_sizes = { |
| | subset_name: len(subset) |
| | for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets) |
| | } |
| |
|
| | |
| | print("\nDataset information for split '{}':".format(self.split)) |
| | print("-" * 60) |
| | print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}") |
| | print("-" * 60) |
| | for subset_name, norm_weight in sorted( |
| | self.normalized_weights.items(), key=lambda x: -x[1] |
| | ): |
| | size = dataset_sizes[subset_name] |
| | orig_weight = self.cfg[split].weight[subset_name] |
| | print( |
| | f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}" |
| | ) |
| | print("-" * 60) |
| |
|
| | |
| | self.cumsum_weights = {} |
| | cumsum = 0 |
| | for k, v in self.normalized_weights.items(): |
| | cumsum += v |
| | self.cumsum_weights[k] = cumsum |
| |
|
| | |
| | self.records = [] |
| | for subset in self.subsets: |
| | self.records.extend(subset.records) |
| |
|
| | def __iter__(self): |
| | while True: |
| | |
| | rand = np.random.random() |
| | for subset_name, cumsum in self.cumsum_weights.items(): |
| | if rand <= cumsum: |
| | selected_subset = subset_name |
| | break |
| |
|
| | |
| | subset_idx = list(self.subset_cfg.keys()).index(selected_subset) |
| |
|
| | try: |
| | |
| | dataset = self.subsets[subset_idx] |
| | idx = np.random.randint(len(dataset)) |
| | sample = dataset[idx] |
| | yield sample |
| | except Exception as e: |
| | if self.debug: |
| | raise e |
| | else: |
| | print(f"Error sampling from {selected_subset}: {str(e)}") |
| | continue |
| |
|