File size: 5,486 Bytes
476daa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a27d55f
 
476daa5
 
 
 
 
 
 
 
 
 
 
 
286a978
476daa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286a978
476daa5
 
 
 
 
 
 
 
 
 
 
 
286a978
476daa5
 
 
286a978
476daa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286a978
 
476daa5
 
 
 
286a978
 
 
476daa5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os, sys
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)  # PYTHON > 3.3 does not allow relative referencing

import h5py
import nibabel as nib
from nilearn.image import resample_img
import re
import numpy as np
from scipy.ndimage import zoom
from skimage.measure import regionprops
from tqdm import tqdm
from argparse import ArgumentParser
from scipy.ndimage.morphology import binary_dilation, generate_binary_structure

import pandas as pd

from ddmr.utils import constants as C
from ddmr.utils.misc import segmentation_cardinal_to_ohe, segmentation_ohe_to_cardinal

SEGMENTATION_NR2LBL_LUT = {0: 'background',
                           1: 'parenchyma',
                           2: 'vessel'}
SEGMENTATION_LBL2NR_LUT = {v: k for k, v in SEGMENTATION_NR2LBL_LUT.items()}

IMG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Volumes'
SEG_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Segmentations' # '/home/jpdefrutos/workspace/LiverSegmentation_UNet3D/data/prediction'

IMG_NAME_PATTERN = '(.*).nii.gz'
SEG_NAME_PATTERN = '(.*).nii.gz'

OUT_DIRECTORY = '/mnt/EncryptedData1/Users/javier/ext_datasets/COMET_dataset/OSLO_COMET_CT/Formatted_128x128x128/w_bboxes'


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--crop', action='store_true')  # If present, args.crop = True, else args.crop = False
    parser.add_argument('--offset', type=int, default=C.MAX_AUG_DISP_ISOT + 10, help='Crop offset in mm')
    parser.add_argument('--dilate-segmentations', type=bool, default=False)
    args = parser.parse_args()

    img_list = [os.path.join(IMG_DIRECTORY, f) for f in os.listdir(IMG_DIRECTORY) if f.endswith('.nii.gz')]
    img_list.sort()

    seg_list = [os.path.join(SEG_DIRECTORY, f) for f in os.listdir(SEG_DIRECTORY) if f.endswith('.nii.gz')]
    seg_list.sort()

    zoom_file = pd.DataFrame(columns=['scale_i', 'scale_j', 'scale_k'])
    os.makedirs(OUT_DIRECTORY, exist_ok=True)
    binary_ball = generate_binary_structure(3, 1)
    for seg_file in tqdm(seg_list):
        img_name = re.match(SEG_NAME_PATTERN, os.path.split(seg_file)[-1])[1]
        img_file = os.path.join(IMG_DIRECTORY, img_name + '.nii.gz')

        img = resample_img(nib.load(img_file), np.eye(3))
        seg = resample_img(nib.load(seg_file), np.eye(3), interpolation='nearest')

        img = np.asarray(img.dataobj)
        seg = np.asarray(seg.dataobj)

        segs_are_ohe = bool(len(seg.shape) > 3 and seg.shape[3] > 1)
        bbox = [0]*6
        if args.crop:
            parenchyma = regionprops(seg[..., 0])[0]
            bbox = np.asarray(parenchyma.bbox) + [*[-args.offset]*3, *[args.offset]*3]
            # check that the new bbox is within the image limits!
            bbox[:3] = np.maximum(bbox[:3], [0, 0, 0])
            bbox[3:] = np.minimum(bbox[3:], img.shape)
            img = img[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5]]
            seg = seg[bbox[0]:bbox[3], bbox[1]:bbox[4], bbox[2]:bbox[5], ...]
        # Resize to 128x128x128
        isot_shape = img.shape

        zoom_factors = (np.asarray([128]*3) / np.asarray(img.shape)).tolist()
        img_isotropic = np.copy(img)
        img = zoom(img, zoom_factors, order=3)
        if args.dilate_segmentations:
            seg = binary_dilation(seg, binary_ball, iterations=1)
        seg_isotropic = np.copy(seg)
        seg = zoom(seg, zoom_factors + [1]*(len(seg.shape) - len(img.shape)), order=0)
        zoom_file = zoom_file.append({'scale_i': zoom_factors[0],
                                      'scale_j': zoom_factors[1],
                                      'scale_k': zoom_factors[2]}, ignore_index=True)

        # seg -> cardinal
        # seg_expanded -> OHE
        if segs_are_ohe:
            seg_expanded = seg.copy()
            seg = segmentation_ohe_to_cardinal(seg)  # Ordinal encoded. argmax returns the first ocurrence of the maximum. Hence the previoous multiplication operation
        else:
            seg_expanded = segmentation_cardinal_to_ohe(seg)

        h5_file = h5py.File(os.path.join(OUT_DIRECTORY, img_name + '.h5'), 'w')

        h5_file.create_dataset('image', data=img[..., np.newaxis], dtype=np.float32)
        h5_file.create_dataset('image_isotropic', data=img_isotropic[..., np.newaxis], dtype=np.float32)
        h5_file.create_dataset('segmentation_isotropic', data=seg_isotropic.astype(np.uint8), dtype=np.uint8)
        h5_file.create_dataset('segmentation', data=seg.astype(np.uint8), dtype=np.uint8)
        h5_file.create_dataset('segmentation_expanded', data=seg_expanded.astype(np.uint8), dtype=np.uint8)
        h5_file.create_dataset('segmentation_labels', data=np.unique(seg)[1:])  # Remove the 0 (background label)
        h5_file.create_dataset('isotropic_shape', data=isot_shape)
        if args.crop:
            h5_file.create_dataset('bounding_box_origin', data=bbox[:3])
            h5_file.create_dataset('bounding_box_shape', data=bbox[3:] - bbox[:3])
        print('{}: Segmentation labels {}'.format(img_name, np.unique(seg)[1:]))
        h5_file.close()

    zoom_file.to_csv(os.path.join(OUT_DIRECTORY, 'zoom_factors.csv'))
    print("Average")
    print(zoom_file.mean().to_list())

    print("Standard deviation")
    print(zoom_file.std().to_list())

    print("Average + STD")
    print((zoom_file.mean() + zoom_file.std()).to_list())