|
from torch.utils.data import ( |
|
DataLoader |
|
) |
|
from dataloader.dataset import ( |
|
SegmentationDataset, |
|
AugmentedSegmentationDataset |
|
) |
|
|
|
def sam_dataloader(cfg): |
|
loader_args = dict(num_workers=cfg.base.num_workers, |
|
pin_memory=cfg.base.pin_memory) |
|
""" |
|
Since the output of SAM's mask decoder is 256 by default (without using a postprocessing step), |
|
hence, we chose to resize the mask ground truth into 256x256 for a better output (prediction without postprocessing). |
|
""" |
|
if cfg.base.dataset_name in ["buidnewprocess", "kvasir", "isiconlytrain", "drive"]: |
|
train_dataset = SegmentationDataset(cfg.base.dataset_name, |
|
cfg.dataloader.train_dir_img, |
|
cfg.dataloader.train_dir_mask, |
|
scale=(1024, 256)) |
|
elif cfg.base.dataset_name in ["bts", "las_mri", "las_ct"]: |
|
train_dataset = AugmentedSegmentationDataset(cfg.base.dataset_name, |
|
cfg.dataloader.train_dir_img, |
|
cfg.dataloader.train_dir_mask, |
|
scale=(1024, 256)) |
|
else: |
|
raise NameError(f"[Error] Dataset {cfg.base.dataset_name} is either in wrong format or not yet implemented!") |
|
|
|
val_dataset = SegmentationDataset(cfg.base.dataset_name, |
|
cfg.dataloader.valid_dir_img, |
|
cfg.dataloader.valid_dir_mask, |
|
scale=(1024, 256)) |
|
test_dataset = SegmentationDataset(cfg.base.dataset_name, |
|
cfg.dataloader.test_dir_img, |
|
cfg.dataloader.test_dir_mask, |
|
scale=(1024, 256)) |
|
train_loader = DataLoader(train_dataset, |
|
shuffle=True, |
|
batch_size=cfg.train.train_batch_size, |
|
multiprocessing_context="fork", |
|
**loader_args) |
|
val_loader = DataLoader(val_dataset, |
|
shuffle=False, |
|
drop_last=True, |
|
batch_size=cfg.train.valid_batch_size, |
|
multiprocessing_context="fork", |
|
**loader_args) |
|
test_loader = DataLoader(test_dataset, |
|
shuffle=False, |
|
batch_size=cfg.train.test_batch_size, |
|
drop_last=True, |
|
multiprocessing_context="fork", |
|
**loader_args) |
|
|
|
return train_loader, val_loader, test_loader, val_dataset, test_dataset |