Hemaxi commited on
Commit
630e8d8
·
verified ·
1 Parent(s): 0addf11

Upload prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +163 -163
prediction.py CHANGED
@@ -1,164 +1,164 @@
1
- # example of using saved cycleGAN models for image translation
2
- #based on https://machinelearningmastery.com/cyclegan-tutorial-with-keras/
3
- from keras.models import load_model
4
- import numpy as np
5
- import tensorflow_addons as tfa
6
- from scipy.ndimage import zoom
7
- from tqdm import tqdm
8
- import warnings
9
- warnings.filterwarnings("ignore")
10
- from huggingface_hub import hf_hub_download
11
- from skimage.morphology import binary_erosion, binary_dilation
12
- from skimage import draw
13
-
14
-
15
- def predict_mask(image, dim_x, dim_y, dim_z, _resize=True, norm_=True, mode_='test', patch_size=(64,128,128,1), _step=64, _step_z=32, _patch_size_z=64):
16
-
17
- cust={'InstanceNormalization': tfa.layers.InstanceNormalization}
18
- #load the model
19
- # Download the model from Hugging Face Model Hub
20
- model_dir = hf_hub_download(repo_id="Hemaxi/3DCycleGAN", filename="CycleGANVesselSegmentation.h5")
21
- model_BtoA = load_model(model_dir, cust)
22
-
23
- print('Mode: {}'.format(mode_))
24
-
25
- _patch_size = patch_size[1]
26
- _nbslices = patch_size[0]
27
-
28
- perceqmin = 1
29
- perceqmax = 99
30
-
31
- image = ((image/(np.max(image)))*255).astype('uint8')
32
-
33
- print('Image Shape: {}'.format(image.shape))
34
- print('----------------------------------------')
35
-
36
- initial_image_x = np.shape(image)[0]
37
- initial_image_y = np.shape(image)[1]
38
- initial_image_z = np.shape(image)[2]
39
-
40
- #percentile equalization
41
- if norm_:
42
- minval = np.percentile(image, perceqmin)
43
- maxval = np.percentile(image, perceqmax)
44
- image = np.clip(image, minval, maxval)
45
- image = (((image - minval) / (maxval - minval)) * 255).astype('uint8')
46
-
47
- if _resize:
48
- image = zoom(image, (dim_x/0.333, dim_y/0.333, dim_z/0.5), order=3, mode='nearest')
49
- image = ((image/np.max(image))*255.0).astype('uint8')
50
-
51
-
52
- #image size
53
- size_y = np.shape(image)[0]
54
- size_x = np.shape(image)[1]
55
- size_depth = np.shape(image)[2]
56
- aux_sizes_or = [size_y, size_x, size_depth]
57
-
58
-
59
- #patch size
60
- new_size_y = int((size_y/_patch_size) + 1) * _patch_size
61
- new_size_x = int((size_x/_patch_size) + 1) * _patch_size
62
- new_size_z = int((size_depth/_patch_size_z) + 1) * _patch_size_z
63
- aux_sizes = [new_size_y, new_size_x, new_size_z]
64
-
65
- ## zero padding
66
- aux_img = np.random.randint(1,50,(aux_sizes[0], aux_sizes[1], aux_sizes[2]))
67
- aux_img[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] = image
68
- image = aux_img
69
- del aux_img
70
-
71
- final_mask_foreground = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2]))
72
- final_mask_background = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2]))
73
- final_mask_background = final_mask_background.astype('uint8')
74
- final_mask_foreground = final_mask_foreground.astype('uint8')
75
-
76
-
77
- total_iterations = int(image.shape[0]/_patch_size)
78
-
79
- with tqdm(total=total_iterations) as pbar:
80
- i=0
81
- while i+_patch_size<=image.shape[0]:
82
- j=0
83
- while j+_patch_size<=image.shape[1]:
84
- k=0
85
- while k+_patch_size_z<=image.shape[2]:
86
-
87
- B_real = np.zeros((1,_nbslices,_patch_size,_patch_size,1),dtype='float32')
88
- _slice = image[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z]
89
-
90
- _slice = _slice.transpose(2,0,1)
91
- _slice = np.expand_dims(_slice, axis=-1)
92
-
93
- B_real[0,:]=(_slice-127.5) /127.5
94
-
95
- A_generated = model_BtoA.predict(B_real)
96
-
97
- A_generated = (A_generated + 1)/2 #from [-1,1] to [0,1]
98
-
99
- A_generated = A_generated[0,:,:,:,0]
100
- A_generated = A_generated.transpose(1,2,0)
101
-
102
- #print(np.unique(A_generated))
103
- A_generated = (A_generated>0.5)*1
104
-
105
- A_generated = A_generated.astype('uint8')
106
-
107
- final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + A_generated
108
- final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + (1-A_generated)
109
-
110
- k=k+_step_z
111
- j=j+_step
112
- i=i+_step
113
- pbar.update(1)
114
-
115
-
116
- del _slice
117
- del A_generated
118
- del B_real
119
-
120
- final_mask = (final_mask_foreground>=final_mask_background)*1
121
-
122
- image = image[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:size_depth]
123
- print('Image Shape: {}'.format(image.shape))
124
- print('----------------------------------------')
125
-
126
- final_mask = final_mask[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]]
127
-
128
-
129
- if _resize:
130
- final_mask = zoom(final_mask, (0.333/dim_x, 0.333/dim_y, 0.5/dim_z), order=3, mode='nearest')
131
- final_mask = (final_mask*255.0).astype('uint8')
132
-
133
- final_size_x = np.shape(final_mask)[0]
134
- final_size_y = np.shape(final_mask)[1]
135
- final_size_z = np.shape(final_mask)[2]
136
-
137
- aux_mask = np.zeros((initial_image_x, initial_image_y, initial_image_z)).astype('uint8')
138
- aux_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] = final_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)]
139
-
140
- final_mask = aux_mask.copy()
141
-
142
-
143
- print('Mask Shape: {}'.format(final_mask.shape))
144
- print('----------------------------------------')
145
- final_mask = final_mask/np.max(final_mask)
146
- final_mask = final_mask*255.0
147
- final_mask = final_mask.astype('uint8')
148
-
149
-
150
- #closing operation to fill small holes
151
- mask = final_mask
152
- mask[mask!=0] = 1
153
- mask = mask.astype('uint8')
154
-
155
- ellipsoid = draw.ellipsoid(9,9,3, spacing=(1,1,1), levelset=False)
156
- ellipsoid = ellipsoid.astype('uint8')
157
- ellipsoid = ellipsoid[1:-1,1:-1,1:-1]
158
-
159
- #perform closing operation on the mask
160
- dil = binary_dilation(mask, ellipsoid)
161
- closed_mask = binary_erosion(dil, ellipsoid)
162
- closed_mask = (closed_mask*255.0).astype('uint8')
163
-
164
  return closed_mask
 
