''' python3 split_data.py RegisteredImageFolderPath RegisteredLabelFolderPath Given the parameter as the path to the registered images, function creates two folders in the base directory (same level as this script), randomly putting in 70 percent of images into the train and 30 percent to the test ''' import os import glob import random import shutil from pathlib import Path from typing import Tuple import numpy as np from collections import OrderedDict import json import argparse import sys from collections import namedtuple ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute()) sys.path.insert(0, os.path.join(ROOT_DIR, 'deepatlas/utils')) from utils import ( make_if_dont_exist, load_json ) """ creates a folder at a specified folder path if it does not exists folder_path : relative path of the folder (from cur_dir) which needs to be created over_write :(default: False) if True overwrite the existing folder """ def parse_command_line(): print('Parsing Command Line Arguments') parser = argparse.ArgumentParser( description='pipeline for dataset split') parser.add_argument('--config', metavar='path to the configuration file', type=str, help='absolute path to the configuration file') parser.add_argument('--train_only', action='store_true', help='only training or training plus test') argv = parser.parse_args() return argv def split(img, seg, seg_path): label = [] unlabel = [] total = [] for i in img: name = os.path.basename(i) seg_name = os.path.join(seg_path, name) if seg_name in seg: item = {"img": i, "seg": seg_name} label.append(item) else: item = {"img": i} unlabel.append(item) total.append(item) return label, unlabel, total def main(): random.seed(2938649572) ROOT_DIR = str(Path(os.getcwd()).parent.parent.absolute()) args = parse_command_line() config = args.config config = load_json(config) config = namedtuple("config", config.keys())(*config.values()) task_id = config.task_name k_fold = config.num_fold folder_name = config.folder_name train_only = args.train_only deepatlas_path = ROOT_DIR base_path = os.path.join(deepatlas_path, "deepatlas_preprocessed") task_path = os.path.join(base_path, task_id) img_path = os.path.join(task_path, 'Training_dataset', 'images') seg_path = os.path.join(task_path, 'Training_dataset', 'labels') image_list = glob.glob(img_path + "/*.nii.gz") label_list = glob.glob(seg_path + "/*.nii.gz") label, unlabel, total = split(image_list, label_list, seg_path) piece_data = {} info_path = os.path.join(task_path, 'Training_dataset', 'data_info') folder_path = os.path.join(info_path, folder_name) make_if_dont_exist(info_path) make_if_dont_exist(folder_path) if not train_only: # compute number of scans for each fold num_images = len(image_list) num_each_fold_scan = divmod(num_images, k_fold)[0] fold_num_scan = np.repeat(num_each_fold_scan, k_fold) num_remain_scan = divmod(num_images, k_fold)[1] count = 0 while num_remain_scan > 0: fold_num_scan[count] += 1 count = (count+1) % k_fold num_remain_scan -= 1 # compute number of labels for each fold num_seg = len(label_list) num_each_fold_seg = divmod(num_seg, k_fold)[0] fold_num_seg = np.repeat(num_each_fold_seg, k_fold) num_remain_seg = divmod(num_seg, k_fold)[1] count = 0 while num_remain_seg > 0: fold_num_seg[count] += 1 count = (count+1) % k_fold num_remain_seg -= 1 random.shuffle(unlabel) random.shuffle(label) start_point = 0 start_point1 = 0 # select scans for each fold for m in range(k_fold): piece_data[f'fold_{m+1}'] = label[start_point:start_point+fold_num_seg[m]] fold_num_unlabel = fold_num_scan[m] - fold_num_seg[m] piece_data[f'fold_{m+1}'].extend(unlabel[start_point1:start_point1+fold_num_unlabel]) start_point += fold_num_seg[m] start_point1 += fold_num_unlabel info_json_path = os.path.join(folder_path, f'info.json') else: piece_data = total info_json_path = os.path.join(folder_path, f'info_train_only.json') with open(info_json_path, 'w') as f: json.dump(piece_data, f, indent=4, sort_keys=True) if os.path.exists(info_json_path): print("new info json file created!") if __name__ == '__main__': main()