File size: 4,575 Bytes
f93b3cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-

import numpy as np
import os
import pytorch_lightning as pl
import torch
import webdataset as wds
from torchvision.transforms import transforms

from ldm.util import instantiate_from_config


def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
    """Take a list  of samples (as dictionary) and create a batch, preserving the keys.
    If `tensors` is True, `ndarray` objects are combined into
    tensor batches.
    :param dict samples: list of samples
    :param bool tensors: whether to turn lists of ndarrays into a single ndarray
    :returns: single sample consisting of a batch
    :rtype: dict
    """
    keys = set.intersection(*[set(sample.keys()) for sample in samples])
    batched = {key: [] for key in keys}

    for s in samples:
        [batched[key].append(s[key]) for key in batched]

    result = {}
    for key in batched:
        if isinstance(batched[key][0], (int, float)):
            if combine_scalars:
                result[key] = np.array(list(batched[key]))
        elif isinstance(batched[key][0], torch.Tensor):
            if combine_tensors:
                result[key] = torch.stack(list(batched[key]))
        elif isinstance(batched[key][0], np.ndarray):
            if combine_tensors:
                result[key] = np.array(list(batched[key]))
        else:
            result[key] = list(batched[key])
    return result


class WebDataModuleFromConfig(pl.LightningDataModule):

    def __init__(self,
                 tar_base,
                 batch_size,
                 train=None,
                 validation=None,
                 test=None,
                 num_workers=4,
                 multinode=True,
                 min_size=None,
                 max_pwatermark=1.0,
                 **kwargs):
        super().__init__()
        print(f'Setting tar base to {tar_base}')
        self.tar_base = tar_base
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train = train
        self.validation = validation
        self.test = test
        self.multinode = multinode
        self.min_size = min_size  # filter out very small images
        self.max_pwatermark = max_pwatermark  # filter out watermarked images

    def make_loader(self, dataset_config):
        image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
        image_transforms = transforms.Compose(image_transforms)

        process = instantiate_from_config(dataset_config['process'])

        shuffle = dataset_config.get('shuffle', 0)
        shardshuffle = shuffle > 0

        nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only

        tars = os.path.join(self.tar_base, dataset_config.shards)

        dset = wds.WebDataset(
            tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle,
            handler=wds.warn_and_continue).repeat().shuffle(shuffle)
        print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')

        dset = (
            dset.select(self.filter_keys).decode('pil',
                                                 handler=wds.warn_and_continue).select(self.filter_size).map_dict(
                                                     jpg=image_transforms, handler=wds.warn_and_continue).map(process))
        dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn))

        loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers)

        return loader

    def filter_size(self, x):
        if self.min_size is None:
            return True
        try:
            return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[
                'json']['pwatermark'] <= self.max_pwatermark
        except Exception:
            return False

    def filter_keys(self, x):
        try:
            return ("jpg" in x) and ("txt" in x)
        except Exception:
            return False

    def train_dataloader(self):
        return self.make_loader(self.train)

    def val_dataloader(self):
        return None

    def test_dataloader(self):
        return None


if __name__ == '__main__':
    from omegaconf import OmegaConf
    config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml")
    datamod = WebDataModuleFromConfig(**config["data"]["params"])
    dataloader = datamod.train_dataloader()

    for batch in dataloader:
        print(batch.keys())
        print(batch['jpg'].shape)