CoAdapter / ldm /data /dataset_laion.py
MC-E
first push
c05d22e
raw
history blame
4.58 kB
# -*- coding: utf-8 -*-
import numpy as np
import os
import pytorch_lightning as pl
import torch
import webdataset as wds
from torchvision.transforms import transforms
from ldm.util import instantiate_from_config
def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
"""Take a list of samples (as dictionary) and create a batch, preserving the keys.
If `tensors` is True, `ndarray` objects are combined into
tensor batches.
:param dict samples: list of samples
:param bool tensors: whether to turn lists of ndarrays into a single ndarray
:returns: single sample consisting of a batch
:rtype: dict
"""
keys = set.intersection(*[set(sample.keys()) for sample in samples])
batched = {key: [] for key in keys}
for s in samples:
[batched[key].append(s[key]) for key in batched]
result = {}
for key in batched:
if isinstance(batched[key][0], (int, float)):
if combine_scalars:
result[key] = np.array(list(batched[key]))
elif isinstance(batched[key][0], torch.Tensor):
if combine_tensors:
result[key] = torch.stack(list(batched[key]))
elif isinstance(batched[key][0], np.ndarray):
if combine_tensors:
result[key] = np.array(list(batched[key]))
else:
result[key] = list(batched[key])
return result
class WebDataModuleFromConfig(pl.LightningDataModule):
def __init__(self,
tar_base,
batch_size,
train=None,
validation=None,
test=None,
num_workers=4,
multinode=True,
min_size=None,
max_pwatermark=1.0,
**kwargs):
super().__init__()
print(f'Setting tar base to {tar_base}')
self.tar_base = tar_base
self.batch_size = batch_size
self.num_workers = num_workers
self.train = train
self.validation = validation
self.test = test
self.multinode = multinode
self.min_size = min_size # filter out very small images
self.max_pwatermark = max_pwatermark # filter out watermarked images
def make_loader(self, dataset_config):
image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
image_transforms = transforms.Compose(image_transforms)
process = instantiate_from_config(dataset_config['process'])
shuffle = dataset_config.get('shuffle', 0)
shardshuffle = shuffle > 0
nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
tars = os.path.join(self.tar_base, dataset_config.shards)
dset = wds.WebDataset(
tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle,
handler=wds.warn_and_continue).repeat().shuffle(shuffle)
print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
dset = (
dset.select(self.filter_keys).decode('pil',
handler=wds.warn_and_continue).select(self.filter_size).map_dict(
jpg=image_transforms, handler=wds.warn_and_continue).map(process))
dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn))
loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers)
return loader
def filter_size(self, x):
if self.min_size is None:
return True
try:
return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[
'json']['pwatermark'] <= self.max_pwatermark
except Exception:
return False
def filter_keys(self, x):
try:
return ("jpg" in x) and ("txt" in x)
except Exception:
return False
def train_dataloader(self):
return self.make_loader(self.train)
def val_dataloader(self):
return None
def test_dataloader(self):
return None
if __name__ == '__main__':
from omegaconf import OmegaConf
config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml")
datamod = WebDataModuleFromConfig(**config["data"]["params"])
dataloader = datamod.train_dataloader()
for batch in dataloader:
print(batch.keys())
print(batch['jpg'].shape)