Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import SimpleITK as sitk | |
| from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti | |
| from HD_BET.predict_case import predict_case_3D_net | |
| import imp | |
| from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters | |
| import os | |
| import HD_BET | |
| def apply_bet(img, bet, out_fname): | |
| img_itk = sitk.ReadImage(img) | |
| img_npy = sitk.GetArrayFromImage(img_itk) | |
| img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet)) | |
| img_npy[img_bet == 0] = 0 | |
| out = sitk.GetImageFromArray(img_npy) | |
| out.CopyInformation(img_itk) | |
| sitk.WriteImage(out, out_fname) | |
| def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0, | |
| postprocess=False, do_tta=True, keep_mask=True, overwrite=True): | |
| """ | |
| :param mri_fnames: str or list/tuple of str | |
| :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames | |
| :param mode: fast or accurate | |
| :param config_file: config.py | |
| :param device: either int (for device id) or 'cpu' | |
| :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all | |
| but the largest predicted connected component. Default False | |
| :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use | |
| CPU you may want to turn that off to speed things up | |
| :return: | |
| """ | |
| list_of_param_files = [] | |
| if mode == 'fast': | |
| params_file = get_params_fname(0) | |
| maybe_download_parameters(0) | |
| list_of_param_files.append(params_file) | |
| elif mode == 'accurate': | |
| for i in range(5): | |
| params_file = get_params_fname(i) | |
| maybe_download_parameters(i) | |
| list_of_param_files.append(params_file) | |
| else: | |
| raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode) | |
| assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files" | |
| cf = imp.load_source('cf', config_file) | |
| cf = cf.config() | |
| net, _ = cf.get_network(cf.val_use_train_mode, None) | |
| if device == "cpu": | |
| net = net.cpu() | |
| else: | |
| net.cuda(device) | |
| if not isinstance(mri_fnames, (list, tuple)): | |
| mri_fnames = [mri_fnames] | |
| if not isinstance(output_fnames, (list, tuple)): | |
| output_fnames = [output_fnames] | |
| assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length" | |
| params = [] | |
| for p in list_of_param_files: | |
| params.append(torch.load(p, map_location=lambda storage, loc: storage)) | |
| for in_fname, out_fname in zip(mri_fnames, output_fnames): | |
| mask_fname = out_fname[:-7] + "_mask.nii.gz" | |
| if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)): | |
| print("File:", in_fname) | |
| print("preprocessing...") | |
| try: | |
| data, data_dict = load_and_preprocess(in_fname) | |
| except RuntimeError: | |
| print("\nERROR\nCould not read file", in_fname, "\n") | |
| continue | |
| except AssertionError as e: | |
| print(e) | |
| continue | |
| softmax_preds = [] | |
| print("prediction (CNN id)...") | |
| for i, p in enumerate(params): | |
| print(i) | |
| net.load_state_dict(p) | |
| net.eval() | |
| net.apply(SetNetworkToVal(False, False)) | |
| _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats, | |
| cf.val_batch_size, cf.net_input_must_be_divisible_by, | |
| cf.val_min_size, device, cf.da_mirror_axes) | |
| softmax_preds.append(softmax_pred[None]) | |
| seg = np.argmax(np.vstack(softmax_preds).mean(0), 0) | |
| if postprocess: | |
| seg = postprocess_prediction(seg) | |
| print("exporting segmentation...") | |
| save_segmentation_nifti(seg, data_dict, mask_fname) | |
| apply_bet(in_fname, mask_fname, out_fname) | |
| if not keep_mask: | |
| os.remove(mask_fname) | |