DuyTa's picture
d2f9d5b1696d723607b469880b3d5616ee5b225a5296662d3b494a9f93762c27
864c14f verified
raw
history blame
3.65 kB
import json
import os
from sklearn.model_selection import train_test_split
from monai.data import DataLoader, Dataset
from monai import transforms
def datafold_read(datalist, basedir, fold=0, key="training"):
with open(datalist) as f:
json_data = json.load(f)
json_data = json_data[key]
for d in json_data:
for k in d:
if isinstance(d[k], list):
d[k] = [os.path.join(basedir, iv) for iv in d[k]]
elif isinstance(d[k], str):
d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]
tr = []
val = []
for d in json_data:
if "fold" in d and d["fold"] == fold:
val.append(d)
else:
tr.append(d)
return tr, val
def split_train_test(datalist, basedir, fold,test_size = 0.2, volume : float = None) :
train_files, _ = datafold_read(datalist=datalist, basedir=basedir, fold=fold)
if volume != None :
train_files, _ = train_test_split(train_files,test_size=volume,random_state=42)
train_files,validation_files = train_test_split(train_files,test_size=test_size, random_state=42)
validation_files,test_files = train_test_split(validation_files,test_size=test_size, random_state=42)
return train_files, validation_files, test_files
def get_loader(batch_size, data_dir, json_list, fold, roi,volume :float = None,test_size = 0.2):
train_files,validation_files,test_files = split_train_test(datalist = json_list,basedir = data_dir,test_size=test_size,fold = fold,volume= volume)
train_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"]),
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
transforms.CropForegroundd(
keys=["image", "label"],
source_key="image",
k_divisible=[roi[0], roi[1], roi[2]],
),
transforms.RandSpatialCropd(
keys=["image", "label"],
roi_size=[roi[0], roi[1], roi[2]],
random_size=False,
),
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
]
)
val_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image", "label"]),
transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
]
)
train_ds = Dataset(data=train_files, transform=train_transform)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
)
val_ds = Dataset(data=validation_files, transform=val_transform)
val_loader = DataLoader(
val_ds,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True,
)
test_ds = Dataset(data=test_files, transform=val_transform)
test_loader = DataLoader(
test_ds,
batch_size=1,
shuffle=False,
num_workers=2,
pin_memory=True,
)
return train_loader, val_loader,test_loader