Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,887 Bytes
b5ce381 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
from typing import Any, Dict, Optional
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from omegaconf import DictConfig
import sys
import pyrootutils
root = pyrootutils.setup_root(__file__, pythonpath=True)
sys.path.append(root)
from sgm.data.video_dataset_latent import VideoDataset
class VideoDataModule(LightningDataModule):
"""
A DataModule implements 5 key methods:
def prepare_data(self):
# things to do on 1 GPU/TPU (not on every GPU/TPU in DDP)
# download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# things to do on every process in DDP
# load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def teardown(self):
# called on every process in DDP
# clean up after fit or test
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html
"""
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.train_config = train
assert "datapipeline" in self.train_config and "loader" in self.train_config, (
"train config requires the fields `datapipeline` and `loader`"
)
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
print(
"Warning: No Validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "test config requires the fields `datapipeline` and `loader`"
def setup(self, stage: Optional[str] = None):
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be
careful not to execute things like random split twice!
"""
print("Preparing datasets")
self.train_datapipeline = VideoDataset(**self.train_config.datapipeline)
if self.val_config:
self.val_datapipeline = VideoDataset(**self.val_config.datapipeline)
if self.test_config:
self.test_datapipeline = VideoDataset(**self.test_config.datapipeline)
def train_dataloader(self):
return DataLoader(self.train_datapipeline, **self.train_config.loader)
def val_dataloader(self):
if self.val_datapipeline:
return DataLoader(self.val_datapipeline, **self.val_config.loader)
else:
return None
def test_dataloader(self):
if self.test_datapipeline:
return DataLoader(self.test_datapipeline, **self.test_config.loader)
else:
return None
def teardown(self, stage: Optional[str] = None):
"""Clean up after fit or test."""
pass
def state_dict(self):
"""Extra things to save to checkpoint."""
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
"""Things to do when loading checkpoint."""
pass
if __name__ == "__main__":
import hydra
import omegaconf
import pyrootutils
import cv2
root = pyrootutils.setup_root(__file__, pythonpath=True)
cfg = omegaconf.OmegaConf.load(
root / "configs" / "datamodule" / "image_datamodule.yaml"
)
# cfg.data_dir = str(root / "data")
data = hydra.utils.instantiate(cfg)
data.prepare_data()
data.setup()
print(data.data_train.__getitem__(0)[0].shape)
batch = next(iter(data.train_dataloader()))
identity, target = batch
image_identity = (identity[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
image_other = (target[0].permute(1, 2, 0).numpy() + 1) / 2 * 255
cv2.imwrite("image_identity.png", image_identity[:, :, ::-1])
cv2.imwrite("image_other.png", image_other[:, :, ::-1])
|