File size: 1,484 Bytes
b8fae22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | """Build datasets / dataloaders from a Config, consistent across train & test."""
from __future__ import annotations
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from .unified_dataset import UnifiedSegDataset
from .transforms import build_transform
from ..engine.distributed import is_dist
def build_dataset(cfg, split: str) -> UnifiedSegDataset:
train = (split == "train")
synth = cfg.synth_train_dir if train else ""
# construct without transform first so in_channels/num_classes auto-detect runs
ds = UnifiedSegDataset(
data_root=cfg.data_root, dataset=cfg.dataset, protocol=cfg.protocol, split=split,
transform=None, in_channels=cfg.in_channels, num_classes=cfg.num_classes,
synth_dir=synth,
)
ds.transform = build_transform(cfg.img_size, ds.in_channels, train=train,
aug=cfg.aug, normalize=cfg.normalize)
return ds
def build_loader(cfg, split: str, ds: UnifiedSegDataset) -> DataLoader:
train = (split == "train")
sampler = None
if is_dist():
sampler = DistributedSampler(ds, shuffle=train, drop_last=train)
return DataLoader(
ds,
batch_size=cfg.batch_size,
shuffle=(train and sampler is None),
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=(train and sampler is None),
persistent_workers=cfg.num_workers > 0,
)
|