Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import numpy as np | |
import os | |
import pytorch_lightning as pl | |
import torch | |
import webdataset as wds | |
from torchvision.transforms import transforms | |
from ldm.util import instantiate_from_config | |
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): | |
"""Take a list of samples (as dictionary) and create a batch, preserving the keys. | |
If `tensors` is True, `ndarray` objects are combined into | |
tensor batches. | |
:param dict samples: list of samples | |
:param bool tensors: whether to turn lists of ndarrays into a single ndarray | |
:returns: single sample consisting of a batch | |
:rtype: dict | |
""" | |
keys = set.intersection(*[set(sample.keys()) for sample in samples]) | |
batched = {key: [] for key in keys} | |
for s in samples: | |
[batched[key].append(s[key]) for key in batched] | |
result = {} | |
for key in batched: | |
if isinstance(batched[key][0], (int, float)): | |
if combine_scalars: | |
result[key] = np.array(list(batched[key])) | |
elif isinstance(batched[key][0], torch.Tensor): | |
if combine_tensors: | |
result[key] = torch.stack(list(batched[key])) | |
elif isinstance(batched[key][0], np.ndarray): | |
if combine_tensors: | |
result[key] = np.array(list(batched[key])) | |
else: | |
result[key] = list(batched[key]) | |
return result | |
class WebDataModuleFromConfig(pl.LightningDataModule): | |
def __init__(self, | |
tar_base, | |
batch_size, | |
train=None, | |
validation=None, | |
test=None, | |
num_workers=4, | |
multinode=True, | |
min_size=None, | |
max_pwatermark=1.0, | |
**kwargs): | |
super().__init__() | |
print(f'Setting tar base to {tar_base}') | |
self.tar_base = tar_base | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.train = train | |
self.validation = validation | |
self.test = test | |
self.multinode = multinode | |
self.min_size = min_size # filter out very small images | |
self.max_pwatermark = max_pwatermark # filter out watermarked images | |
def make_loader(self, dataset_config): | |
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] | |
image_transforms = transforms.Compose(image_transforms) | |
process = instantiate_from_config(dataset_config['process']) | |
shuffle = dataset_config.get('shuffle', 0) | |
shardshuffle = shuffle > 0 | |
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only | |
tars = os.path.join(self.tar_base, dataset_config.shards) | |
dset = wds.WebDataset( | |
tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle, | |
handler=wds.warn_and_continue).repeat().shuffle(shuffle) | |
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') | |
dset = ( | |
dset.select(self.filter_keys).decode('pil', | |
handler=wds.warn_and_continue).select(self.filter_size).map_dict( | |
jpg=image_transforms, handler=wds.warn_and_continue).map(process)) | |
dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn)) | |
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers) | |
return loader | |
def filter_size(self, x): | |
if self.min_size is None: | |
return True | |
try: | |
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[ | |
'json']['pwatermark'] <= self.max_pwatermark | |
except Exception: | |
return False | |
def filter_keys(self, x): | |
try: | |
return ("jpg" in x) and ("txt" in x) | |
except Exception: | |
return False | |
def train_dataloader(self): | |
return self.make_loader(self.train) | |
def val_dataloader(self): | |
return None | |
def test_dataloader(self): | |
return None | |
if __name__ == '__main__': | |
from omegaconf import OmegaConf | |
config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml") | |
datamod = WebDataModuleFromConfig(**config["data"]["params"]) | |
dataloader = datamod.train_dataloader() | |
for batch in dataloader: | |
print(batch.keys()) | |
print(batch['jpg'].shape) | |