| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import List, Optional |
|
|
| import lightning.pytorch as pl |
| import torch |
| from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from nemo.lightning.pytorch.plugins import MegatronDataSampler |
|
|
|
|
| class MockDataModule(pl.LightningDataModule): |
| """ |
| A PyTorch Lightning DataModule for creating mock datasets for training, validation, and testing. |
| |
| Args: |
| image_h (int): Height of the images in the dataset. Default is 1024. |
| image_w (int): Width of the images in the dataset. Default is 1024. |
| micro_batch_size (int): Micro batch size for the data sampler. Default is 4. |
| global_batch_size (int): Global batch size for the data sampler. Default is 8. |
| rampup_batch_size (Optional[List[int]]): Ramp-up batch size for the data sampler. Default is None. |
| num_train_samples (int): Number of training samples. Default is 10,000. |
| num_val_samples (int): Number of validation samples. Default is 10,000. |
| num_test_samples (int): Number of testing samples. Default is 10,000. |
| num_workers (int): Number of worker threads for data loading. Default is 8. |
| pin_memory (bool): Whether to use pinned memory for data loading. Default is True. |
| persistent_workers (bool): Whether to use persistent workers for data loading. Default is False. |
| image_precached (bool): Whether the images are pre-cached. Default is False. |
| text_precached (bool): Whether the text data is pre-cached. Default is False. |
| """ |
|
|
| def __init__( |
| self, |
| image_h: int = 1024, |
| image_w: int = 1024, |
| micro_batch_size: int = 4, |
| global_batch_size: int = 8, |
| rampup_batch_size: Optional[List[int]] = None, |
| num_train_samples: int = 10_000, |
| num_val_samples: int = 10_000, |
| num_test_samples: int = 10_000, |
| num_workers: int = 8, |
| pin_memory: bool = True, |
| persistent_workers: bool = False, |
| image_precached=False, |
| text_precached=False, |
| ): |
|
|
| super().__init__() |
| self.image_h = image_h |
| self.image_w = image_w |
| self.num_train_samples = num_train_samples |
| self.num_val_samples = num_val_samples |
| self.num_test_samples = num_test_samples |
| self.num_workers = num_workers |
| self.pin_memory = pin_memory |
| self.persistent_workers = persistent_workers |
| self.image_precached = image_precached |
| self.text_precached = text_precached |
| self.global_batch_size = global_batch_size |
| self.micro_batch_size = micro_batch_size |
| self.tokenizer = None |
| self.seq_length = 10 |
|
|
| self.data_sampler = MegatronDataSampler( |
| seq_len=self.seq_length, |
| micro_batch_size=micro_batch_size, |
| global_batch_size=global_batch_size, |
| rampup_batch_size=rampup_batch_size, |
| ) |
|
|
| def setup(self, stage: str = "") -> None: |
| """ |
| Sets up datasets for training, validation, and testing. |
| |
| Args: |
| stage (str): The stage of the process (e.g., 'fit', 'test'). Default is an empty string. |
| """ |
| self._train_ds = _MockT2IDataset( |
| image_H=1024, |
| image_W=1024, |
| length=self.num_train_samples, |
| image_precached=self.image_precached, |
| text_precached=self.text_precached, |
| ) |
| self._validation_ds = _MockT2IDataset( |
| image_H=1024, |
| image_W=1024, |
| length=self.num_val_samples, |
| image_precached=self.image_precached, |
| text_precached=self.text_precached, |
| ) |
| self._test_ds = _MockT2IDataset( |
| image_H=1024, |
| image_W=1024, |
| length=self.num_test_samples, |
| image_precached=self.image_precached, |
| text_precached=self.text_precached, |
| ) |
|
|
| def train_dataloader(self) -> TRAIN_DATALOADERS: |
| """ |
| Returns the training DataLoader. |
| |
| Returns: |
| TRAIN_DATALOADERS: DataLoader for the training dataset. |
| """ |
| if not hasattr(self, "_train_ds"): |
| self.setup() |
| return self._create_dataloader(self._train_ds) |
|
|
| def val_dataloader(self) -> EVAL_DATALOADERS: |
| """ |
| Returns the validation DataLoader. |
| |
| Returns: |
| EVAL_DATALOADERS: DataLoader for the validation dataset. |
| """ |
| if not hasattr(self, "_validation_ds"): |
| self.setup() |
| return self._create_dataloader(self._validation_ds) |
|
|
| def test_dataloader(self) -> EVAL_DATALOADERS: |
| """ |
| Returns the testing DataLoader. |
| |
| Returns: |
| EVAL_DATALOADERS: DataLoader for the testing dataset. |
| """ |
| if not hasattr(self, "_test_ds"): |
| self.setup() |
| return self._create_dataloader(self._test_ds) |
|
|
| def _create_dataloader(self, dataset, **kwargs) -> DataLoader: |
| """ |
| Creates a DataLoader for the given dataset. |
| |
| Args: |
| dataset: The dataset to load. |
| **kwargs: Additional arguments for the DataLoader. |
| |
| Returns: |
| DataLoader: Configured DataLoader for the dataset. |
| """ |
| return DataLoader( |
| dataset, |
| num_workers=self.num_workers, |
| pin_memory=self.pin_memory, |
| persistent_workers=self.persistent_workers, |
| **kwargs, |
| ) |
|
|
|
|
| class _MockT2IDataset(Dataset): |
| """ |
| A mock dataset class for text-to-image tasks, simulating data samples for training and testing. |
| |
| This dataset generates synthetic data for both image and text inputs, with options to use |
| pre-cached latent representations or raw data. The class is designed for use in testing and |
| prototyping machine learning models. |
| |
| Attributes: |
| image_H (int): Height of the generated images. |
| image_W (int): Width of the generated images. |
| length (int): Total number of samples in the dataset. |
| image_key (str): Key for accessing image data in the output dictionary. |
| txt_key (str): Key for accessing text data in the output dictionary. |
| hint_key (str): Key for accessing hint data in the output dictionary. |
| image_precached (bool): Whether to use pre-cached latent representations for images. |
| text_precached (bool): Whether to use pre-cached embeddings for text. |
| prompt_seq_len (int): Sequence length for text prompts. |
| pooled_prompt_dim (int): Dimensionality of pooled text embeddings. |
| context_dim (int): Dimensionality of the text embedding context. |
| vae_scale_factor (int): Scaling factor for the VAE latent representation. |
| vae_channels (int): Number of channels in the VAE latent representation. |
| latent_shape (tuple): Shape of the latent representation for images (if pre-cached). |
| prompt_embeds_shape (tuple): Shape of the text prompt embeddings (if pre-cached). |
| pooped_prompt_embeds_shape (tuple): Shape of pooled text embeddings (if pre-cached). |
| text_ids_shape (tuple): Shape of the text token IDs (if pre-cached). |
| |
| Methods: |
| __getitem__(index): |
| Retrieves a single sample from the dataset based on the specified index. |
| __len__(): |
| Returns the total number of samples in the dataset. |
| """ |
|
|
| def __init__( |
| self, |
| image_H, |
| image_W, |
| length=100000, |
| image_key='images', |
| txt_key='txt', |
| hint_key='hint', |
| image_precached=False, |
| text_precached=False, |
| prompt_seq_len=256, |
| pooled_prompt_dim=768, |
| context_dim=4096, |
| vae_scale_factor=8, |
| vae_channels=16, |
| ): |
| super().__init__() |
| self.length = length |
| self.H = image_H |
| self.W = image_W |
| self.image_key = image_key |
| self.txt_key = txt_key |
| self.hint_key = hint_key |
| self.image_precached = image_precached |
| self.text_precached = text_precached |
| if self.image_precached: |
| self.latent_shape = (vae_channels, int(image_H // vae_scale_factor), int(image_W // vae_scale_factor)) |
| if self.text_precached: |
| self.prompt_embeds_shape = (prompt_seq_len, context_dim) |
| self.pooped_prompt_embeds_shape = (pooled_prompt_dim,) |
| self.text_ids_shape = (prompt_seq_len, 3) |
|
|
| def __getitem__(self, index): |
| """ |
| Retrieves a single sample from the dataset. |
| |
| The sample can include raw image and text data or pre-cached latent representations, |
| depending on the configuration. |
| |
| Args: |
| index (int): Index of the sample to retrieve. |
| |
| Returns: |
| dict: A dictionary containing the generated data sample. The keys and values |
| depend on whether `image_precached` and `text_precached` are set. |
| Possible keys include: |
| - 'latents': Pre-cached latent representation of the image. |
| - 'control_latents': Pre-cached control latent representation. |
| - 'images': Raw image tensor. |
| - 'hint': Hint tensor for the image. |
| - 'prompt_embeds': Pre-cached text prompt embeddings. |
| - 'pooled_prompt_embeds': Pooled text prompt embeddings. |
| - 'text_ids': Text token IDs. |
| - 'txt': Text input string (if text is not pre-cached). |
| """ |
| item = {} |
| if self.image_precached: |
| item['latents'] = torch.randn(self.latent_shape) |
| item['control_latents'] = torch.randn(self.latent_shape) |
| else: |
| item[self.image_key] = torch.randn(3, self.H, self.W) |
| item[self.hint_key] = torch.randn(3, self.H, self.W) |
|
|
| if self.text_precached: |
| item['prompt_embeds'] = torch.randn(self.prompt_embeds_shape) |
| item['pooled_prompt_embeds'] = torch.randn(self.pooped_prompt_embeds_shape) |
| item['text_ids'] = torch.randn(self.text_ids_shape) |
| else: |
| item[self.txt_key] = "This is a sample caption input" |
|
|
| return item |
|
|
| def __len__(self): |
| """ |
| Returns the total number of samples in the dataset. |
| |
| Returns: |
| int: Total number of samples (`length` attribute). |
| """ |
| return self.length |
|
|