matiasky commited on
Commit
50f870e
·
1 Parent(s): 0cd7e3b

update app

Browse files
Files changed (4) hide show
  1. app.py +6 -310
  2. utils/plotting.py +77 -0
  3. utils/segmentation.py +259 -0
  4. utils/utils.py +0 -103
app.py CHANGED
@@ -1,307 +1,5 @@
1
- import numpy as np
2
  import gradio as gr
3
- import cv2
4
- import matplotlib.pyplot as plt
5
- from PIL import Image
6
- from io import BytesIO
7
- from mpl_toolkits.axes_grid1 import make_axes_locatable
8
-
9
- from models.HybridGNet2IGSC import Hybrid
10
- from utils.utils import scipy_to_torch_sparse, genMatrixesLungsHeart
11
- import scipy.sparse as sp
12
- import torch
13
- import pandas as pd
14
- from zipfile import ZipFile
15
-
16
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
17
- hybrid = None
18
-
19
- def getDenseMask(landmarks, h, w):
20
-
21
- RL = landmarks[0:44]
22
- LL = landmarks[44:94]
23
- H = landmarks[94:]
24
-
25
- img = np.zeros([h, w], dtype = 'uint8')
26
-
27
- RL = RL.reshape(-1, 1, 2).astype('int')
28
- LL = LL.reshape(-1, 1, 2).astype('int')
29
- H = H.reshape(-1, 1, 2).astype('int')
30
-
31
- img = cv2.drawContours(img, [RL], -1, 1, -1)
32
- img = cv2.drawContours(img, [LL], -1, 1, -1)
33
- img = cv2.drawContours(img, [H], -1, 2, -1)
34
-
35
- return img
36
-
37
- def getMasks(landmarks, h, w):
38
-
39
- RL = landmarks[0:44]
40
- LL = landmarks[44:94]
41
- H = landmarks[94:]
42
-
43
- RL = RL.reshape(-1, 1, 2).astype('int')
44
- LL = LL.reshape(-1, 1, 2).astype('int')
45
- H = H.reshape(-1, 1, 2).astype('int')
46
-
47
- RL_mask = np.zeros([h, w], dtype = 'uint8')
48
- LL_mask = np.zeros([h, w], dtype = 'uint8')
49
- H_mask = np.zeros([h, w], dtype = 'uint8')
50
-
51
- RL_mask = cv2.drawContours(RL_mask, [RL], -1, 255, -1)
52
- LL_mask = cv2.drawContours(LL_mask, [LL], -1, 255, -1)
53
- H_mask = cv2.drawContours(H_mask, [H], -1, 255, -1)
54
-
55
- return RL_mask, LL_mask, H_mask
56
-
57
- def drawOnTop(img, landmarks, original_shape):
58
- h, w = original_shape
59
- output = getDenseMask(landmarks, h, w)
60
-
61
- image = np.zeros([h, w, 3])
62
- image[:,:,0] = img + 0.3 * (output == 1).astype('float') - 0.1 * (output == 2).astype('float')
63
- image[:,:,1] = img + 0.3 * (output == 2).astype('float') - 0.1 * (output == 1).astype('float')
64
- image[:,:,2] = img - 0.1 * (output == 1).astype('float') - 0.2 * (output == 2).astype('float')
65
-
66
- image = np.clip(image, 0, 1)
67
-
68
- RL, LL, H = landmarks[0:44], landmarks[44:94], landmarks[94:]
69
-
70
- # Draw the landmarks as dots
71
-
72
- for l in RL:
73
- image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 0, 1), -1)
74
- for l in LL:
75
- image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 0, 1), -1)
76
- for l in H:
77
- image = cv2.circle(image, (int(l[0]), int(l[1])), 5, (1, 1, 0), -1)
78
-
79
- return image
80
-
81
-
82
- def loadModel(device):
83
- A, AD, D, U = genMatrixesLungsHeart()
84
- N1 = A.shape[0]
85
- N2 = AD.shape[0]
86
-
87
- A = sp.csc_matrix(A).tocoo()
88
- AD = sp.csc_matrix(AD).tocoo()
89
- D = sp.csc_matrix(D).tocoo()
90
- U = sp.csc_matrix(U).tocoo()
91
-
92
- D_ = [D.copy()]
93
- U_ = [U.copy()]
94
- A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
95
-
96
- config = {}
97
- config['n_nodes'] = [N1, N1, N1, N2, N2, N2]
98
-
99
- A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_, D_, U_))
100
-
101
- config['latents'] = 64
102
- config['inputsize'] = 1024
103
-
104
- f = 32
105
- config['filters'] = [2, f, f, f, f//2, f//2, f//2]
106
- config['skip_features'] = f
107
- config['eval_sampling'] = True
108
-
109
- hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
110
- hybrid.load_state_dict(torch.load("weights/weights.pt", map_location=torch.device(device)))
111
- hybrid.eval()
112
-
113
- return hybrid
114
-
115
-
116
- def pad_to_square(img):
117
- h, w = img.shape[:2]
118
-
119
- if h > w:
120
- padw = (h - w)
121
- auxw = padw % 2
122
- img = np.pad(img, ((0, 0), (padw//2, padw//2 + auxw)), 'constant')
123
-
124
- padh = 0
125
- auxh = 0
126
-
127
- else:
128
- padh = (w - h)
129
- auxh = padh % 2
130
- img = np.pad(img, ((padh//2, padh//2 + auxh), (0, 0)), 'constant')
131
-
132
- padw = 0
133
- auxw = 0
134
-
135
- return img, (padh, padw, auxh, auxw)
136
-
137
-
138
- def preprocess(input_img):
139
- img, padding = pad_to_square(input_img)
140
-
141
- h, w = img.shape[:2]
142
- if h != 1024 or w != 1024:
143
- img = cv2.resize(img, (1024, 1024), interpolation = cv2.INTER_CUBIC)
144
-
145
- return img, (h, w, padding)
146
-
147
-
148
- def removePreprocess(output, info):
149
- """
150
- output: np.array of shape (n_samples, N_landmarks, 2)
151
- info: (h, w, padding)
152
- """
153
- h, w, padding = info
154
- padh, padw, auxh, auxw = padding
155
-
156
- # Scale
157
- if h != 1024 or w != 1024:
158
- output = output * h
159
- else:
160
- output = output * 1024
161
-
162
- # Subtract padding
163
- output[:, :, 0] = output[:, :, 0] - padw//2
164
- output[:, :, 1] = output[:, :, 1] - padh//2
165
-
166
- return output
167
-
168
-
169
- def zip_files(files):
170
- with ZipFile("complete_results.zip", "w") as zipObj:
171
- for idx, file in enumerate(files):
172
- zipObj.write(file, arcname=file.split("/")[-1])
173
- return "complete_results.zip"
174
-
175
-
176
- def plot_landmarks_with_uncertainty(img, landmarks, uncertainty, figsize=(6,6)):
177
- # Get dense mask as in drawOnTop
178
- h, w = img.shape[:2]
179
- dense_mask = getDenseMask(landmarks, h, w)
180
-
181
- # Start with image overlay
182
- overlay = np.zeros([h, w, 3])
183
- overlay[:,:,0] = img + 0.3 * (dense_mask == 1).astype('float') - 0.1 * (dense_mask == 2).astype('float')
184
- overlay[:,:,1] = img + 0.3 * (dense_mask == 2).astype('float') - 0.1 * (dense_mask == 1).astype('float')
185
- overlay[:,:,2] = img - 0.1 * (dense_mask == 1).astype('float') - 0.2 * (dense_mask == 2).astype('float')
186
- overlay = np.clip(overlay, 0, 1)
187
-
188
- # Plot
189
- fig, ax = plt.subplots(figsize=figsize)
190
- ax.imshow(overlay)
191
-
192
- # Scatter landmarks colored by uncertainty
193
- scatter = ax.scatter(
194
- landmarks[:,0], landmarks[:,1],
195
- c=uncertainty, cmap='hot',
196
- s=50, vmin=0, vmax=np.max(uncertainty)
197
- )
198
-
199
- # Colorbar
200
- divider = make_axes_locatable(ax)
201
- cax = divider.append_axes("right", size="5%", pad=0.05)
202
- plt.colorbar(scatter, cax=cax, label='Node uncertainty')
203
-
204
- ax.set_xlim(0, img.shape[1])
205
- ax.set_ylim(img.shape[0], 0)
206
- ax.axis('off')
207
- fig.tight_layout()
208
- return fig
209
-
210
- def segment(input_img, noise_std=0.0):
211
- global hybrid, device
212
-
213
- if hybrid is None:
214
- hybrid = loadModel(device)
215
-
216
- # ------------------ HANDLE SKETCH / INPAINT ------------------
217
- if isinstance(input_img, dict):
218
- original = input_img["image"].astype(np.float32)/255.0
219
- mask = input_img["mask"]
220
- if mask.ndim == 3:
221
- mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
222
- mask = mask.astype(np.float32)/255.0
223
- mask = 1.0 - mask # black strokes
224
- input_img = np.minimum(original, mask)
225
- else:
226
- input_img = input_img.astype(np.float32)/255.0
227
-
228
- # ------------------ ADD GAUSSIAN NOISE ------------------
229
- if noise_std > 0:
230
- noise = np.random.normal(0, noise_std, input_img.shape)
231
- input_img = np.clip(input_img + noise, 0.0, 1.0)
232
-
233
- # ------------------ PREPROCESS & PREDICT ------------------
234
- original_shape = input_img.shape[:2]
235
- img, (h, w, padding) = preprocess(input_img)
236
- data = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device).float()
237
- n_samples = 100
238
-
239
- with torch.no_grad():
240
- mu, log_var, conv6, conv5 = hybrid.encode(data)
241
- latent_var = np.exp(log_var.cpu().numpy())
242
-
243
- # Sample N latent vectors to decode
244
- zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)]
245
- z_exp = torch.stack(zs, dim=0)
246
-
247
- # Expand skip connections to match batch size
248
- conv6_exp = conv6.repeat(n_samples, 1, 1, 1)
249
- conv5_exp = conv5.repeat(n_samples, 1, 1, 1)
250
-
251
- # Decode in batch
252
- output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp)
253
- output = output.cpu().numpy().reshape(n_samples, -1, 2)
254
- output = removePreprocess(output, (h, w, padding)).astype('int')
255
-
256
- # Compute mean and std per node
257
- means = np.mean(output, axis=0)
258
- stds = np.std(output, axis=0)
259
-
260
- # ------------------ SAVE LANDMARKS & MASKS ------------------
261
- RL = means[0:44]
262
- LL = means[44:94]
263
- H = means[94:]
264
-
265
- np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
266
- np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
267
- np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
268
-
269
- RL_mask, LL_mask, H_mask = getMasks(means, original_shape[0], original_shape[1])
270
- cv2.imwrite("tmp/RL_mask.png", RL_mask)
271
- cv2.imwrite("tmp/LL_mask.png", LL_mask)
272
- cv2.imwrite("tmp/H_mask.png", H_mask)
273
-
274
- RL_std = stds[0:44]
275
- LL_std = stds[44:94]
276
- H_std = stds[94:]
277
-
278
- # Save as text files
279
- np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f")
280
- np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f")
281
- np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f")
282
-
283
- zipf = zip_files([
284
- "tmp/RL_landmarks.txt", "tmp/LL_landmarks.txt", "tmp/H_landmarks.txt",
285
- "tmp/RL_mask.png", "tmp/LL_mask.png", "tmp/H_mask.png",
286
- "tmp/RL_std.txt", "tmp/LL_std.txt", "tmp/H_std.txt"
287
- ])
288
-
289
- # ------------------ RANDOM UNCERTAINTY ------------------
290
- node_uncertainty = np.mean(stds, axis=1)
291
- fig = plot_landmarks_with_uncertainty(input_img, means, node_uncertainty)
292
-
293
- output_path = "tmp/segmentation_with_uncertainty.png"
294
- fig.savefig(output_path, format="png", dpi=150)
295
- plt.close(fig)
296
-
297
- return output_path, [
298
- "tmp/RL_landmarks.txt", "tmp/LL_landmarks.txt", "tmp/H_landmarks.txt",
299
- "tmp/RL_mask.png", "tmp/LL_mask.png", "tmp/H_mask.png",
300
- "tmp/RL_std.txt", "tmp/LL_std.txt", "tmp/H_std.txt",
301
- zipf
302
- ]
303
-
304
-
305
 
306
  # ------------------------- GRADIO -------------------------
307
  if __name__ == "__main__":
@@ -310,19 +8,18 @@ if __name__ == "__main__":
310
 
311
  with gr.Tab("Segment Image"):
312
  with gr.Row():
313
- with gr.Column():
314
  image_input = gr.Image(
315
  type="numpy",
316
  tool="sketch",
317
  image_mode="L",
318
- height=512,
319
- shape=(512, 512)
320
  )
321
 
322
  noise_slider = gr.Slider(
323
  label="Gaussian Noise Std Dev",
324
  minimum=0.0,
325
- maximum=0.25, # max = strong noise / blurry
326
  step=0.01,
327
  value=0.0
328
  )
@@ -336,9 +33,8 @@ if __name__ == "__main__":
336
  'utils/example3.png','utils/example4.jpg'
337
  ])
338
 
339
- with gr.Column():
340
- #image_output = gr.Image(type="numpy", height=750)
341
- image_output = gr.Image(type="filepath", height=512)
342
  results = gr.File()
343
 
344
  gr.Markdown("""
 
 
1
  import gradio as gr
2
+ from utils.segmentation import segment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # ------------------------- GRADIO -------------------------
5
  if __name__ == "__main__":
 
8
 
9
  with gr.Tab("Segment Image"):
10
  with gr.Row():
11
+ with gr.Column(scale=1):
12
  image_input = gr.Image(
13
  type="numpy",
14
  tool="sketch",
15
  image_mode="L",
16
+ height=450,
 
17
  )
18
 
19
  noise_slider = gr.Slider(
20
  label="Gaussian Noise Std Dev",
21
  minimum=0.0,
22
+ maximum=0.25,
23
  step=0.01,
24
  value=0.0
25
  )
 
33
  'utils/example3.png','utils/example4.jpg'
34
  ])
35
 
36
+ with gr.Column(scale=2):
37
+ image_output = gr.Image(type="filepath", height=450)
 
38
  results = gr.File()
39
 
40
  gr.Markdown("""
utils/plotting.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
5
+
6
+ def getDenseMask(landmarks, h, w):
7
+ RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
8
+ img = np.zeros([h, w], dtype='uint8')
9
+ RL = RL.reshape(-1, 1, 2).astype('int')
10
+ LL = LL.reshape(-1, 1, 2).astype('int')
11
+ H = H.reshape(-1, 1, 2).astype('int')
12
+ img = cv2.drawContours(img, [RL], -1, 1, -1)
13
+ img = cv2.drawContours(img, [LL], -1, 1, -1)
14
+ img = cv2.drawContours(img, [H], -1, 2, -1)
15
+ return img
16
+
17
+ def drawOnTop(img, landmarks, original_shape):
18
+ h, w = original_shape
19
+ output = getDenseMask(landmarks, h, w)
20
+ image = np.zeros([h,w,3])
21
+ image[:,:,0] = img + 0.3*(output==1).astype('float') - 0.1*(output==2).astype('float')
22
+ image[:,:,1] = img + 0.3*(output==2).astype('float') - 0.1*(output==1).astype('float')
23
+ image[:,:,2] = img - 0.1*(output==1).astype('float') - 0.2*(output==2).astype('float')
24
+ image = np.clip(image,0,1)
25
+ RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
26
+ for l in RL: image = cv2.circle(image,(int(l[0]),int(l[1])),5,(1,0,1),-1)
27
+ for l in LL: image = cv2.circle(image,(int(l[0]),int(l[1])),5,(1,0,1),-1)
28
+ for l in H: image = cv2.circle(image,(int(l[0]),int(l[1])),5,(1,1,0),-1)
29
+ return image
30
+
31
+ def create_overlay(img, landmarks):
32
+ h, w = img.shape[:2]
33
+ dense_mask = getDenseMask(landmarks, h, w)
34
+ overlay = np.zeros([h, w, 3])
35
+
36
+ overlay[:,:,0] = img + 0.3 * (dense_mask == 1).astype('float') - 0.1 * (dense_mask == 2).astype('float')
37
+ overlay[:,:,1] = img + 0.3 * (dense_mask == 2).astype('float') - 0.1 * (dense_mask == 1).astype('float')
38
+ overlay[:,:,2] = img - 0.1 * (dense_mask == 1).astype('float') - 0.2 * (dense_mask == 2).astype('float')
39
+ overlay = np.clip(overlay, 0, 1)
40
+
41
+ return overlay
42
+
43
+ def plot_side_by_side_comparison(img_orig, means_orig, uncertainty_orig, img_corr, means_corr, uncertainty_corr):
44
+
45
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 7))
46
+
47
+ fig.set_constrained_layout(True)
48
+
49
+ vmax = max(np.max(np.mean(uncertainty_orig, axis=1)), np.max(np.mean(uncertainty_corr, axis=1)))
50
+
51
+ # --- Original ---
52
+ overlay_orig = create_overlay(img_orig, means_orig)
53
+ ax1.imshow(overlay_orig)
54
+ scatter1 = ax1.scatter(
55
+ means_orig[:, 0], means_orig[:, 1],
56
+ c=np.mean(uncertainty_orig, axis=1),
57
+ cmap='hot', s=50, vmin=0, vmax=vmax
58
+ )
59
+ ax1.set_title("Original", fontsize=16, pad=10)
60
+ ax1.axis('off')
61
+
62
+ # --- Corrupted ---
63
+ overlay_corr = create_overlay(img_corr, means_corr)
64
+ ax2.imshow(overlay_corr)
65
+ scatter2 = ax2.scatter(
66
+ means_corr[:, 0], means_corr[:, 1],
67
+ c=np.mean(uncertainty_corr, axis=1),
68
+ cmap='hot', s=50, vmin=0, vmax=vmax
69
+ )
70
+ ax2.set_title("Corrupted", fontsize=16, pad=10)
71
+ ax2.axis('off')
72
+
73
+ # Shared colorbar
74
+ cbar = fig.colorbar(scatter2, ax=[ax1, ax2], fraction=0.046, pad=0.01, shrink=0.85)
75
+ cbar.ax.tick_params(labelsize=10)
76
+
77
+ return fig
utils/segmentation.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ import scipy.sparse as sp
5
+ import sys
6
+ import os
7
+ from zipfile import ZipFile
8
+ from .plotting import plot_side_by_side_comparison
9
+
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+ from models.HybridGNet2IGSC import Hybrid
12
+
13
+ hybrid = None
14
+
15
+ def scipy_to_torch_sparse(scp_matrix):
16
+ values = scp_matrix.data
17
+ indices = np.vstack((scp_matrix.row, scp_matrix.col))
18
+ i = torch.LongTensor(indices)
19
+ v = torch.FloatTensor(values)
20
+ shape = scp_matrix.shape
21
+
22
+ sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
23
+ return sparse_tensor
24
+
25
+ ## Adjacency Matrix
26
+ def mOrgan(N):
27
+ sub = np.zeros([N, N])
28
+ for i in range(0, N):
29
+ sub[i, i-1] = 1
30
+ sub[i, (i+1)%N] = 1
31
+ return sub
32
+
33
+ ## Downsampling Matrix
34
+ def mOrganD(N):
35
+ N2 = int(np.ceil(N/2))
36
+ sub = np.zeros([N2, N])
37
+
38
+ for i in range(0, N2):
39
+ if (2*i+1) == N:
40
+ sub[i, 2*i] = 1
41
+ else:
42
+ sub[i, 2*i] = 1/2
43
+ sub[i, 2*i+1] = 1/2
44
+
45
+ return sub
46
+
47
+ def mOrganU(N):
48
+ N2 = int(np.ceil(N/2))
49
+ sub = np.zeros([N, N2])
50
+
51
+ for i in range(0, N):
52
+ if i % 2 == 0:
53
+ sub[i, i//2] = 1
54
+ else:
55
+ sub[i, i//2] = 1/2
56
+ sub[i, (i//2 + 1) % N2] = 1/2
57
+
58
+ return sub
59
+
60
+ def genMatrixesLungsHeart():
61
+ RLUNG = 44
62
+ LLUNG = 50
63
+ HEART = 26
64
+
65
+ Asub1 = mOrgan(RLUNG)
66
+ Asub2 = mOrgan(LLUNG)
67
+ Asub3 = mOrgan(HEART)
68
+
69
+ ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
70
+ ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
71
+ ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
72
+
73
+ Dsub1 = mOrganD(RLUNG)
74
+ Dsub2 = mOrganD(LLUNG)
75
+ Dsub3 = mOrganD(HEART)
76
+
77
+ Usub1 = mOrganU(RLUNG)
78
+ Usub2 = mOrganU(LLUNG)
79
+ Usub3 = mOrganU(HEART)
80
+
81
+ p1 = RLUNG
82
+ p2 = p1 + LLUNG
83
+ p3 = p2 + HEART
84
+
85
+ p1_ = int(np.ceil(RLUNG / 2))
86
+ p2_ = p1_ + int(np.ceil(LLUNG / 2))
87
+ p3_ = p2_ + int(np.ceil(HEART / 2))
88
+
89
+ A = np.zeros([p3, p3])
90
+
91
+ A[:p1, :p1] = Asub1
92
+ A[p1:p2, p1:p2] = Asub2
93
+ A[p2:p3, p2:p3] = Asub3
94
+
95
+ AD = np.zeros([p3_, p3_])
96
+
97
+ AD[:p1_, :p1_] = ADsub1
98
+ AD[p1_:p2_, p1_:p2_] = ADsub2
99
+ AD[p2_:p3_, p2_:p3_] = ADsub3
100
+
101
+ D = np.zeros([p3_, p3])
102
+
103
+ D[:p1_, :p1] = Dsub1
104
+ D[p1_:p2_, p1:p2] = Dsub2
105
+ D[p2_:p3_, p2:p3] = Dsub3
106
+
107
+ U = np.zeros([p3, p3_])
108
+
109
+ U[:p1, :p1_] = Usub1
110
+ U[p1:p2, p1_:p2_] = Usub2
111
+ U[p2:p3, p2_:p3_] = Usub3
112
+
113
+ return A, AD, D, U
114
+
115
+ def zip_files(files, output_name="complete_results.zip"):
116
+ with ZipFile(output_name, "w") as zipObj:
117
+ for file in files:
118
+ zipObj.write(file, arcname=file.split("/")[-1])
119
+ return output_name
120
+
121
+ def getMasks(landmarks, h, w):
122
+ RL, LL, H = landmarks[:44], landmarks[44:94], landmarks[94:]
123
+ RL_mask, LL_mask, H_mask = [np.zeros([h, w], dtype='uint8') for _ in range(3)]
124
+ RL_mask = cv2.drawContours(RL_mask, [RL.reshape(-1,1,2).astype('int')], -1, 255, -1)
125
+ LL_mask = cv2.drawContours(LL_mask, [LL.reshape(-1,1,2).astype('int')], -1, 255, -1)
126
+ H_mask = cv2.drawContours(H_mask, [H.reshape(-1,1,2).astype('int')], -1, 255, -1)
127
+ return RL_mask, LL_mask, H_mask
128
+
129
+ def pad_to_square(img):
130
+ h, w = img.shape[:2]
131
+ if h > w:
132
+ padw = h - w
133
+ auxw = padw % 2
134
+ img = np.pad(img, ((0,0),(padw//2, padw//2+auxw)), 'constant')
135
+ return img, (0, padw, 0, auxw)
136
+ else:
137
+ padh = w - h
138
+ auxh = padh % 2
139
+ img = np.pad(img, ((padh//2, padh//2+auxh),(0,0)), 'constant')
140
+ return img, (padh, 0, auxh, 0)
141
+
142
+ def preprocess(img):
143
+ img, padding = pad_to_square(img)
144
+ h, w = img.shape[:2]
145
+ if h != 1024 or w != 1024:
146
+ img = cv2.resize(img, (1024,1024), interpolation=cv2.INTER_CUBIC)
147
+ return img, (h, w, padding)
148
+
149
+ def removePreprocess(output, info):
150
+ h, w, padding = info
151
+ padh, padw, auxh, auxw = padding
152
+ if h != 1024 or w != 1024:
153
+ output = output * h
154
+ else:
155
+ output = output * 1024
156
+ output[:,:,0] -= padw//2
157
+ output[:,:,1] -= padh//2
158
+ return output
159
+
160
+ def loadModel(device):
161
+ global hybrid
162
+ A, AD, D, U = genMatrixesLungsHeart()
163
+ N1, N2 = A.shape[0], AD.shape[0]
164
+ A, AD, D, U = [sp.csc_matrix(x).tocoo() for x in [A, AD, D, U]]
165
+ D_, U_ = [D.copy()], [U.copy()]
166
+ A_ = [A.copy(), A.copy(), A.copy(), AD.copy(), AD.copy(), AD.copy()]
167
+ config = {'n_nodes':[N1,N1,N1,N2,N2,N2], 'latents':64, 'inputsize':1024,
168
+ 'filters':[2,32,32,32,16,16,16], 'skip_features':32, 'eval_sampling':True}
169
+ A_t, D_t, U_t = ([scipy_to_torch_sparse(x).to(device) for x in X] for X in (A_,D_,U_))
170
+ hybrid = Hybrid(config.copy(), D_t, U_t, A_t).to(device)
171
+ hybrid.load_state_dict(torch.load("weights/weights.pt", map_location=device))
172
+ hybrid.eval()
173
+ return hybrid
174
+
175
+ def predict_landmarks(img, n_samples=100):
176
+ global hybrid
177
+ img_proc, (h, w, padding) = preprocess(img)
178
+ data = torch.from_numpy(img_proc).unsqueeze(0).unsqueeze(0).to(next(hybrid.parameters()).device).float()
179
+ with torch.no_grad():
180
+ mu, log_var, conv6, conv5 = hybrid.encode(data)
181
+ zs = [hybrid.sampling(mu, log_var) for _ in range(n_samples)]
182
+ z_exp = torch.stack(zs, dim=0)
183
+ conv6_exp, conv5_exp = conv6.repeat(n_samples,1,1,1), conv5.repeat(n_samples,1,1,1)
184
+ output, _, _ = hybrid.decode(z_exp, conv6_exp, conv5_exp)
185
+ output = output.cpu().numpy().reshape(n_samples,-1,2)
186
+ output = removePreprocess(output, (h,w,padding)).astype('int')
187
+ means, stds = np.mean(output,axis=0), np.std(output,axis=0)
188
+ return means, stds
189
+
190
+
191
+ def segment(input_img, noise_std=0.0):
192
+ """
193
+ input_img: dict with keys "image" (numpy array) and optionally "mask"
194
+ noise_std: standard deviation of Gaussian noise to add for robustness
195
+ Returns: path to comparison figure, list of saved files
196
+ """
197
+ global hybrid
198
+
199
+ if hybrid is None:
200
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
201
+ hybrid = loadModel(device)
202
+
203
+ # Original image and corrupted version
204
+ img_orig = input_img["image"].astype(np.float32) / 255.0
205
+ mask = input_img.get("mask", None)
206
+ if mask is not None:
207
+ mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
208
+ mask = 1.0 - mask
209
+ img_corr = np.minimum(img_orig, mask)
210
+ else:
211
+ img_corr = img_orig.copy()
212
+
213
+ if noise_std > 0:
214
+ noise = np.random.normal(0, noise_std, img_corr.shape)
215
+ img_corr = np.clip(img_corr + noise, 0.0, 1.0)
216
+
217
+ # Predict landmarks
218
+ means_orig, stds_orig = predict_landmarks(img_orig)
219
+ means_corr, stds_corr = predict_landmarks(img_corr)
220
+
221
+ # Save landmarks and masks
222
+ os.makedirs("tmp", exist_ok=True)
223
+
224
+ RL, LL, H = means_orig[:44], means_orig[44:94], means_orig[94:]
225
+ np.savetxt("tmp/RL_landmarks.txt", RL, delimiter=" ", fmt="%d")
226
+ np.savetxt("tmp/LL_landmarks.txt", LL, delimiter=" ", fmt="%d")
227
+ np.savetxt("tmp/H_landmarks.txt", H, delimiter=" ", fmt="%d")
228
+
229
+ RL_mask, LL_mask, H_mask = getMasks(means_orig, img_orig.shape[0], img_orig.shape[1])
230
+ cv2.imwrite("tmp/RL_mask.png", RL_mask)
231
+ cv2.imwrite("tmp/LL_mask.png", LL_mask)
232
+ cv2.imwrite("tmp/H_mask.png", H_mask)
233
+
234
+ RL_std, LL_std, H_std = stds_orig[:44], stds_orig[44:94], stds_orig[94:]
235
+ np.savetxt("tmp/RL_std.txt", RL_std, delimiter=" ", fmt="%.4f")
236
+ np.savetxt("tmp/LL_std.txt", LL_std, delimiter=" ", fmt="%.4f")
237
+ np.savetxt("tmp/H_std.txt", H_std, delimiter=" ", fmt="%.4f")
238
+
239
+ zipf = zip_files([
240
+ "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
241
+ "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
242
+ "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt"
243
+ ])
244
+
245
+ # Optional: plot side-by-side comparison
246
+ fig = plot_side_by_side_comparison(img_orig, means_orig, stds_orig, img_corr, means_corr, stds_corr)
247
+ output_path = "tmp/segmentation_comparison.png"
248
+ fig.savefig(output_path, dpi=300)
249
+ import matplotlib.pyplot as plt
250
+ plt.close(fig)
251
+
252
+ saved_files = [
253
+ "tmp/RL_landmarks.txt","tmp/LL_landmarks.txt","tmp/H_landmarks.txt",
254
+ "tmp/RL_mask.png","tmp/LL_mask.png","tmp/H_mask.png",
255
+ "tmp/RL_std.txt","tmp/LL_std.txt","tmp/H_std.txt",
256
+ zipf
257
+ ]
258
+
259
+ return output_path, saved_files
utils/utils.py DELETED
@@ -1,103 +0,0 @@
1
- import numpy as np
2
- import scipy.sparse as sp
3
- import torch
4
-
5
- def scipy_to_torch_sparse(scp_matrix):
6
- values = scp_matrix.data
7
- indices = np.vstack((scp_matrix.row, scp_matrix.col))
8
- i = torch.LongTensor(indices)
9
- v = torch.FloatTensor(values)
10
- shape = scp_matrix.shape
11
-
12
- sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape))
13
- return sparse_tensor
14
-
15
- ## Adjacency Matrix
16
- def mOrgan(N):
17
- sub = np.zeros([N, N])
18
- for i in range(0, N):
19
- sub[i, i-1] = 1
20
- sub[i, (i+1)%N] = 1
21
- return sub
22
-
23
- ## Downsampling Matrix
24
- def mOrganD(N):
25
- N2 = int(np.ceil(N/2))
26
- sub = np.zeros([N2, N])
27
-
28
- for i in range(0, N2):
29
- if (2*i+1) == N:
30
- sub[i, 2*i] = 1
31
- else:
32
- sub[i, 2*i] = 1/2
33
- sub[i, 2*i+1] = 1/2
34
-
35
- return sub
36
-
37
- def mOrganU(N):
38
- N2 = int(np.ceil(N/2))
39
- sub = np.zeros([N, N2])
40
-
41
- for i in range(0, N):
42
- if i % 2 == 0:
43
- sub[i, i//2] = 1
44
- else:
45
- sub[i, i//2] = 1/2
46
- sub[i, (i//2 + 1) % N2] = 1/2
47
-
48
- return sub
49
-
50
- def genMatrixesLungsHeart():
51
- RLUNG = 44
52
- LLUNG = 50
53
- HEART = 26
54
-
55
- Asub1 = mOrgan(RLUNG)
56
- Asub2 = mOrgan(LLUNG)
57
- Asub3 = mOrgan(HEART)
58
-
59
- ADsub1 = mOrgan(int(np.ceil(RLUNG / 2)))
60
- ADsub2 = mOrgan(int(np.ceil(LLUNG / 2)))
61
- ADsub3 = mOrgan(int(np.ceil(HEART / 2)))
62
-
63
- Dsub1 = mOrganD(RLUNG)
64
- Dsub2 = mOrganD(LLUNG)
65
- Dsub3 = mOrganD(HEART)
66
-
67
- Usub1 = mOrganU(RLUNG)
68
- Usub2 = mOrganU(LLUNG)
69
- Usub3 = mOrganU(HEART)
70
-
71
- p1 = RLUNG
72
- p2 = p1 + LLUNG
73
- p3 = p2 + HEART
74
-
75
- p1_ = int(np.ceil(RLUNG / 2))
76
- p2_ = p1_ + int(np.ceil(LLUNG / 2))
77
- p3_ = p2_ + int(np.ceil(HEART / 2))
78
-
79
- A = np.zeros([p3, p3])
80
-
81
- A[:p1, :p1] = Asub1
82
- A[p1:p2, p1:p2] = Asub2
83
- A[p2:p3, p2:p3] = Asub3
84
-
85
- AD = np.zeros([p3_, p3_])
86
-
87
- AD[:p1_, :p1_] = ADsub1
88
- AD[p1_:p2_, p1_:p2_] = ADsub2
89
- AD[p2_:p3_, p2_:p3_] = ADsub3
90
-
91
- D = np.zeros([p3_, p3])
92
-
93
- D[:p1_, :p1] = Dsub1
94
- D[p1_:p2_, p1:p2] = Dsub2
95
- D[p2_:p3_, p2:p3] = Dsub3
96
-
97
- U = np.zeros([p3, p3_])
98
-
99
- U[:p1, :p1_] = Usub1
100
- U[p1:p2, p1_:p2_] = Usub2
101
- U[p2:p3, p2_:p3_] = Usub3
102
-
103
- return A, AD, D, U