from typing import Optional import torchdata.datapipes.iter import webdataset as wds from omegaconf import DictConfig from pytorch_lightning import LightningDataModule try: from sdata import create_dataset, create_dummy_dataset, create_loader except ImportError as e: print("#" * 100) print("Datasets not yet available") print("to enable, we need to add stable-datasets as a submodule") print("please use ``git submodule update --init --recursive``") print("and do ``pip install -e stable-datasets/`` from the root of this repo") print("#" * 100) exit(1) class StableDataModuleFromConfig(LightningDataModule): def __init__( self, train: DictConfig, validation: Optional[DictConfig] = None, test: Optional[DictConfig] = None, skip_val_loader: bool = False, dummy: bool = False, ): super().__init__() self.train_config = train assert ( "datapipeline" in self.train_config and "loader" in self.train_config ), "train config requires the fields `datapipeline` and `loader`" self.val_config = validation if not skip_val_loader: if self.val_config is not None: assert ( "datapipeline" in self.val_config and "loader" in self.val_config ), "validation config requires the fields `datapipeline` and `loader`" else: print( "Warning: No Validation datapipeline defined, using that one from training" ) self.val_config = train self.test_config = test if self.test_config is not None: assert ( "datapipeline" in self.test_config and "loader" in self.test_config ), "test config requires the fields `datapipeline` and `loader`" self.dummy = dummy if self.dummy: print("#" * 100) print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") print("#" * 100) def setup(self, stage: str) -> None: print("Preparing datasets") if self.dummy: data_fn = create_dummy_dataset else: data_fn = create_dataset self.train_datapipeline = data_fn(**self.train_config.datapipeline) if self.val_config: self.val_datapipeline = data_fn(**self.val_config.datapipeline) if self.test_config: self.test_datapipeline = data_fn(**self.test_config.datapipeline) def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: loader = create_loader(self.train_datapipeline, **self.train_config.loader) return loader def val_dataloader(self) -> wds.DataPipeline: return create_loader(self.val_datapipeline, **self.val_config.loader) def test_dataloader(self) -> wds.DataPipeline: return create_loader(self.test_datapipeline, **self.test_config.loader)