|
import os.path |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from skimage import io |
|
from Utils.Augmentations import Augmentations, Resize |
|
|
|
|
|
class Datasets(Dataset): |
|
def __init__(self, data_file, transform=None, phase='train', *args, **kwargs): |
|
self.transform = transform |
|
self.data_info = pd.read_csv(data_file, index_col=0) |
|
self.phase = phase |
|
|
|
def __len__(self): |
|
return len(self.data_info) |
|
|
|
def __getitem__(self, index): |
|
data = self.pull_item_seg(index) |
|
return data |
|
|
|
def pull_item_seg(self, index): |
|
""" |
|
:param index: image index |
|
""" |
|
data = self.data_info.iloc[index] |
|
img_name = data['img'] |
|
label_name = data['label'] |
|
|
|
ori_img = io.imread(img_name, as_gray=False) |
|
ori_label = io.imread(label_name, as_gray=True) |
|
assert (ori_img is not None and ori_label is not None), f'{img_name} or {label_name} is not valid' |
|
|
|
if self.transform is not None: |
|
img, label = self.transform((ori_img, ori_label)) |
|
|
|
one_hot_label = np.zeros([2] + list(label.shape), dtype=np.float) |
|
one_hot_label[0] = label == 0 |
|
one_hot_label[1] = label > 0 |
|
return_dict = { |
|
'img': torch.from_numpy(img).permute(2, 0, 1), |
|
'label': torch.from_numpy(one_hot_label), |
|
'img_name': os.path.basename(img_name) |
|
} |
|
return return_dict |
|
|
|
|
|
def get_data_loader(config, test_mode=False): |
|
if not test_mode: |
|
train_params = { |
|
'batch_size': config['BATCH_SIZE'], |
|
'shuffle': config['IS_SHUFFLE'], |
|
'drop_last': False, |
|
'collate_fn': collate_fn, |
|
'num_workers': config['NUM_WORKERS'], |
|
'pin_memory': False |
|
} |
|
|
|
train_set = Datasets( |
|
config['DATASET'], |
|
Augmentations( |
|
config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'train', config['PHASE'], config |
|
), |
|
config['PHASE'], |
|
config |
|
) |
|
patterns = ['train'] |
|
else: |
|
patterns = [] |
|
|
|
if config['IS_VAL']: |
|
val_params = { |
|
'batch_size': config['VAL_BATCH_SIZE'], |
|
'shuffle': False, |
|
'drop_last': False, |
|
'collate_fn': collate_fn, |
|
'num_workers': config['NUM_WORKERS'], |
|
'pin_memory': False |
|
} |
|
val_set = Datasets( |
|
config['VAL_DATASET'], |
|
Augmentations( |
|
config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'val', config['PHASE'], config |
|
), |
|
config['PHASE'], |
|
config |
|
) |
|
patterns += ['val'] |
|
|
|
if config['IS_TEST']: |
|
test_params = { |
|
'batch_size': config['VAL_BATCH_SIZE'], |
|
'shuffle': False, |
|
'drop_last': False, |
|
'collate_fn': collate_fn, |
|
'num_workers': config['NUM_WORKERS'], |
|
'pin_memory': False |
|
} |
|
test_set = Datasets( |
|
config['TEST_DATASET'], |
|
Augmentations( |
|
config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'test', config['PHASE'], config |
|
), |
|
config['PHASE'], |
|
config |
|
) |
|
patterns += ['test'] |
|
|
|
data_loaders = {} |
|
for x in patterns: |
|
data_loaders[x] = DataLoader(eval(x+'_set'), **eval(x+'_params')) |
|
return data_loaders |
|
|
|
|
|
def collate_fn(batch): |
|
def to_tensor(item): |
|
if torch.is_tensor(item): |
|
return item |
|
elif isinstance(item, type(np.array(0))): |
|
return torch.from_numpy(item).float() |
|
elif isinstance(item, type('0')): |
|
return item |
|
elif isinstance(item, list): |
|
return item |
|
elif isinstance(item, dict): |
|
return item |
|
|
|
return_data = {} |
|
for key in batch[0].keys(): |
|
return_data[key] = [] |
|
|
|
for sample in batch: |
|
for key, value in sample.items(): |
|
return_data[key].append(to_tensor(value)) |
|
|
|
keys = set(batch[0].keys()) - {'img_name'} |
|
for key in keys: |
|
return_data[key] = torch.stack(return_data[key], dim=0) |
|
|
|
return return_data |
|
|
|
|