|
""" |
|
Code adapted from timm https://github.com/huggingface/pytorch-image-models |
|
|
|
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich |
|
""" |
|
|
|
import logging |
|
from contextlib import suppress |
|
from functools import partial |
|
from itertools import repeat |
|
|
|
import numpy as np |
|
import torch |
|
import torch.utils.data |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
|
from timm.data.dataset import IterableImageDataset |
|
from timm.data.loader import PrefetchLoader, _worker_init |
|
from timm.data.transforms_factory import create_transform |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
def fast_collate(batch, target_dtype=torch.uint8): |
|
"""A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)""" |
|
assert isinstance(batch[0], tuple) |
|
batch_size = len(batch) |
|
if isinstance(batch[0][0], np.ndarray): |
|
targets = torch.tensor([b[1] for b in batch], dtype=target_dtype) |
|
assert len(targets) == batch_size |
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) |
|
for i in range(batch_size): |
|
tensor[i] += torch.from_numpy(batch[i][0]) |
|
return tensor, targets |
|
else: |
|
raise ValueError(f"Incorrect batch type: {type(batch[0][0])}") |
|
|
|
|
|
def adapt_to_chs(x, n): |
|
if not isinstance(x, (tuple, list)): |
|
x = tuple(repeat(x, n)) |
|
elif len(x) != n: |
|
|
|
if len(x) * 2 == n: |
|
x = np.concatenate((x, x)) |
|
_logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.") |
|
else: |
|
x_mean = np.mean(x).item() |
|
x = (x_mean,) * n |
|
_logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.") |
|
else: |
|
assert len(x) == n, "normalization stats must match image channels" |
|
return x |
|
|
|
|
|
class PrefetchLoaderForMultiInput(PrefetchLoader): |
|
def __init__( |
|
self, |
|
loader, |
|
mean=IMAGENET_DEFAULT_MEAN, |
|
std=IMAGENET_DEFAULT_STD, |
|
channels=3, |
|
device=torch.device("cpu"), |
|
img_dtype=torch.float32, |
|
): |
|
|
|
mean = adapt_to_chs(mean, channels) |
|
std = adapt_to_chs(std, channels) |
|
normalization_shape = (1, channels, 1, 1) |
|
|
|
self.loader = loader |
|
self.device = device |
|
self.img_dtype = img_dtype |
|
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) |
|
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) |
|
|
|
self.is_cuda = torch.cuda.is_available() and device.type == "cpu" |
|
|
|
def __iter__(self): |
|
first = True |
|
if self.is_cuda: |
|
stream = torch.cuda.Stream() |
|
stream_context = partial(torch.cuda.stream, stream=stream) |
|
else: |
|
stream = None |
|
stream_context = suppress |
|
|
|
for next_input, next_target in self.loader: |
|
|
|
with stream_context(): |
|
next_input = next_input.to(device=self.device, non_blocking=True) |
|
next_target = next_target.to(device=self.device, non_blocking=True) |
|
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) |
|
|
|
if not first: |
|
yield input, target |
|
else: |
|
first = False |
|
|
|
if stream is not None: |
|
torch.cuda.current_stream().wait_stream(stream) |
|
|
|
input = next_input |
|
target = next_target |
|
|
|
yield input, target |
|
|
|
|
|
def create_loader( |
|
dataset, |
|
input_size, |
|
batch_size, |
|
mean=IMAGENET_DEFAULT_MEAN, |
|
std=IMAGENET_DEFAULT_STD, |
|
num_workers=1, |
|
crop_pct=None, |
|
crop_mode=None, |
|
pin_memory=False, |
|
img_dtype=torch.float32, |
|
device=torch.device("cpu"), |
|
persistent_workers=True, |
|
worker_seeding="all", |
|
target_type=torch.int64, |
|
): |
|
|
|
transform = create_transform( |
|
input_size, |
|
is_training=False, |
|
use_prefetcher=True, |
|
mean=mean, |
|
std=std, |
|
crop_pct=crop_pct, |
|
crop_mode=crop_mode, |
|
) |
|
dataset.transform = transform |
|
|
|
if isinstance(dataset, IterableImageDataset): |
|
|
|
|
|
dataset.set_loader_cfg(num_workers=num_workers) |
|
raise ValueError("Incorrect dataset type: IterableImageDataset") |
|
|
|
loader_class = torch.utils.data.DataLoader |
|
loader_args = dict( |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
sampler=None, |
|
collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type), |
|
pin_memory=pin_memory, |
|
drop_last=False, |
|
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), |
|
persistent_workers=persistent_workers, |
|
) |
|
try: |
|
loader = loader_class(dataset, **loader_args) |
|
except TypeError: |
|
loader_args.pop("persistent_workers") |
|
loader = loader_class(dataset, **loader_args) |
|
|
|
loader = PrefetchLoaderForMultiInput( |
|
loader, |
|
mean=mean, |
|
std=std, |
|
channels=input_size[0], |
|
device=device, |
|
img_dtype=img_dtype, |
|
) |
|
|
|
return loader |
|
|