yuxin
commited on
Commit
•
f192b2c
1
Parent(s):
005719f
add model
Browse files- model_segvol_single.py +24 -20
model_segvol_single.py
CHANGED
@@ -26,6 +26,8 @@ class SegVolModel(PreTrainedModel):
|
|
26 |
|
27 |
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
28 |
|
|
|
|
|
29 |
def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
|
30 |
print(image.shape, zoomed_image.shape, text_prompt)
|
31 |
print(bbox_prompt[0].shape, bbox_prompt[1].shape, point_prompt[0].shape, point_prompt[1].shape)
|
@@ -51,6 +53,7 @@ class SegVolProcessor():
|
|
51 |
]
|
52 |
)
|
53 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
|
|
54 |
|
55 |
# ct_path is path for a ct scan file with nii.gz format
|
56 |
# gt_path is path for a ground truth file with nii.gz format
|
@@ -97,15 +100,14 @@ class SegVolProcessor():
|
|
97 |
item['zoom_out_label'] = item_zoom_out['label']
|
98 |
return item
|
99 |
|
100 |
-
def point_prompt_b(self, label_single_resize):
|
101 |
-
|
102 |
-
|
103 |
-
points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
|
104 |
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
105 |
return points_single, binary_points_resize
|
106 |
|
107 |
def bbox_prompt_b(self, label_single_resize):
|
108 |
-
box_single = generate_box(label_single_resize).unsqueeze(0).float().
|
109 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
110 |
return box_single, binary_cube_resize
|
111 |
|
@@ -269,14 +271,15 @@ class SegVol(nn.Module):
|
|
269 |
test_mode=False,
|
270 |
):
|
271 |
super().__init__()
|
|
|
272 |
self.image_encoder = image_encoder
|
273 |
self.mask_decoder = mask_decoder
|
274 |
self.prompt_encoder = prompt_encoder
|
275 |
self.text_encoder = TextEncoder()
|
276 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
277 |
self.test_mode = test_mode
|
278 |
-
self.dice_loss = BinaryDiceLoss().
|
279 |
-
self.bce_loss = BCELoss().
|
280 |
self.decoder_iter = 6
|
281 |
|
282 |
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
@@ -379,14 +382,14 @@ class SegVol(nn.Module):
|
|
379 |
iter_points_ax.append(point)
|
380 |
iter_point_labels.append(point_label)
|
381 |
# batched prompt
|
382 |
-
iter_points_ax = torch.stack(iter_points_ax, dim=0).
|
383 |
-
iter_point_labels = torch.stack(iter_point_labels, dim=0).
|
384 |
iter_points = (iter_points_ax, iter_point_labels)
|
385 |
-
iter_bboxes = torch.stack(iter_bboxes, dim=0).float().
|
386 |
return iter_points, iter_bboxes, iter_organs
|
387 |
|
388 |
def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
|
389 |
-
pseudo_labels = torch.zeros(input_shape).
|
390 |
# generate points
|
391 |
points = []
|
392 |
point_labels = []
|
@@ -408,13 +411,13 @@ class SegVol(nn.Module):
|
|
408 |
fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
|
409 |
points.append(point)
|
410 |
point_labels.append(point_label)
|
411 |
-
points = torch.stack(points, dim=0).
|
412 |
-
point_labels = torch.stack(point_labels, dim=0).
|
413 |
pseudo_points_prompt = (points, point_labels)
|
414 |
return pseudo_labels, pseudo_points_prompt
|
415 |
|
416 |
def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
|
417 |
-
pseudo_labels = torch.zeros(input_shape).
|
418 |
iter_bboxes = []
|
419 |
# generate boxes
|
420 |
for batch_idx in range(input_shape[0]):
|
@@ -445,12 +448,13 @@ class SegVol(nn.Module):
|
|
445 |
# print(f"Mask value {value} has IOU > 0.90 in binary cube.")
|
446 |
pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
|
447 |
|
448 |
-
bboxes = torch.stack(iter_bboxes, dim=0).float().
|
449 |
return pseudo_labels, bboxes
|
450 |
|
451 |
class TextEncoder(nn.Module):
|
452 |
def __init__(self):
|
453 |
super().__init__()
|
|
|
454 |
config = CLIPTextConfig()
|
455 |
self.clip_text_model = CLIPTextModel(config)
|
456 |
self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
|
@@ -463,7 +467,7 @@ class TextEncoder(nn.Module):
|
|
463 |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
464 |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
465 |
for key in tokens.keys():
|
466 |
-
tokens[key] = tokens[key].
|
467 |
return tokens
|
468 |
|
469 |
def forward(self, text):
|
@@ -752,16 +756,16 @@ def sliding_window_inference(
|
|
752 |
if use_point:
|
753 |
window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
754 |
point, point_label = select_points(window_binary_prompt_map.squeeze())
|
755 |
-
points = (point.unsqueeze(0).float().
|
756 |
pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
757 |
-
boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().
|
758 |
if use_box:
|
759 |
if num_win == 1:
|
760 |
window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
761 |
-
boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().
|
762 |
else:
|
763 |
pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
764 |
-
boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().
|
765 |
seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
|
766 |
#############
|
767 |
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
|
|
|
26 |
|
27 |
self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
|
28 |
|
29 |
+
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
30 |
+
|
31 |
def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
|
32 |
print(image.shape, zoomed_image.shape, text_prompt)
|
33 |
print(bbox_prompt[0].shape, bbox_prompt[1].shape, point_prompt[0].shape, point_prompt[1].shape)
|
|
|
53 |
]
|
54 |
)
|
55 |
self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
|
56 |
+
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
57 |
|
58 |
# ct_path is path for a ct scan file with nii.gz format
|
59 |
# gt_path is path for a ground truth file with nii.gz format
|
|
|
100 |
item['zoom_out_label'] = item_zoom_out['label']
|
101 |
return item
|
102 |
|
103 |
+
def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0):
|
104 |
+
point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
|
105 |
+
points_single = (point.unsqueeze(0).float().to(self.custom_device), point_label.unsqueeze(0).float().to(self.custom_device))
|
|
|
106 |
binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
107 |
return points_single, binary_points_resize
|
108 |
|
109 |
def bbox_prompt_b(self, label_single_resize):
|
110 |
+
box_single = generate_box(label_single_resize).unsqueeze(0).float().to(self.custom_device)
|
111 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
112 |
return box_single, binary_cube_resize
|
113 |
|
|
|
271 |
test_mode=False,
|
272 |
):
|
273 |
super().__init__()
|
274 |
+
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
275 |
self.image_encoder = image_encoder
|
276 |
self.mask_decoder = mask_decoder
|
277 |
self.prompt_encoder = prompt_encoder
|
278 |
self.text_encoder = TextEncoder()
|
279 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
280 |
self.test_mode = test_mode
|
281 |
+
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
282 |
+
self.bce_loss = BCELoss().to(self.custom_device)
|
283 |
self.decoder_iter = 6
|
284 |
|
285 |
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
|
|
|
382 |
iter_points_ax.append(point)
|
383 |
iter_point_labels.append(point_label)
|
384 |
# batched prompt
|
385 |
+
iter_points_ax = torch.stack(iter_points_ax, dim=0).to(self.custom_device)
|
386 |
+
iter_point_labels = torch.stack(iter_point_labels, dim=0).to(self.custom_device)
|
387 |
iter_points = (iter_points_ax, iter_point_labels)
|
388 |
+
iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
|
389 |
return iter_points, iter_bboxes, iter_organs
|
390 |
|
391 |
def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
|
392 |
+
pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
|
393 |
# generate points
|
394 |
points = []
|
395 |
point_labels = []
|
|
|
411 |
fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
|
412 |
points.append(point)
|
413 |
point_labels.append(point_label)
|
414 |
+
points = torch.stack(points, dim=0).to(self.custom_device)
|
415 |
+
point_labels = torch.stack(point_labels, dim=0).to(self.custom_device)
|
416 |
pseudo_points_prompt = (points, point_labels)
|
417 |
return pseudo_labels, pseudo_points_prompt
|
418 |
|
419 |
def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
|
420 |
+
pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
|
421 |
iter_bboxes = []
|
422 |
# generate boxes
|
423 |
for batch_idx in range(input_shape[0]):
|
|
|
448 |
# print(f"Mask value {value} has IOU > 0.90 in binary cube.")
|
449 |
pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
|
450 |
|
451 |
+
bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
|
452 |
return pseudo_labels, bboxes
|
453 |
|
454 |
class TextEncoder(nn.Module):
|
455 |
def __init__(self):
|
456 |
super().__init__()
|
457 |
+
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
458 |
config = CLIPTextConfig()
|
459 |
self.clip_text_model = CLIPTextModel(config)
|
460 |
self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
|
|
|
467 |
text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
|
468 |
tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
|
469 |
for key in tokens.keys():
|
470 |
+
tokens[key] = tokens[key].to(self.custom_device)
|
471 |
return tokens
|
472 |
|
473 |
def forward(self, text):
|
|
|
756 |
if use_point:
|
757 |
window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
758 |
point, point_label = select_points(window_binary_prompt_map.squeeze())
|
759 |
+
points = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
|
760 |
pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
761 |
+
boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
|
762 |
if use_box:
|
763 |
if num_win == 1:
|
764 |
window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
765 |
+
boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().to(device)
|
766 |
else:
|
767 |
pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
|
768 |
+
boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
|
769 |
seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
|
770 |
#############
|
771 |
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
|