1
+ # example of using saved cycleGAN models for image translation
2
+ #based on https://machinelearningmastery.com/cyclegan-tutorial-with-keras/
3
+ from keras.models import load_model
4
+ import numpy as np
5
+ import tensorflow_addons as tfa
6
+ from scipy.ndimage import zoom
7
+ from tqdm import tqdm
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+ from huggingface_hub import hf_hub_download
11
+ from skimage.morphology import binary_erosion, binary_dilation
12
+ from skimage import draw
13
+
14
+
15
+ def predict_mask(image, dim_x, dim_y, dim_z, _resize=True, norm_=True, mode_='test', patch_size=(64,128,128,1), _step=64, _step_z=32, _patch_size_z=64):
16
+
17
+ cust={'InstanceNormalization': tfa.layers.InstanceNormalization}
18
+ #load the model
19
+ # Download the model from Hugging Face Model Hub
20
+ model_dir = hf_hub_download(repo_id="Hemaxi/3DCycleGAN", filename="CycleGANVesselSegmentation.h5")
21
+ model_BtoA = load_model(model_dir, cust)
22
+
23
+ print('Mode: {}'.format(mode_))
24
+
25
+ _patch_size = patch_size[1]
26
+ _nbslices = patch_size[0]
27
+
28
+ perceqmin = 1
29
+ perceqmax = 99
30
+
31
+ image = ((image/(np.max(image)))*255).astype('uint8')
32
+
33
+ print('Image Shape: {}'.format(image.shape))
34
+ print('----------------------------------------')
35
+
36
+ initial_image_x = np.shape(image)[0]
37
+ initial_image_y = np.shape(image)[1]
38
+ initial_image_z = np.shape(image)[2]
39
+
40
+ #percentile equalization
41
+ if norm_:
42
+ minval = np.percentile(image, perceqmin)
43
+ maxval = np.percentile(image, perceqmax)
44
+ image = np.clip(image, minval, maxval)
45
+ image = (((image - minval) / (maxval - minval)) * 255).astype('uint8')
46
+
47
+ if _resize:
48
+ image = zoom(image, (dim_x/0.333, dim_y/0.333, dim_z/0.5), order=3, mode='nearest')
49
+ image = ((image/np.max(image))*255.0).astype('uint8')
50
+
51
+
52
+ #image size
53
+ size_y = np.shape(image)[0]
54
+ size_x = np.shape(image)[1]
55
+ size_depth = np.shape(image)[2]
56
+ aux_sizes_or = [size_y, size_x, size_depth]
57
+
58
+
59
+ #patch size
60
+ new_size_y = int((size_y/_patch_size) + 1) * _patch_size
61
+ new_size_x = int((size_x/_patch_size) + 1) * _patch_size
62
+ new_size_z = int((size_depth/_patch_size_z) + 1) * _patch_size_z
63
+ aux_sizes = [new_size_y, new_size_x, new_size_z]
64
+
65
+ ## zero padding
66
+ aux_img = np.random.randint(1,50,(aux_sizes[0], aux_sizes[1], aux_sizes[2]))
67
+ aux_img[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] = image
68
+ image = aux_img
69
+ del aux_img
70
+
71
+ final_mask_foreground = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2]))
72
+ final_mask_background = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2]))
73
+ final_mask_background = final_mask_background.astype('uint8')
74
+ final_mask_foreground = final_mask_foreground.astype('uint8')
75
+
76
+
77
+ total_iterations = int(image.shape[0]/_patch_size)
78
+
79
+ with tqdm(total=total_iterations) as pbar:
80
+ i=0
81
+ while i+_patch_size<=image.shape[0]:
82
+ j=0
83
+ while j+_patch_size<=image.shape[1]:
84
+ k=0
85
+ while k+_patch_size_z<=image.shape[2]:
86
+
87
+ B_real = np.zeros((1,_nbslices,_patch_size,_patch_size,1),dtype='float32')
88
+ _slice = image[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z]
89
+
90
+ _slice = _slice.transpose(2,0,1)
91
+ _slice = np.expand_dims(_slice, axis=-1)
92
+
93
+ B_real[0,:]=(_slice-127.5) /127.5
94
+
95
+ A_generated = model_BtoA.predict(B_real)
96
+
97
+ A_generated = (A_generated + 1)/2 #from [-1,1] to [0,1]
98
+
99
+ A_generated = A_generated[0,:,:,:,0]
100
+ A_generated = A_generated.transpose(1,2,0)
101
+
102
+ #print(np.unique(A_generated))
103
+ A_generated = (A_generated>0.5)*1
104
+
105
+ A_generated = A_generated.astype('uint8')
106
+
107
+ final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + A_generated
108
+ final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + (1-A_generated)
109
+
110
+ k=k+_step_z
111
+ j=j+_step
112
+ i=i+_step
113
+ pbar.update(1)
114
+
115
+
116
+ del _slice
117
+ del A_generated
118
+ del B_real
119
+
120
+ final_mask = (final_mask_foreground>=final_mask_background)*1
121
+
122
+ image = image[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:size_depth]
123
+ print('Image Shape: {}'.format(image.shape))
124
+ print('----------------------------------------')
125
+
126
+ final_mask = final_mask[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]]
127
+
128
+
129
+ if _resize:
130
+ final_mask = zoom(final_mask, (0.333/dim_x, 0.333/dim_y, 0.5/dim_z), order=3, mode='nearest')
131
+ final_mask = (final_mask*255.0).astype('uint8')
132
+
133
+ final_size_x = np.shape(final_mask)[0]
134
+ final_size_y = np.shape(final_mask)[1]
135
+ final_size_z = np.shape(final_mask)[2]
136
+
137
+ aux_mask = np.zeros((initial_image_x, initial_image_y, initial_image_z)).astype('uint8')
138
+ aux_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] = final_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)]
139
+
140
+ final_mask = aux_mask.copy()
141
+
142
+
143
+ print('Mask Shape: {}'.format(final_mask.shape))
144
+ print('----------------------------------------')
145
+ final_mask = final_mask/np.max(final_mask)
146
+ final_mask = final_mask*255.0
147
+ final_mask = final_mask.astype('uint8')
148
+
149
+
150
+ #closing operation to fill small holes
151
+ mask = final_mask
152
+ mask[mask!=0] = 1
153
+ mask = mask.astype('uint8')
154
+
155
+ ellipsoid = draw.ellipsoid(9,9,3, spacing=(1,1,1), levelset=False)
156
+ ellipsoid = ellipsoid.astype('uint8')
157
+ ellipsoid = ellipsoid[1:-1,1:-1,1:-1]
158
+
159
+ #perform closing operation on the mask
160
+ dil = binary_dilation(mask, ellipsoid)
161
+ closed_mask = binary_erosion(dil, ellipsoid)
162
+ closed_mask = (closed_mask*255.0).astype('uint8')
163
+
164
  return closed_mask