Spaces:
Running
Running
File size: 5,698 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import pytorch_lightning as L
from torch.utils.data import DataLoader, random_split
import torch
import time
import webdataset as wds
from torch.utils.data import default_collate
import math
from PIL import Image
class ImageDataModule(L.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
test_dataset,
full_batch_size,
num_workers,
eval_batch_size=None,
num_nodes=1,
num_devices=1,
val_proportion=0.1,
):
super().__init__()
self._builders = {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset,
}
self.num_workers = num_workers
self.collate_fn = dict_collate_fn()
self.full_batch_size = full_batch_size
self.train_batch_size = full_batch_size // (num_nodes * num_devices)
if eval_batch_size is None:
self.eval_batch_size = self.train_batch_size
self.full_eval_batch_size = self.full_batch_size
else:
self.eval_batch_size = eval_batch_size // (num_nodes * num_devices)
self.full_eval_batch_size = eval_batch_size
print(f"Each GPU will receive {self.train_batch_size} images for training")
print(f"Each GPU will receive {self.eval_batch_size} images for evaluation")
self.val_proportion = val_proportion
self.world_size = num_nodes * num_devices
def setup(self, stage=None):
"""Setup the datamodule.
Args:
stage (str): stage of the datamodule
Is be one of "fit" or "test" or None
"""
print("Stage", stage)
start_time = time.time()
if stage == "fit" or stage is None:
self.train_dataset = self._builders["train"]()
self.train_dataset, self.num_train_batches = self.get_webdataset_length(
self.train_dataset,
dict_collate_fn(),
self.full_batch_size,
self.train_batch_size,
)
self.val_dataset = self._builders["val"]()
self.val_dataset, self.num_val_batches = self.get_webdataset_length(
self.val_dataset,
dict_collate_fn(),
self.full_eval_batch_size,
self.eval_batch_size,
0,
)
print(f"Train dataset size: {len(self.train_dataset)}")
print(f"Val dataset size: {len(self.val_dataset)}")
else:
self.test_dataset = self._builders["test"]()
self.test_dataset, self.num_test_batches = self.get_webdataset_length(
self.test_dataset,
dict_collate_fn(),
self.full_eval_batch_size,
self.eval_batch_size,
self.num_workers,
)
print(f"Test dataset size: {len(self.test_dataset)}")
end_time = time.time()
print(f"Setup took {(end_time - start_time):.2f} seconds")
def train_dataloader(self):
return wds.WebLoader(
self.train_dataset,
batch_size=None,
shuffle=False,
num_workers=self.num_workers,
# persistent_workers=self.num_workers > 1,
).with_length(self.num_train_batches)
# return DataLoader(
# self.train_dataset,
# batch_size=self.batch_size,
# shuffle=True,
# pin_memory=False,
# drop_last=True,
# num_workers=self.num_workers,
# collate_fn=self.train_dataset.collate_fn,
# )
def val_dataloader(self):
return wds.WebLoader(
self.val_dataset,
batch_size=None,
shuffle=False,
num_workers=0,
).with_length(self.num_val_batches)
def test_dataloader(self):
return wds.WebLoader(
self.test_dataset,
batch_size=None,
shuffle=False,
num_workers=0,
).with_length(self.num_test_batches)
def get_webdataset_length(
self, dataset, collate_fn, full_batch_size, batch_size, num_workers=0
):
dataset = dataset.compose(
wds.batched(
batch_size,
partial=self.world_size > 1,
collation_fn=collate_fn,
# dict_collate_and_pad(["flan_t5_xl"], max_length=256),
)
)
num_samples = dataset.num_samples
if self.world_size > 1:
num_batches = math.ceil(num_samples / full_batch_size)
num_workers = max(1, num_workers)
num_worker_batches = math.ceil(num_batches / num_workers)
num_batches = num_worker_batches * num_workers
num_samples = num_batches * full_batch_size
dataset = dataset.with_epoch(num_worker_batches).with_length(
num_worker_batches
)
else:
num_batches = math.ceil(num_samples / batch_size)
dataset = dataset.with_epoch(num_batches).with_length(num_batches)
return dataset, num_batches
def dict_collate_fn():
def dict_collate(batch):
output_dict = {}
if isinstance(batch[0], dict):
for key in batch[0].keys():
output_dict[key] = dict_collate([item[key] for item in batch])
else:
# Check if the batch contains PIL images
if isinstance(batch[0], Image.Image):
output_dict = batch # Return list of PIL images directly
else:
output_dict = default_collate(batch)
return output_dict
return dict_collate
|