import os.path from torch.utils.data import Dataset, DataLoader import torch import numpy as np import pandas as pd from skimage import io from Utils.Augmentations import Augmentations, Resize class Datasets(Dataset): def __init__(self, data_file, transform=None, phase='train', *args, **kwargs): self.transform = transform self.data_info = pd.read_csv(data_file, index_col=0) self.phase = phase def __len__(self): return len(self.data_info) def __getitem__(self, index): data = self.pull_item_seg(index) return data def pull_item_seg(self, index): """ :param index: image index """ data = self.data_info.iloc[index] img_name = data['img'] label_name = data['label'] ori_img = io.imread(img_name, as_gray=False) ori_label = io.imread(label_name, as_gray=True) assert (ori_img is not None and ori_label is not None), f'{img_name} or {label_name} is not valid' if self.transform is not None: img, label = self.transform((ori_img, ori_label)) one_hot_label = np.zeros([2] + list(label.shape), dtype=np.float) one_hot_label[0] = label == 0 one_hot_label[1] = label > 0 return_dict = { 'img': torch.from_numpy(img).permute(2, 0, 1), 'label': torch.from_numpy(one_hot_label), 'img_name': os.path.basename(img_name) } return return_dict def get_data_loader(config, test_mode=False): if not test_mode: train_params = { 'batch_size': config['BATCH_SIZE'], 'shuffle': config['IS_SHUFFLE'], 'drop_last': False, 'collate_fn': collate_fn, 'num_workers': config['NUM_WORKERS'], 'pin_memory': False } # data_file, config, transform=None train_set = Datasets( config['DATASET'], Augmentations( config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'train', config['PHASE'], config ), config['PHASE'], config ) patterns = ['train'] else: patterns = [] if config['IS_VAL']: val_params = { 'batch_size': config['VAL_BATCH_SIZE'], 'shuffle': False, 'drop_last': False, 'collate_fn': collate_fn, 'num_workers': config['NUM_WORKERS'], 'pin_memory': False } val_set = Datasets( config['VAL_DATASET'], Augmentations( config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'val', config['PHASE'], config ), config['PHASE'], config ) patterns += ['val'] if config['IS_TEST']: test_params = { 'batch_size': config['VAL_BATCH_SIZE'], 'shuffle': False, 'drop_last': False, 'collate_fn': collate_fn, 'num_workers': config['NUM_WORKERS'], 'pin_memory': False } test_set = Datasets( config['TEST_DATASET'], Augmentations( config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'test', config['PHASE'], config ), config['PHASE'], config ) patterns += ['test'] data_loaders = {} for x in patterns: data_loaders[x] = DataLoader(eval(x+'_set'), **eval(x+'_params')) return data_loaders def collate_fn(batch): def to_tensor(item): if torch.is_tensor(item): return item elif isinstance(item, type(np.array(0))): return torch.from_numpy(item).float() elif isinstance(item, type('0')): return item elif isinstance(item, list): return item elif isinstance(item, dict): return item return_data = {} for key in batch[0].keys(): return_data[key] = [] for sample in batch: for key, value in sample.items(): return_data[key].append(to_tensor(value)) keys = set(batch[0].keys()) - {'img_name'} for key in keys: return_data[key] = torch.stack(return_data[key], dim=0) return return_data