Spaces:
Runtime error
Runtime error
import os.path | |
from torch.utils.data import Dataset, DataLoader | |
import torch | |
import numpy as np | |
import pandas as pd | |
from skimage import io | |
from Utils.Augmentations import Augmentations, Resize | |
class Datasets(Dataset): | |
def __init__(self, data_file, transform=None, phase='train', *args, **kwargs): | |
self.transform = transform | |
self.data_info = pd.read_csv(data_file, index_col=0) | |
self.phase = phase | |
def __len__(self): | |
return len(self.data_info) | |
def __getitem__(self, index): | |
data = self.pull_item_seg(index) | |
return data | |
def pull_item_seg(self, index): | |
""" | |
:param index: image index | |
""" | |
data = self.data_info.iloc[index] | |
img_name = data['img'] | |
label_name = data['label'] | |
ori_img = io.imread(img_name, as_gray=False) | |
ori_label = io.imread(label_name, as_gray=True) | |
assert (ori_img is not None and ori_label is not None), f'{img_name} or {label_name} is not valid' | |
if self.transform is not None: | |
img, label = self.transform((ori_img, ori_label)) | |
one_hot_label = np.zeros([2] + list(label.shape), dtype=np.float) | |
one_hot_label[0] = label == 0 | |
one_hot_label[1] = label > 0 | |
return_dict = { | |
'img': torch.from_numpy(img).permute(2, 0, 1), | |
'label': torch.from_numpy(one_hot_label), | |
'img_name': os.path.basename(img_name) | |
} | |
return return_dict | |
def get_data_loader(config, test_mode=False): | |
if not test_mode: | |
train_params = { | |
'batch_size': config['BATCH_SIZE'], | |
'shuffle': config['IS_SHUFFLE'], | |
'drop_last': False, | |
'collate_fn': collate_fn, | |
'num_workers': config['NUM_WORKERS'], | |
'pin_memory': False | |
} | |
# data_file, config, transform=None | |
train_set = Datasets( | |
config['DATASET'], | |
Augmentations( | |
config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'train', config['PHASE'], config | |
), | |
config['PHASE'], | |
config | |
) | |
patterns = ['train'] | |
else: | |
patterns = [] | |
if config['IS_VAL']: | |
val_params = { | |
'batch_size': config['VAL_BATCH_SIZE'], | |
'shuffle': False, | |
'drop_last': False, | |
'collate_fn': collate_fn, | |
'num_workers': config['NUM_WORKERS'], | |
'pin_memory': False | |
} | |
val_set = Datasets( | |
config['VAL_DATASET'], | |
Augmentations( | |
config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'val', config['PHASE'], config | |
), | |
config['PHASE'], | |
config | |
) | |
patterns += ['val'] | |
if config['IS_TEST']: | |
test_params = { | |
'batch_size': config['VAL_BATCH_SIZE'], | |
'shuffle': False, | |
'drop_last': False, | |
'collate_fn': collate_fn, | |
'num_workers': config['NUM_WORKERS'], | |
'pin_memory': False | |
} | |
test_set = Datasets( | |
config['TEST_DATASET'], | |
Augmentations( | |
config['IMG_SIZE'], config['PRIOR_MEAN'], config['PRIOR_STD'], 'test', config['PHASE'], config | |
), | |
config['PHASE'], | |
config | |
) | |
patterns += ['test'] | |
data_loaders = {} | |
for x in patterns: | |
data_loaders[x] = DataLoader(eval(x+'_set'), **eval(x+'_params')) | |
return data_loaders | |
def collate_fn(batch): | |
def to_tensor(item): | |
if torch.is_tensor(item): | |
return item | |
elif isinstance(item, type(np.array(0))): | |
return torch.from_numpy(item).float() | |
elif isinstance(item, type('0')): | |
return item | |
elif isinstance(item, list): | |
return item | |
elif isinstance(item, dict): | |
return item | |
return_data = {} | |
for key in batch[0].keys(): | |
return_data[key] = [] | |
for sample in batch: | |
for key, value in sample.items(): | |
return_data[key].append(to_tensor(value)) | |
keys = set(batch[0].keys()) - {'img_name'} | |
for key in keys: | |
return_data[key] = torch.stack(return_data[key], dim=0) | |
return return_data | |