yuxin commited on
Commit
faa61f9
1 Parent(s): 44c38ef
Files changed (1) hide show
  1. model_segvol_single.py +108 -96
model_segvol_single.py CHANGED
@@ -89,102 +89,16 @@ class SegVolProcessor():
89
  item['zoom_out_label'] = item_zoom_out['label']
90
  return item
91
 
92
- def generate_box(self, pred_pre, bbox_shift=None):
93
- meaning_post_label = pred_pre # [h, w, d]
94
- ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
95
- if all(tensor.nelement() == 0 for tensor in ones_idx):
96
- bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
97
- # print(bboxes, bboxes.shape)
98
- return bboxes
99
- min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
100
- max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
101
-
102
-
103
- if bbox_shift is None:
104
- corner_min = []
105
- corner_max = []
106
- shape = meaning_post_label.shape
107
- for coor in min_coords:
108
- coor_ = max(0, coor)
109
- corner_min.append(coor_)
110
- for idx, coor in enumerate(max_coords):
111
- coor_ = min(shape[idx], coor)
112
- corner_max.append(coor_)
113
- corner_min = torch.tensor(corner_min)
114
- corner_max = torch.tensor(corner_max)
115
- return torch.cat((corner_min, corner_max), dim=0)
116
- else:
117
- # add perturbation to bounding box coordinates
118
- corner_min = []
119
- corner_max = []
120
- shape = meaning_post_label.shape
121
- for coor in min_coords:
122
- coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
123
- corner_min.append(coor_)
124
- for idx, coor in enumerate(max_coords):
125
- coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
126
- corner_max.append(coor_)
127
- corner_min = torch.tensor(corner_min)
128
- corner_max = torch.tensor(corner_max)
129
- return torch.cat((corner_min, corner_max), dim=0)
130
-
131
-
132
- def select_points(self, preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
133
- spacial_dim = 3
134
- points = torch.zeros((0, 3))
135
- labels = torch.zeros((0))
136
- pos_thred = 0.9
137
- neg_thred = 0.1
138
-
139
- # get pos/net indices
140
- positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
141
- negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
142
-
143
- ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
144
- if all(tmp.nelement() == 0 for tmp in ones_idx):
145
- # all neg
146
- num_positive_extra = 0
147
- selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
148
- points = torch.cat((points, selected_positive_point), dim=0)
149
- labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
150
- else:
151
- # random select a pos point
152
- random_idx = torch.randint(len(positive_indices[0]), (1,))
153
- selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
154
- points = torch.cat((points, selected_positive_point), dim=0)
155
- labels = torch.cat((labels, torch.ones((1))))
156
-
157
- if num_positive_extra > 0:
158
- pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
159
- extra_positive_points = []
160
- for pos_idx in pos_idx_list:
161
- extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
162
- extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
163
- points = torch.cat((points, extra_positive_points), dim=0)
164
- labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
165
-
166
- if num_negative_extra > 0:
167
- neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
168
- extra_negative_points = []
169
- for neg_idx in neg_idx_list:
170
- extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
171
- extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
172
- points = torch.cat((points, extra_negative_points), dim=0)
173
- labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
174
- # print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
175
- # print('==> points ', points.shape, labels)
176
-
177
- if fix_extra_point_num is None:
178
- left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
179
- else:
180
- left_point_num = fix_extra_point_num + 1 - labels.shape[0]
181
-
182
- for _ in range(left_point_num):
183
- ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
184
- points = torch.cat((points, ignore_point), dim=0)
185
- labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
186
-
187
- return (points, labels)
188
 
189
  class MinMaxNormalization(transforms.Transform):
190
  def __call__(self, data):
@@ -228,7 +142,105 @@ class ForegroundNormalization(transforms.Transform):
228
  ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
229
  ct_narray = (ct_narray - mean) / max(std, 1e-8)
230
  return ct_narray
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  # SegVol
233
  import torch
234
  import torch.nn as nn
 
89
  item['zoom_out_label'] = item_zoom_out['label']
90
  return item
91
 
92
+ def point_prompt(self, label_single_resize):
93
+ point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
94
+ points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
95
+ binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
96
+ return points_single, binary_points_resize
97
+
98
+ def bbox_prompt(self, label_single_resize):
99
+ box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
100
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
101
+ return box_single, binary_cube_resize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  class MinMaxNormalization(transforms.Transform):
104
  def __call__(self, data):
 
142
  ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
143
  ct_narray = (ct_narray - mean) / max(std, 1e-8)
144
  return ct_narray
145
+
146
+ # prompts
147
+ def generate_box(pred_pre, bbox_shift=None):
148
+ meaning_post_label = pred_pre # [h, w, d]
149
+ ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
150
+ if all(tensor.nelement() == 0 for tensor in ones_idx):
151
+ bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
152
+ # print(bboxes, bboxes.shape)
153
+ return bboxes
154
+ min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
155
+ max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
156
+
157
+
158
+ if bbox_shift is None:
159
+ corner_min = []
160
+ corner_max = []
161
+ shape = meaning_post_label.shape
162
+ for coor in min_coords:
163
+ coor_ = max(0, coor)
164
+ corner_min.append(coor_)
165
+ for idx, coor in enumerate(max_coords):
166
+ coor_ = min(shape[idx], coor)
167
+ corner_max.append(coor_)
168
+ corner_min = torch.tensor(corner_min)
169
+ corner_max = torch.tensor(corner_max)
170
+ return torch.cat((corner_min, corner_max), dim=0)
171
+ else:
172
+ # add perturbation to bounding box coordinates
173
+ corner_min = []
174
+ corner_max = []
175
+ shape = meaning_post_label.shape
176
+ for coor in min_coords:
177
+ coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
178
+ corner_min.append(coor_)
179
+ for idx, coor in enumerate(max_coords):
180
+ coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
181
+ corner_max.append(coor_)
182
+ corner_min = torch.tensor(corner_min)
183
+ corner_max = torch.tensor(corner_max)
184
+ return torch.cat((corner_min, corner_max), dim=0)
185
+
186
+
187
+ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
188
+ spacial_dim = 3
189
+ points = torch.zeros((0, 3))
190
+ labels = torch.zeros((0))
191
+ pos_thred = 0.9
192
+ neg_thred = 0.1
193
+
194
+ # get pos/net indices
195
+ positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
196
+ negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
197
+
198
+ ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
199
+ if all(tmp.nelement() == 0 for tmp in ones_idx):
200
+ # all neg
201
+ num_positive_extra = 0
202
+ selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
203
+ points = torch.cat((points, selected_positive_point), dim=0)
204
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
205
+ else:
206
+ # random select a pos point
207
+ random_idx = torch.randint(len(positive_indices[0]), (1,))
208
+ selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
209
+ points = torch.cat((points, selected_positive_point), dim=0)
210
+ labels = torch.cat((labels, torch.ones((1))))
211
+
212
+ if num_positive_extra > 0:
213
+ pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
214
+ extra_positive_points = []
215
+ for pos_idx in pos_idx_list:
216
+ extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
217
+ extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
218
+ points = torch.cat((points, extra_positive_points), dim=0)
219
+ labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
220
+
221
+ if num_negative_extra > 0:
222
+ neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
223
+ extra_negative_points = []
224
+ for neg_idx in neg_idx_list:
225
+ extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
226
+ extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
227
+ points = torch.cat((points, extra_negative_points), dim=0)
228
+ labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
229
+ # print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
230
+ # print('==> points ', points.shape, labels)
231
 
232
+ if fix_extra_point_num is None:
233
+ left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
234
+ else:
235
+ left_point_num = fix_extra_point_num + 1 - labels.shape[0]
236
+
237
+ for _ in range(left_point_num):
238
+ ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
239
+ points = torch.cat((points, ignore_point), dim=0)
240
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
241
+
242
+ return (points, labels)
243
+
244
  # SegVol
245
  import torch
246
  import torch.nn as nn