Spaces:
Running
Running
""" | |
Code adapted from timm https://github.com/huggingface/pytorch-image-models | |
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich | |
""" | |
import logging | |
from contextlib import suppress | |
from functools import partial | |
from itertools import repeat | |
import numpy as np | |
import torch | |
import torch.utils.data | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from timm.data.dataset import IterableImageDataset | |
from timm.data.loader import PrefetchLoader, _worker_init | |
from timm.data.transforms_factory import create_transform | |
_logger = logging.getLogger(__name__) | |
def fast_collate(batch, target_dtype=torch.uint8): | |
"""A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)""" | |
assert isinstance(batch[0], tuple) | |
batch_size = len(batch) | |
if isinstance(batch[0][0], np.ndarray): | |
targets = torch.tensor([b[1] for b in batch], dtype=target_dtype) | |
assert len(targets) == batch_size | |
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) | |
for i in range(batch_size): | |
tensor[i] += torch.from_numpy(batch[i][0]) | |
return tensor, targets | |
else: | |
raise ValueError(f"Incorrect batch type: {type(batch[0][0])}") | |
def adapt_to_chs(x, n): | |
if not isinstance(x, (tuple, list)): | |
x = tuple(repeat(x, n)) | |
elif len(x) != n: | |
# doubled channels | |
if len(x) * 2 == n: | |
x = np.concatenate((x, x)) | |
_logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.") | |
else: | |
x_mean = np.mean(x).item() | |
x = (x_mean,) * n | |
_logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.") | |
else: | |
assert len(x) == n, "normalization stats must match image channels" | |
return x | |
class PrefetchLoaderForMultiInput(PrefetchLoader): | |
def __init__( | |
self, | |
loader, | |
mean=IMAGENET_DEFAULT_MEAN, | |
std=IMAGENET_DEFAULT_STD, | |
channels=3, | |
device=torch.device("cpu"), | |
img_dtype=torch.float32, | |
): | |
mean = adapt_to_chs(mean, channels) | |
std = adapt_to_chs(std, channels) | |
normalization_shape = (1, channels, 1, 1) | |
self.loader = loader | |
self.device = device | |
self.img_dtype = img_dtype | |
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape) | |
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape) | |
self.is_cuda = torch.cuda.is_available() and device.type == "cpu" | |
def __iter__(self): | |
first = True | |
if self.is_cuda: | |
stream = torch.cuda.Stream() | |
stream_context = partial(torch.cuda.stream, stream=stream) | |
else: | |
stream = None | |
stream_context = suppress | |
for next_input, next_target in self.loader: | |
with stream_context(): | |
next_input = next_input.to(device=self.device, non_blocking=True) | |
next_target = next_target.to(device=self.device, non_blocking=True) | |
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std) | |
if not first: | |
yield input, target # noqa: F823, F821 | |
else: | |
first = False | |
if stream is not None: | |
torch.cuda.current_stream().wait_stream(stream) | |
input = next_input | |
target = next_target | |
yield input, target | |
def create_loader( | |
dataset, | |
input_size, | |
batch_size, | |
mean=IMAGENET_DEFAULT_MEAN, | |
std=IMAGENET_DEFAULT_STD, | |
num_workers=1, | |
crop_pct=None, | |
crop_mode=None, | |
pin_memory=False, | |
img_dtype=torch.float32, | |
device=torch.device("cpu"), | |
persistent_workers=True, | |
worker_seeding="all", | |
target_type=torch.int64, | |
): | |
transform = create_transform( | |
input_size, | |
is_training=False, | |
use_prefetcher=True, | |
mean=mean, | |
std=std, | |
crop_pct=crop_pct, | |
crop_mode=crop_mode, | |
) | |
dataset.transform = transform | |
if isinstance(dataset, IterableImageDataset): | |
# give Iterable datasets early knowledge of num_workers so that sample estimates | |
# are correct before worker processes are launched | |
dataset.set_loader_cfg(num_workers=num_workers) | |
raise ValueError("Incorrect dataset type: IterableImageDataset") | |
loader_class = torch.utils.data.DataLoader | |
loader_args = dict( | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
sampler=None, | |
collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type), | |
pin_memory=pin_memory, | |
drop_last=False, | |
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding), | |
persistent_workers=persistent_workers, | |
) | |
try: | |
loader = loader_class(dataset, **loader_args) | |
except TypeError: | |
loader_args.pop("persistent_workers") # only in Pytorch 1.7+ | |
loader = loader_class(dataset, **loader_args) | |
loader = PrefetchLoaderForMultiInput( | |
loader, | |
mean=mean, | |
std=std, | |
channels=input_size[0], | |
device=device, | |
img_dtype=img_dtype, | |
) | |
return loader | |