|
import json |
|
import os |
|
from sklearn.model_selection import train_test_split |
|
|
|
from monai.data import DataLoader, Dataset |
|
from monai import transforms |
|
|
|
def datafold_read(datalist, basedir, fold=0, key="training"): |
|
with open(datalist) as f: |
|
json_data = json.load(f) |
|
|
|
json_data = json_data[key] |
|
|
|
for d in json_data: |
|
for k in d: |
|
if isinstance(d[k], list): |
|
d[k] = [os.path.join(basedir, iv) for iv in d[k]] |
|
elif isinstance(d[k], str): |
|
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k] |
|
|
|
tr = [] |
|
val = [] |
|
for d in json_data: |
|
if "fold" in d and d["fold"] == fold: |
|
val.append(d) |
|
else: |
|
tr.append(d) |
|
|
|
return tr, val |
|
|
|
|
|
def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) : |
|
train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold) |
|
if volume != None : |
|
train_files, _ = train_test_split(train_files,test_size=volume,random_state=42) |
|
|
|
train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42) |
|
|
|
validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42) |
|
return train_files, validation_files, test_files |
|
|
|
|
|
def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2): |
|
train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume) |
|
|
|
train_transform = transforms.Compose( |
|
[ |
|
transforms.LoadImaged(keys=["image", "label"]), |
|
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
|
transforms.CropForegroundd( |
|
keys=["image", "label"], |
|
source_key="image", |
|
k_divisible=[roi[0], roi[1], roi[2]], |
|
), |
|
transforms.RandSpatialCropd( |
|
keys=["image", "label"], |
|
roi_size=[roi[0], roi[1], roi[2]], |
|
random_size=False, |
|
), |
|
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), |
|
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), |
|
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), |
|
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
|
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0), |
|
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0), |
|
] |
|
) |
|
val_transform = transforms.Compose( |
|
[ |
|
transforms.LoadImaged(keys=["image", "label"]), |
|
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"), |
|
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
|
] |
|
) |
|
|
|
train_ds = Dataset(data=train_files, transform=train_transform) |
|
train_loader = DataLoader( |
|
train_ds, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=2, |
|
pin_memory=True, |
|
) |
|
val_ds = Dataset(data=validation_files, transform=val_transform) |
|
val_loader = DataLoader( |
|
val_ds, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=2, |
|
pin_memory=True, |
|
) |
|
test_ds = Dataset(data=test_files, transform=val_transform) |
|
test_loader = DataLoader( |
|
test_ds, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=2, |
|
pin_memory=True, |
|
) |
|
return train_loader, val_loader,test_loader |