File size: 1,484 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Build datasets / dataloaders from a Config, consistent across train & test."""
from __future__ import annotations

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from .unified_dataset import UnifiedSegDataset
from .transforms import build_transform
from ..engine.distributed import is_dist


def build_dataset(cfg, split: str) -> UnifiedSegDataset:
    train = (split == "train")
    synth = cfg.synth_train_dir if train else ""
    # construct without transform first so in_channels/num_classes auto-detect runs
    ds = UnifiedSegDataset(
        data_root=cfg.data_root, dataset=cfg.dataset, protocol=cfg.protocol, split=split,
        transform=None, in_channels=cfg.in_channels, num_classes=cfg.num_classes,
        synth_dir=synth,
    )
    ds.transform = build_transform(cfg.img_size, ds.in_channels, train=train,
                                   aug=cfg.aug, normalize=cfg.normalize)
    return ds


def build_loader(cfg, split: str, ds: UnifiedSegDataset) -> DataLoader:
    train = (split == "train")
    sampler = None
    if is_dist():
        sampler = DistributedSampler(ds, shuffle=train, drop_last=train)
    return DataLoader(
        ds,
        batch_size=cfg.batch_size,
        shuffle=(train and sampler is None),
        sampler=sampler,
        num_workers=cfg.num_workers,
        pin_memory=True,
        drop_last=(train and sampler is None),
        persistent_workers=cfg.num_workers > 0,
    )