''' |
python3 split_data.py RegisteredImageFolderPath RegisteredLabelFolderPath |
Given the parameter as the path to the registered images, |
function creates two folders in the base directory (same level as this script), randomly putting in |
70 percent of images into the train and 30 percent to the test |
''' |
import os |
import glob |
import random |
import shutil |
from typing import Tuple |
import numpy as np |
from collections import OrderedDict |
import json |
import argparse |
""" |
creates a folder at a specified folder path if it does not exists |
folder_path : relative path of the folder (from cur_dir) which needs to be created |
over_write :(default: False) if True overwrite the existing folder |
""" |
def parse_command_line(): |
print('---'*10) |
print('Parsing Command Line Arguments') |
parser = argparse.ArgumentParser( |
description='pipeline for dataset split') |
parser.add_argument('-bp', metavar='base path', type=str, |
help="Absolute path of the base directory") |
parser.add_argument('-ip', metavar='image path', type=str, |
help="Relative path of the image directory") |
parser.add_argument('-sp', metavar='segmentation path', type=str, |
help="Relative path of the image directory") |
parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+', |
help='a list of label name and corresponding value') |
parser.add_argument('-ti', metavar='task id', type=int, |
help='task id number') |
parser.add_argument('-tn', metavar='task name', type=str, |
help='task name') |
parser.add_argument('-kf', metavar='k-fold validation', type=int, default=5, |
help='k-fold validation') |
argv = parser.parse_args() |
return argv |
def make_if_dont_exist(folder_path, overwrite=False): |
if os.path.exists(folder_path): |
if not overwrite: |
print(f'{folder_path} exists.') |
else: |
print(f"{folder_path} overwritten") |
shutil.rmtree(folder_path) |
os.makedirs(folder_path) |
else: |
os.makedirs(folder_path) |
print(f"{folder_path} created!") |
def rename(location, oldname, newname): |
os.rename(os.path.join(location, oldname), os.path.join(location, newname)) |
def main(): |
args = parse_command_line() |
base = args.bp |
reg_data_path = args.ip |
lab_data_path = args.sp |
task_id = args.ti |
Name = args.tn |
k_fold = args.kf |
seg_list = args.sl |
base_dir = "/home/ameen" |
nnunet_dir = "nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data" |
main_dir = os.path.join(base_dir, 'nnUNet/nnunet') |
make_if_dont_exist(os.path.join(main_dir, 'nnUNet_preprocessed')) |
make_if_dont_exist(os.path.join(main_dir, 'nnUNet_trained_models')) |
os.environ['nnUNet_raw_data_base'] = os.path.join( |
main_dir, 'nnUNet_raw_data_base') |
os.environ['nnUNet_preprocessed'] = os.path.join( |
main_dir, 'nnUNet_preprocessed') |
os.environ['RESULTS_FOLDER'] = os.path.join( |
main_dir, 'nnUNet_trained_models') |
random.seed(19) |
cur_path = os.getcwd() |
image_list = glob.glob(os.path.join(base, reg_data_path) + "/*.nii.gz") |
label_list = glob.glob(os.path.join(base, lab_data_path) + "/*.nii.gz") |
num_images = len(image_list) |
num_each_fold = divmod(num_images, k_fold)[0] |
fold_num = np.repeat(num_each_fold, k_fold) |
num_remain = divmod(num_images, k_fold)[1] |
count = 0 |
while num_remain > 0: |
fold_num[count] += 1 |
count = (count+1) % 5 |
num_remain -= 1 |
random.shuffle(image_list) |
piece_data = {} |
start_point = 0 |
for m in range(k_fold): |
piece_data[f'fold_{m}'] = image_list[start_point:start_point+fold_num[m]] |
start_point += fold_num[m] |
for j in range(k_fold): |
task_name = f"Task0{task_id}_{Name}_fold{j}" |
task_id += 1 |
task_folder_name = os.path.join(base_dir, nnunet_dir, task_name) |
train_image_dir = os.path.join(task_folder_name, 'imagesTr') |
train_label_dir = os.path.join(task_folder_name, 'labelsTr') |
test_dir = os.path.join(task_folder_name, 'imagesTs') |
make_if_dont_exist(task_folder_name) |
make_if_dont_exist(train_image_dir) |
make_if_dont_exist(train_label_dir) |
make_if_dont_exist(test_dir) |
num_test = fold_num[j] |
num_train = np.sum(fold_num) - num_test |
print("Number of training subjects: ", num_train, |
"\nNumber of testing subjects:", num_test, "\nTotal:", num_images) |
p = 0 |
train_images = [] |
while p < len(piece_data): |
if p !=j: |
train_images.extend(piece_data[f'fold_{p}']) |
p+=1 |
test_images = piece_data[f'fold_{j}'] |
for i in range(len(train_images)): |
filename1 = os.path.basename(train_images[i]).split(".")[0] |
number = ''.join(filter(lambda x: x.isdigit(), filename1)) |
shutil.copy(train_images[i], train_image_dir) |
filename = os.path.basename(train_images[i]) |
rename(train_image_dir, filename, Name + "_" + number + "_0000.nii.gz") |
for label_dir in label_list: |
if label_dir.endswith(os.path.basename(train_images[i])): |
shutil.copy(label_dir, train_label_dir) |
rename(train_label_dir, filename, Name + "_" + number + '.nii.gz') |
break |
for i in range(len(test_images)): |
shutil.copy(test_images[i], test_dir) |
filename = os.path.basename(test_images[i]) |
filename1 = os.path.basename(test_images[i]).split(".")[0] |
number = ''.join(filter(lambda x: x.isdigit(), filename1)) |
rename(test_dir, filename, Name + "_" + number + "_0000.nii.gz") |
json_dict = OrderedDict() |
json_dict['name'] = task_name |
json_dict['description'] = Name |
json_dict['tensorImageSize'] = "4D" |
json_dict['reference'] = "MODIFY" |
json_dict['licence'] = "MODIFY" |
json_dict['release'] = "0.0" |
json_dict['modality'] = { |
"0": "CT" |
} |
json_dict['labels'] = { |
"0": "background", |
} |
for i in range(0, len(seg_list), 2): |
assert(seg_list[i].isdigit() == True) |
assert(seg_list[i + 1].isdigit() == False) |
json_dict['labels'].update({ |
seg_list[i]: seg_list[i + 1] |
}) |
train_ids = os.listdir(train_image_dir) |
test_ids = os.listdir(test_dir) |
json_dict['numTraining'] = len(train_ids) |
json_dict['numTest'] = len(test_ids) |
json_dict['training'] = [{'image': "./imagesTr/%s" % (i[:i.find( |
"_0000")]+'.nii.gz'), "label": "./labelsTr/%s" % (i[:i.find("_0000")]+'.nii.gz')} for i in train_ids] |
json_dict['test'] = ["./imagesTs/%s" % |
(i[:i.find("_0000")]+'.nii.gz') for i in test_ids] |
with open(os.path.join(task_folder_name, "dataset.json"), 'w') as f: |
json.dump(json_dict, f, indent=4, sort_keys=True) |
if os.path.exists(os.path.join(task_folder_name, 'dataset.json')): |
print("new json file created!") |
if __name__ == '__main__': |
main() |