File size: 4,343 Bytes
0ee52bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import numpy as np
import SimpleITK as sitk
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
from HD_BET.predict_case import predict_case_3D_net
import imp
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
import os
import HD_BET


def apply_bet(img, bet, out_fname):
    img_itk = sitk.ReadImage(img)
    img_npy = sitk.GetArrayFromImage(img_itk)
    img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
    img_npy[img_bet == 0] = 0
    out = sitk.GetImageFromArray(img_npy)
    out.CopyInformation(img_itk)
    sitk.WriteImage(out, out_fname)


def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
               postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
    """

    :param mri_fnames: str or list/tuple of str
    :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
    :param mode: fast or accurate
    :param config_file: config.py
    :param device: either int (for device id) or 'cpu'
    :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
    but the largest predicted connected component. Default False
    :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
    CPU you may want to turn that off to speed things up
    :return:
    """

    list_of_param_files = []

    if mode == 'fast':
        params_file = get_params_fname(0)
        maybe_download_parameters(0)

        list_of_param_files.append(params_file)
    elif mode == 'accurate':
        for i in range(5):
            params_file = get_params_fname(i)
            maybe_download_parameters(i)

            list_of_param_files.append(params_file)
    else:
        raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)

    assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"

    cf = imp.load_source('cf', config_file)
    cf = cf.config()

    net, _ = cf.get_network(cf.val_use_train_mode, None)
    if device == "cpu":
        net = net.cpu()
    else:
        net.cuda(device)

    if not isinstance(mri_fnames, (list, tuple)):
        mri_fnames = [mri_fnames]

    if not isinstance(output_fnames, (list, tuple)):
        output_fnames = [output_fnames]

    assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"

    params = []
    for p in list_of_param_files:
        params.append(torch.load(p, map_location=lambda storage, loc: storage))

    for in_fname, out_fname in zip(mri_fnames, output_fnames):
        mask_fname = out_fname[:-7] + "_mask.nii.gz"
        if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
            print("File:", in_fname)
            print("preprocessing...")
            try:
                data, data_dict = load_and_preprocess(in_fname)
            except RuntimeError:
                print("\nERROR\nCould not read file", in_fname, "\n")
                continue
            except AssertionError as e:
                print(e)
                continue

            softmax_preds = []

            print("prediction (CNN id)...")
            for i, p in enumerate(params):
                print(i)
                net.load_state_dict(p)
                net.eval()
                net.apply(SetNetworkToVal(False, False))
                _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
                                                            cf.val_batch_size, cf.net_input_must_be_divisible_by,
                                                            cf.val_min_size, device, cf.da_mirror_axes)
                softmax_preds.append(softmax_pred[None])

            seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)

            if postprocess:
                seg = postprocess_prediction(seg)

            print("exporting segmentation...")
            save_segmentation_nifti(seg, data_dict, mask_fname)

            apply_bet(in_fname, mask_fname, out_fname)

            if not keep_mask:
                os.remove(mask_fname)