# -*- coding: utf-8 -*- import numpy as np import os import pytorch_lightning as pl import torch import webdataset as wds from torchvision.transforms import transforms from ldm.util import instantiate_from_config def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): """Take a list of samples (as dictionary) and create a batch, preserving the keys. If `tensors` is True, `ndarray` objects are combined into tensor batches. :param dict samples: list of samples :param bool tensors: whether to turn lists of ndarrays into a single ndarray :returns: single sample consisting of a batch :rtype: dict """ keys = set.intersection(*[set(sample.keys()) for sample in samples]) batched = {key: [] for key in keys} for s in samples: [batched[key].append(s[key]) for key in batched] result = {} for key in batched: if isinstance(batched[key][0], (int, float)): if combine_scalars: result[key] = np.array(list(batched[key])) elif isinstance(batched[key][0], torch.Tensor): if combine_tensors: result[key] = torch.stack(list(batched[key])) elif isinstance(batched[key][0], np.ndarray): if combine_tensors: result[key] = np.array(list(batched[key])) else: result[key] = list(batched[key]) return result class WebDataModuleFromConfig(pl.LightningDataModule): def __init__(self, tar_base, batch_size, train=None, validation=None, test=None, num_workers=4, multinode=True, min_size=None, max_pwatermark=1.0, **kwargs): super().__init__() print(f'Setting tar base to {tar_base}') self.tar_base = tar_base self.batch_size = batch_size self.num_workers = num_workers self.train = train self.validation = validation self.test = test self.multinode = multinode self.min_size = min_size # filter out very small images self.max_pwatermark = max_pwatermark # filter out watermarked images def make_loader(self, dataset_config): image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] image_transforms = transforms.Compose(image_transforms) process = instantiate_from_config(dataset_config['process']) shuffle = dataset_config.get('shuffle', 0) shardshuffle = shuffle > 0 nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only tars = os.path.join(self.tar_base, dataset_config.shards) dset = wds.WebDataset( tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle, handler=wds.warn_and_continue).repeat().shuffle(shuffle) print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') dset = ( dset.select(self.filter_keys).decode('pil', handler=wds.warn_and_continue).select(self.filter_size).map_dict( jpg=image_transforms, handler=wds.warn_and_continue).map(process)) dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn)) loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers) return loader def filter_size(self, x): if self.min_size is None: return True try: return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[ 'json']['pwatermark'] <= self.max_pwatermark except Exception: return False def filter_keys(self, x): try: return ("jpg" in x) and ("txt" in x) except Exception: return False def train_dataloader(self): return self.make_loader(self.train) def val_dataloader(self): return None def test_dataloader(self): return None if __name__ == '__main__': from omegaconf import OmegaConf config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml") datamod = WebDataModuleFromConfig(**config["data"]["params"]) dataloader = datamod.train_dataloader() for batch in dataloader: print(batch.keys()) print(batch['jpg'].shape)