AutoSeg4ETICA / preprocessing /split_data.py
Chris Xiao
upload files
c642393
'''
python3 split_data.py RegisteredImageFolderPath RegisteredLabelFolderPath
Given the parameter as the path to the registered images,
function creates two folders in the base directory (same level as this script), randomly putting in
70 percent of images into the train and 30 percent to the test
'''
import os
import glob
import random
import shutil
from typing import Tuple
import numpy as np
from collections import OrderedDict
import json
import argparse
"""
creates a folder at a specified folder path if it does not exists
folder_path : relative path of the folder (from cur_dir) which needs to be created
over_write :(default: False) if True overwrite the existing folder
"""
def parse_command_line():
print('---'*10)
print('Parsing Command Line Arguments')
parser = argparse.ArgumentParser(
description='pipeline for dataset split')
parser.add_argument('-bp', metavar='base path', type=str,
help="Absolute path of the base directory")
parser.add_argument('-ip', metavar='image path', type=str,
help="Relative path of the image directory")
parser.add_argument('-sp', metavar='segmentation path', type=str,
help="Relative path of the image directory")
parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
help='a list of label name and corresponding value')
parser.add_argument('-ti', metavar='task id', type=int,
help='task id number')
parser.add_argument('-tn', metavar='task name', type=str,
help='task name')
parser.add_argument('-kf', metavar='k-fold validation', type=int, default=5,
help='k-fold validation')
argv = parser.parse_args()
return argv
def make_if_dont_exist(folder_path, overwrite=False):
if os.path.exists(folder_path):
if not overwrite:
print(f'{folder_path} exists.')
else:
print(f"{folder_path} overwritten")
shutil.rmtree(folder_path)
os.makedirs(folder_path)
else:
os.makedirs(folder_path)
print(f"{folder_path} created!")
def rename(location, oldname, newname):
os.rename(os.path.join(location, oldname), os.path.join(location, newname))
def main():
args = parse_command_line()
base = args.bp
reg_data_path = args.ip
lab_data_path = args.sp
task_id = args.ti
Name = args.tn
k_fold = args.kf
seg_list = args.sl
base_dir = "/home/ameen"
#os.chdir(base_dir)
nnunet_dir = "nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data"
main_dir = os.path.join(base_dir, 'nnUNet/nnunet')
make_if_dont_exist(os.path.join(main_dir, 'nnUNet_preprocessed'))
make_if_dont_exist(os.path.join(main_dir, 'nnUNet_trained_models'))
os.environ['nnUNet_raw_data_base'] = os.path.join(
main_dir, 'nnUNet_raw_data_base')
os.environ['nnUNet_preprocessed'] = os.path.join(
main_dir, 'nnUNet_preprocessed')
os.environ['RESULTS_FOLDER'] = os.path.join(
main_dir, 'nnUNet_trained_models')
random.seed(19)
cur_path = os.getcwd() # current working
image_list = glob.glob(os.path.join(base, reg_data_path) + "/*.nii.gz")
label_list = glob.glob(os.path.join(base, lab_data_path) + "/*.nii.gz")
num_images = len(image_list)
# compute number of data for each fold
num_each_fold = divmod(num_images, k_fold)[0]
fold_num = np.repeat(num_each_fold, k_fold)
num_remain = divmod(num_images, k_fold)[1]
count = 0
while num_remain > 0:
fold_num[count] += 1
count = (count+1) % 5
num_remain -= 1
random.shuffle(image_list)
piece_data = {}
start_point = 0
# select scans for each fold
for m in range(k_fold):
piece_data[f'fold_{m}'] = image_list[start_point:start_point+fold_num[m]]
start_point += fold_num[m]
for j in range(k_fold):
task_name = f"Task0{task_id}_{Name}_fold{j}" # MODIFY
task_id += 1
task_folder_name = os.path.join(base_dir, nnunet_dir, task_name)
train_image_dir = os.path.join(task_folder_name, 'imagesTr')
train_label_dir = os.path.join(task_folder_name, 'labelsTr')
test_dir = os.path.join(task_folder_name, 'imagesTs')
make_if_dont_exist(task_folder_name)
make_if_dont_exist(train_image_dir)
make_if_dont_exist(train_label_dir)
make_if_dont_exist(test_dir)
# Dataset Split (70 / 30):
num_test = fold_num[j]
num_train = np.sum(fold_num) - num_test
print("Number of training subjects: ", num_train,
"\nNumber of testing subjects:", num_test, "\nTotal:", num_images)
p = 0
train_images = []
# concat all 4 folds for training
while p < len(piece_data):
if p !=j:
train_images.extend(piece_data[f'fold_{p}'])
p+=1
# select one fold for testing
test_images = piece_data[f'fold_{j}']
# prepare for nnUNet training scans and labels
for i in range(len(train_images)):
filename1 = os.path.basename(train_images[i]).split(".")[0]
number = ''.join(filter(lambda x: x.isdigit(), filename1))
# put this image to the training folder
shutil.copy(train_images[i], train_image_dir)
filename = os.path.basename(train_images[i])
rename(train_image_dir, filename, Name + "_" + number + "_0000.nii.gz")
for label_dir in label_list:
if label_dir.endswith(os.path.basename(train_images[i])):
shutil.copy(label_dir, train_label_dir)
rename(train_label_dir, filename, Name + "_" + number + '.nii.gz')
break
# prepare for nnUNet testing scans
for i in range(len(test_images)):
# put this image to the test folder
shutil.copy(test_images[i], test_dir)
filename = os.path.basename(test_images[i])
filename1 = os.path.basename(test_images[i]).split(".")[0]
number = ''.join(filter(lambda x: x.isdigit(), filename1))
rename(test_dir, filename, Name + "_" + number + "_0000.nii.gz")
# create json file
json_dict = OrderedDict()
json_dict['name'] = task_name
json_dict['description'] = Name
json_dict['tensorImageSize'] = "4D"
json_dict['reference'] = "MODIFY"
json_dict['licence'] = "MODIFY"
json_dict['release'] = "0.0"
json_dict['modality'] = {
"0": "CT"
}
json_dict['labels'] = {
"0": "background",
}
for i in range(0, len(seg_list), 2):
assert(seg_list[i].isdigit() == True)
assert(seg_list[i + 1].isdigit() == False)
json_dict['labels'].update({
seg_list[i]: seg_list[i + 1]
})
train_ids = os.listdir(train_image_dir)
test_ids = os.listdir(test_dir)
json_dict['numTraining'] = len(train_ids)
json_dict['numTest'] = len(test_ids)
json_dict['training'] = [{'image': "./imagesTr/%s" % (i[:i.find(
"_0000")]+'.nii.gz'), "label": "./labelsTr/%s" % (i[:i.find("_0000")]+'.nii.gz')} for i in train_ids]
json_dict['test'] = ["./imagesTs/%s" %
(i[:i.find("_0000")]+'.nii.gz') for i in test_ids]
with open(os.path.join(task_folder_name, "dataset.json"), 'w') as f:
json.dump(json_dict, f, indent=4, sort_keys=True)
if os.path.exists(os.path.join(task_folder_name, 'dataset.json')):
print("new json file created!")
if __name__ == '__main__':
main()