|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from enum import Enum |
|
from typing import Any, Callable, List, Optional, TypeVar |
|
|
|
import torch |
|
from torch.utils.data import Sampler |
|
|
|
from .datasets import ImageNet, ImageNet22k |
|
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
class SamplerType(Enum): |
|
DISTRIBUTED = 0 |
|
EPOCH = 1 |
|
INFINITE = 2 |
|
SHARDED_INFINITE = 3 |
|
SHARDED_INFINITE_NEW = 4 |
|
|
|
|
|
def _make_bool_str(b: bool) -> str: |
|
return "yes" if b else "no" |
|
|
|
|
|
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): |
|
def transform(sample): |
|
image, target = sample |
|
if image_transform is not None: |
|
image = image_transform(image) |
|
if target_transform is not None: |
|
target = target_transform(target) |
|
return image, target |
|
|
|
return transform |
|
|
|
|
|
def _parse_dataset_str(dataset_str: str): |
|
tokens = dataset_str.split(":") |
|
|
|
name = tokens[0] |
|
kwargs = {} |
|
|
|
for token in tokens[1:]: |
|
key, value = token.split("=") |
|
assert key in ("root", "extra", "split") |
|
kwargs[key] = value |
|
|
|
if name == "ImageNet": |
|
class_ = ImageNet |
|
if "split" in kwargs: |
|
kwargs["split"] = ImageNet.Split[kwargs["split"]] |
|
elif name == "ImageNet22k": |
|
class_ = ImageNet22k |
|
else: |
|
raise ValueError(f'Unsupported dataset "{name}"') |
|
|
|
return class_, kwargs |
|
|
|
|
|
def make_dataset( |
|
*, |
|
dataset_str: str, |
|
transform: Optional[Callable] = None, |
|
target_transform: Optional[Callable] = None, |
|
): |
|
""" |
|
Creates a dataset with the specified parameters. |
|
|
|
Args: |
|
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). |
|
transform: A transform to apply to images. |
|
target_transform: A transform to apply to targets. |
|
|
|
Returns: |
|
The created dataset. |
|
""" |
|
logger.info(f'using dataset: "{dataset_str}"') |
|
|
|
class_, kwargs = _parse_dataset_str(dataset_str) |
|
dataset = class_(transform=transform, target_transform=target_transform, **kwargs) |
|
|
|
logger.info(f"# of dataset samples: {len(dataset):,d}") |
|
|
|
|
|
if not hasattr(dataset, "transform"): |
|
setattr(dataset, "transform", transform) |
|
if not hasattr(dataset, "target_transform"): |
|
setattr(dataset, "target_transform", target_transform) |
|
|
|
return dataset |
|
|
|
|
|
def _make_sampler( |
|
*, |
|
dataset, |
|
type: Optional[SamplerType] = None, |
|
shuffle: bool = False, |
|
seed: int = 0, |
|
size: int = -1, |
|
advance: int = 0, |
|
) -> Optional[Sampler]: |
|
sample_count = len(dataset) |
|
|
|
if type == SamplerType.INFINITE: |
|
logger.info("sampler: infinite") |
|
if size > 0: |
|
raise ValueError("sampler size > 0 is invalid") |
|
return InfiniteSampler( |
|
sample_count=sample_count, |
|
shuffle=shuffle, |
|
seed=seed, |
|
advance=advance, |
|
) |
|
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): |
|
logger.info("sampler: sharded infinite") |
|
if size > 0: |
|
raise ValueError("sampler size > 0 is invalid") |
|
|
|
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW |
|
return ShardedInfiniteSampler( |
|
sample_count=sample_count, |
|
shuffle=shuffle, |
|
seed=seed, |
|
advance=advance, |
|
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, |
|
) |
|
elif type == SamplerType.EPOCH: |
|
logger.info("sampler: epoch") |
|
if advance > 0: |
|
raise NotImplementedError("sampler advance > 0 is not supported") |
|
size = size if size > 0 else sample_count |
|
logger.info(f"# of samples / epoch: {size:,d}") |
|
return EpochSampler( |
|
size=size, |
|
sample_count=sample_count, |
|
shuffle=shuffle, |
|
seed=seed, |
|
) |
|
elif type == SamplerType.DISTRIBUTED: |
|
logger.info("sampler: distributed") |
|
if size > 0: |
|
raise ValueError("sampler size > 0 is invalid") |
|
if advance > 0: |
|
raise ValueError("sampler advance > 0 is invalid") |
|
return torch.utils.data.DistributedSampler( |
|
dataset=dataset, |
|
shuffle=shuffle, |
|
seed=seed, |
|
drop_last=False, |
|
) |
|
|
|
logger.info("sampler: none") |
|
return None |
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
def make_data_loader( |
|
*, |
|
dataset, |
|
batch_size: int, |
|
num_workers: int, |
|
shuffle: bool = True, |
|
seed: int = 0, |
|
sampler_type: Optional[SamplerType] = SamplerType.INFINITE, |
|
sampler_size: int = -1, |
|
sampler_advance: int = 0, |
|
drop_last: bool = True, |
|
persistent_workers: bool = False, |
|
collate_fn: Optional[Callable[[List[T]], Any]] = None, |
|
): |
|
""" |
|
Creates a data loader with the specified parameters. |
|
|
|
Args: |
|
dataset: A dataset (third party, LaViDa or WebDataset). |
|
batch_size: The size of batches to generate. |
|
num_workers: The number of workers to use. |
|
shuffle: Whether to shuffle samples. |
|
seed: The random seed to use. |
|
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. |
|
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. |
|
sampler_advance: How many samples to skip (when applicable). |
|
drop_last: Whether the last non-full batch of data should be dropped. |
|
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. |
|
collate_fn: Function that performs batch collation |
|
""" |
|
|
|
sampler = _make_sampler( |
|
dataset=dataset, |
|
type=sampler_type, |
|
shuffle=shuffle, |
|
seed=seed, |
|
size=sampler_size, |
|
advance=sampler_advance, |
|
) |
|
|
|
logger.info("using PyTorch data loader") |
|
data_loader = torch.utils.data.DataLoader( |
|
dataset, |
|
sampler=sampler, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
pin_memory=True, |
|
drop_last=drop_last, |
|
persistent_workers=persistent_workers, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
try: |
|
logger.info(f"# of batches: {len(data_loader):,d}") |
|
except TypeError: |
|
logger.info("infinite data loader") |
|
return data_loader |
|
|