yuxin
commited on
Commit
•
faa61f9
1
Parent(s):
44c38ef
add model
Browse files- model_segvol_single.py +108 -96
model_segvol_single.py
CHANGED
@@ -89,102 +89,16 @@ class SegVolProcessor():
|
|
89 |
item['zoom_out_label'] = item_zoom_out['label']
|
90 |
return item
|
91 |
|
92 |
-
def
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
if bbox_shift is None:
|
104 |
-
corner_min = []
|
105 |
-
corner_max = []
|
106 |
-
shape = meaning_post_label.shape
|
107 |
-
for coor in min_coords:
|
108 |
-
coor_ = max(0, coor)
|
109 |
-
corner_min.append(coor_)
|
110 |
-
for idx, coor in enumerate(max_coords):
|
111 |
-
coor_ = min(shape[idx], coor)
|
112 |
-
corner_max.append(coor_)
|
113 |
-
corner_min = torch.tensor(corner_min)
|
114 |
-
corner_max = torch.tensor(corner_max)
|
115 |
-
return torch.cat((corner_min, corner_max), dim=0)
|
116 |
-
else:
|
117 |
-
# add perturbation to bounding box coordinates
|
118 |
-
corner_min = []
|
119 |
-
corner_max = []
|
120 |
-
shape = meaning_post_label.shape
|
121 |
-
for coor in min_coords:
|
122 |
-
coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
|
123 |
-
corner_min.append(coor_)
|
124 |
-
for idx, coor in enumerate(max_coords):
|
125 |
-
coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
|
126 |
-
corner_max.append(coor_)
|
127 |
-
corner_min = torch.tensor(corner_min)
|
128 |
-
corner_max = torch.tensor(corner_max)
|
129 |
-
return torch.cat((corner_min, corner_max), dim=0)
|
130 |
-
|
131 |
-
|
132 |
-
def select_points(self, preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
|
133 |
-
spacial_dim = 3
|
134 |
-
points = torch.zeros((0, 3))
|
135 |
-
labels = torch.zeros((0))
|
136 |
-
pos_thred = 0.9
|
137 |
-
neg_thred = 0.1
|
138 |
-
|
139 |
-
# get pos/net indices
|
140 |
-
positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
|
141 |
-
negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
|
142 |
-
|
143 |
-
ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
|
144 |
-
if all(tmp.nelement() == 0 for tmp in ones_idx):
|
145 |
-
# all neg
|
146 |
-
num_positive_extra = 0
|
147 |
-
selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
|
148 |
-
points = torch.cat((points, selected_positive_point), dim=0)
|
149 |
-
labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
|
150 |
-
else:
|
151 |
-
# random select a pos point
|
152 |
-
random_idx = torch.randint(len(positive_indices[0]), (1,))
|
153 |
-
selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
|
154 |
-
points = torch.cat((points, selected_positive_point), dim=0)
|
155 |
-
labels = torch.cat((labels, torch.ones((1))))
|
156 |
-
|
157 |
-
if num_positive_extra > 0:
|
158 |
-
pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
|
159 |
-
extra_positive_points = []
|
160 |
-
for pos_idx in pos_idx_list:
|
161 |
-
extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
|
162 |
-
extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
|
163 |
-
points = torch.cat((points, extra_positive_points), dim=0)
|
164 |
-
labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
|
165 |
-
|
166 |
-
if num_negative_extra > 0:
|
167 |
-
neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
|
168 |
-
extra_negative_points = []
|
169 |
-
for neg_idx in neg_idx_list:
|
170 |
-
extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
|
171 |
-
extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
|
172 |
-
points = torch.cat((points, extra_negative_points), dim=0)
|
173 |
-
labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
|
174 |
-
# print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
|
175 |
-
# print('==> points ', points.shape, labels)
|
176 |
-
|
177 |
-
if fix_extra_point_num is None:
|
178 |
-
left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
|
179 |
-
else:
|
180 |
-
left_point_num = fix_extra_point_num + 1 - labels.shape[0]
|
181 |
-
|
182 |
-
for _ in range(left_point_num):
|
183 |
-
ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
|
184 |
-
points = torch.cat((points, ignore_point), dim=0)
|
185 |
-
labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
|
186 |
-
|
187 |
-
return (points, labels)
|
188 |
|
189 |
class MinMaxNormalization(transforms.Transform):
|
190 |
def __call__(self, data):
|
@@ -228,7 +142,105 @@ class ForegroundNormalization(transforms.Transform):
|
|
228 |
ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
|
229 |
ct_narray = (ct_narray - mean) / max(std, 1e-8)
|
230 |
return ct_narray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
# SegVol
|
233 |
import torch
|
234 |
import torch.nn as nn
|
|
|
89 |
item['zoom_out_label'] = item_zoom_out['label']
|
90 |
return item
|
91 |
|
92 |
+
def point_prompt(self, label_single_resize):
|
93 |
+
point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
|
94 |
+
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
95 |
+
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
|
96 |
+
return points_single, binary_points_resize
|
97 |
+
|
98 |
+
def bbox_prompt(self, label_single_resize):
|
99 |
+
box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
|
100 |
+
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
|
101 |
+
return box_single, binary_cube_resize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
class MinMaxNormalization(transforms.Transform):
|
104 |
def __call__(self, data):
|
|
|
142 |
ct_narray = np.clip(ct_narray, lower_bound, upper_bound)
|
143 |
ct_narray = (ct_narray - mean) / max(std, 1e-8)
|
144 |
return ct_narray
|
145 |
+
|
146 |
+
# prompts
|
147 |
+
def generate_box(pred_pre, bbox_shift=None):
|
148 |
+
meaning_post_label = pred_pre # [h, w, d]
|
149 |
+
ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
|
150 |
+
if all(tensor.nelement() == 0 for tensor in ones_idx):
|
151 |
+
bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
|
152 |
+
# print(bboxes, bboxes.shape)
|
153 |
+
return bboxes
|
154 |
+
min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
|
155 |
+
max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
|
156 |
+
|
157 |
+
|
158 |
+
if bbox_shift is None:
|
159 |
+
corner_min = []
|
160 |
+
corner_max = []
|
161 |
+
shape = meaning_post_label.shape
|
162 |
+
for coor in min_coords:
|
163 |
+
coor_ = max(0, coor)
|
164 |
+
corner_min.append(coor_)
|
165 |
+
for idx, coor in enumerate(max_coords):
|
166 |
+
coor_ = min(shape[idx], coor)
|
167 |
+
corner_max.append(coor_)
|
168 |
+
corner_min = torch.tensor(corner_min)
|
169 |
+
corner_max = torch.tensor(corner_max)
|
170 |
+
return torch.cat((corner_min, corner_max), dim=0)
|
171 |
+
else:
|
172 |
+
# add perturbation to bounding box coordinates
|
173 |
+
corner_min = []
|
174 |
+
corner_max = []
|
175 |
+
shape = meaning_post_label.shape
|
176 |
+
for coor in min_coords:
|
177 |
+
coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
|
178 |
+
corner_min.append(coor_)
|
179 |
+
for idx, coor in enumerate(max_coords):
|
180 |
+
coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
|
181 |
+
corner_max.append(coor_)
|
182 |
+
corner_min = torch.tensor(corner_min)
|
183 |
+
corner_max = torch.tensor(corner_max)
|
184 |
+
return torch.cat((corner_min, corner_max), dim=0)
|
185 |
+
|
186 |
+
|
187 |
+
def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
|
188 |
+
spacial_dim = 3
|
189 |
+
points = torch.zeros((0, 3))
|
190 |
+
labels = torch.zeros((0))
|
191 |
+
pos_thred = 0.9
|
192 |
+
neg_thred = 0.1
|
193 |
+
|
194 |
+
# get pos/net indices
|
195 |
+
positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
|
196 |
+
negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
|
197 |
+
|
198 |
+
ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
|
199 |
+
if all(tmp.nelement() == 0 for tmp in ones_idx):
|
200 |
+
# all neg
|
201 |
+
num_positive_extra = 0
|
202 |
+
selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
|
203 |
+
points = torch.cat((points, selected_positive_point), dim=0)
|
204 |
+
labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
|
205 |
+
else:
|
206 |
+
# random select a pos point
|
207 |
+
random_idx = torch.randint(len(positive_indices[0]), (1,))
|
208 |
+
selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
|
209 |
+
points = torch.cat((points, selected_positive_point), dim=0)
|
210 |
+
labels = torch.cat((labels, torch.ones((1))))
|
211 |
+
|
212 |
+
if num_positive_extra > 0:
|
213 |
+
pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
|
214 |
+
extra_positive_points = []
|
215 |
+
for pos_idx in pos_idx_list:
|
216 |
+
extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
|
217 |
+
extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
|
218 |
+
points = torch.cat((points, extra_positive_points), dim=0)
|
219 |
+
labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
|
220 |
+
|
221 |
+
if num_negative_extra > 0:
|
222 |
+
neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
|
223 |
+
extra_negative_points = []
|
224 |
+
for neg_idx in neg_idx_list:
|
225 |
+
extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
|
226 |
+
extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
|
227 |
+
points = torch.cat((points, extra_negative_points), dim=0)
|
228 |
+
labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
|
229 |
+
# print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
|
230 |
+
# print('==> points ', points.shape, labels)
|
231 |
|
232 |
+
if fix_extra_point_num is None:
|
233 |
+
left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
|
234 |
+
else:
|
235 |
+
left_point_num = fix_extra_point_num + 1 - labels.shape[0]
|
236 |
+
|
237 |
+
for _ in range(left_point_num):
|
238 |
+
ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
|
239 |
+
points = torch.cat((points, ignore_point), dim=0)
|
240 |
+
labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
|
241 |
+
|
242 |
+
return (points, labels)
|
243 |
+
|
244 |
# SegVol
|
245 |
import torch
|
246 |
import torch.nn as nn
|