File size: 5,446 Bytes
b5f33fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Code adapted from timm https://github.com/huggingface/pytorch-image-models

Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
"""

import logging
from contextlib import suppress
from functools import partial
from itertools import repeat

import numpy as np
import torch
import torch.utils.data
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.dataset import IterableImageDataset
from timm.data.loader import PrefetchLoader, _worker_init
from timm.data.transforms_factory import create_transform

_logger = logging.getLogger(__name__)


def fast_collate(batch, target_dtype=torch.uint8):
    """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
    assert isinstance(batch[0], tuple)
    batch_size = len(batch)
    if isinstance(batch[0][0], np.ndarray):
        targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
        assert len(targets) == batch_size
        tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
        for i in range(batch_size):
            tensor[i] += torch.from_numpy(batch[i][0])
        return tensor, targets
    else:
        raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")


def adapt_to_chs(x, n):
    if not isinstance(x, (tuple, list)):
        x = tuple(repeat(x, n))
    elif len(x) != n:
        # doubled channels
        if len(x) * 2 == n:
            x = np.concatenate((x, x))
            _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
        else:
            x_mean = np.mean(x).item()
            x = (x_mean,) * n
            _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
    else:
        assert len(x) == n, "normalization stats must match image channels"
    return x


class PrefetchLoaderForMultiInput(PrefetchLoader):
    def __init__(
        self,
        loader,
        mean=IMAGENET_DEFAULT_MEAN,
        std=IMAGENET_DEFAULT_STD,
        channels=3,
        device=torch.device("cuda"),
        img_dtype=torch.float32,
    ):

        mean = adapt_to_chs(mean, channels)
        std = adapt_to_chs(std, channels)
        normalization_shape = (1, channels, 1, 1)

        self.loader = loader
        self.device = device
        self.img_dtype = img_dtype
        self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
        self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)

        self.is_cuda = torch.cuda.is_available() and device.type == "cuda"

    def __iter__(self):
        first = True
        if self.is_cuda:
            stream = torch.cuda.Stream()
            stream_context = partial(torch.cuda.stream, stream=stream)
        else:
            stream = None
            stream_context = suppress

        for next_input, next_target in self.loader:

            with stream_context():
                next_input = next_input.to(device=self.device, non_blocking=True)
                next_target = next_target.to(device=self.device, non_blocking=True)
                next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)

            if not first:
                yield input, target  # noqa: F823, F821
            else:
                first = False

            if stream is not None:
                torch.cuda.current_stream().wait_stream(stream)

            input = next_input
            target = next_target

        yield input, target


def create_loader(
    dataset,
    input_size,
    batch_size,
    mean=IMAGENET_DEFAULT_MEAN,
    std=IMAGENET_DEFAULT_STD,
    num_workers=1,
    crop_pct=None,
    crop_mode=None,
    pin_memory=False,
    img_dtype=torch.float32,
    device=torch.device("cuda"),
    persistent_workers=True,
    worker_seeding="all",
    target_type=torch.int64,
):

    transform = create_transform(
        input_size,
        is_training=False,
        use_prefetcher=True,
        mean=mean,
        std=std,
        crop_pct=crop_pct,
        crop_mode=crop_mode,
    )
    dataset.transform = transform

    if isinstance(dataset, IterableImageDataset):
        # give Iterable datasets early knowledge of num_workers so that sample estimates
        # are correct before worker processes are launched
        dataset.set_loader_cfg(num_workers=num_workers)
        raise ValueError("Incorrect dataset type: IterableImageDataset")

    loader_class = torch.utils.data.DataLoader
    loader_args = dict(
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        sampler=None,
        collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
        pin_memory=pin_memory,
        drop_last=False,
        worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
        persistent_workers=persistent_workers,
    )
    try:
        loader = loader_class(dataset, **loader_args)
    except TypeError:
        loader_args.pop("persistent_workers")  # only in Pytorch 1.7+
        loader = loader_class(dataset, **loader_args)

    loader = PrefetchLoaderForMultiInput(
        loader,
        mean=mean,
        std=std,
        channels=input_size[0],
        device=device,
        img_dtype=img_dtype,
    )

    return loader