yuxin commited on
Commit
f192b2c
1 Parent(s): 005719f
Files changed (1) hide show
  1. model_segvol_single.py +24 -20
model_segvol_single.py CHANGED
@@ -26,6 +26,8 @@ class SegVolModel(PreTrainedModel):
26
 
27
  self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
28
 
 
 
29
  def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
30
  print(image.shape, zoomed_image.shape, text_prompt)
31
  print(bbox_prompt[0].shape, bbox_prompt[1].shape, point_prompt[0].shape, point_prompt[1].shape)
@@ -51,6 +53,7 @@ class SegVolProcessor():
51
  ]
52
  )
53
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
 
54
 
55
  # ct_path is path for a ct scan file with nii.gz format
56
  # gt_path is path for a ground truth file with nii.gz format
@@ -97,15 +100,14 @@ class SegVolProcessor():
97
  item['zoom_out_label'] = item_zoom_out['label']
98
  return item
99
 
100
- def point_prompt_b(self, label_single_resize):
101
- print(label_single_resize.shape, torch.unique(label_single_resize))
102
- point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
103
- points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
104
  binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
105
  return points_single, binary_points_resize
106
 
107
  def bbox_prompt_b(self, label_single_resize):
108
- box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
109
  binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
110
  return box_single, binary_cube_resize
111
 
@@ -269,14 +271,15 @@ class SegVol(nn.Module):
269
  test_mode=False,
270
  ):
271
  super().__init__()
 
272
  self.image_encoder = image_encoder
273
  self.mask_decoder = mask_decoder
274
  self.prompt_encoder = prompt_encoder
275
  self.text_encoder = TextEncoder()
276
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
277
  self.test_mode = test_mode
278
- self.dice_loss = BinaryDiceLoss().cuda()
279
- self.bce_loss = BCELoss().cuda()
280
  self.decoder_iter = 6
281
 
282
  def forward(self, image, text=None, boxes=None, points=None, **kwargs):
@@ -379,14 +382,14 @@ class SegVol(nn.Module):
379
  iter_points_ax.append(point)
380
  iter_point_labels.append(point_label)
381
  # batched prompt
382
- iter_points_ax = torch.stack(iter_points_ax, dim=0).cuda()
383
- iter_point_labels = torch.stack(iter_point_labels, dim=0).cuda()
384
  iter_points = (iter_points_ax, iter_point_labels)
385
- iter_bboxes = torch.stack(iter_bboxes, dim=0).float().cuda()
386
  return iter_points, iter_bboxes, iter_organs
387
 
388
  def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
389
- pseudo_labels = torch.zeros(input_shape).cuda()
390
  # generate points
391
  points = []
392
  point_labels = []
@@ -408,13 +411,13 @@ class SegVol(nn.Module):
408
  fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
409
  points.append(point)
410
  point_labels.append(point_label)
411
- points = torch.stack(points, dim=0).cuda()
412
- point_labels = torch.stack(point_labels, dim=0).cuda()
413
  pseudo_points_prompt = (points, point_labels)
414
  return pseudo_labels, pseudo_points_prompt
415
 
416
  def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
417
- pseudo_labels = torch.zeros(input_shape).cuda()
418
  iter_bboxes = []
419
  # generate boxes
420
  for batch_idx in range(input_shape[0]):
@@ -445,12 +448,13 @@ class SegVol(nn.Module):
445
  # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
446
  pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
447
 
448
- bboxes = torch.stack(iter_bboxes, dim=0).float().cuda()
449
  return pseudo_labels, bboxes
450
 
451
  class TextEncoder(nn.Module):
452
  def __init__(self):
453
  super().__init__()
 
454
  config = CLIPTextConfig()
455
  self.clip_text_model = CLIPTextModel(config)
456
  self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
@@ -463,7 +467,7 @@ class TextEncoder(nn.Module):
463
  text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
464
  tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
465
  for key in tokens.keys():
466
- tokens[key] = tokens[key].cuda()
467
  return tokens
468
 
469
  def forward(self, text):
@@ -752,16 +756,16 @@ def sliding_window_inference(
752
  if use_point:
753
  window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
754
  point, point_label = select_points(window_binary_prompt_map.squeeze())
755
- points = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
756
  pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
757
- boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().cuda()
758
  if use_box:
759
  if num_win == 1:
760
  window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
761
- boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().cuda()
762
  else:
763
  pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
764
- boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().cuda()
765
  seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
766
  #############
767
  # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
 
26
 
27
  self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
28
 
29
+ self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
30
+
31
  def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
32
  print(image.shape, zoomed_image.shape, text_prompt)
33
  print(bbox_prompt[0].shape, bbox_prompt[1].shape, point_prompt[0].shape, point_prompt[1].shape)
 
53
  ]
54
  )
55
  self.zoom_out_transform = transforms.Resized(keys=["image", "label"], spatial_size=spatial_size, mode='nearest-exact')
56
+ self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
 
58
  # ct_path is path for a ct scan file with nii.gz format
59
  # gt_path is path for a ground truth file with nii.gz format
 
100
  item['zoom_out_label'] = item_zoom_out['label']
101
  return item
102
 
103
+ def point_prompt_b(self, label_single_resize, num_positive_extra=4, num_negative_extra=0):
104
+ point, point_label = select_points(label_single_resize, num_positive_extra=num_positive_extra, num_negative_extra=num_negative_extra)
105
+ points_single = (point.unsqueeze(0).float().to(self.custom_device), point_label.unsqueeze(0).float().to(self.custom_device))
 
106
  binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
107
  return points_single, binary_points_resize
108
 
109
  def bbox_prompt_b(self, label_single_resize):
110
+ box_single = generate_box(label_single_resize).unsqueeze(0).float().to(self.custom_device)
111
  binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
112
  return box_single, binary_cube_resize
113
 
 
271
  test_mode=False,
272
  ):
273
  super().__init__()
274
+ self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
275
  self.image_encoder = image_encoder
276
  self.mask_decoder = mask_decoder
277
  self.prompt_encoder = prompt_encoder
278
  self.text_encoder = TextEncoder()
279
  self.feat_shape = np.array(roi_size)/np.array(patch_size)
280
  self.test_mode = test_mode
281
+ self.dice_loss = BinaryDiceLoss().to(self.custom_device)
282
+ self.bce_loss = BCELoss().to(self.custom_device)
283
  self.decoder_iter = 6
284
 
285
  def forward(self, image, text=None, boxes=None, points=None, **kwargs):
 
382
  iter_points_ax.append(point)
383
  iter_point_labels.append(point_label)
384
  # batched prompt
385
+ iter_points_ax = torch.stack(iter_points_ax, dim=0).to(self.custom_device)
386
+ iter_point_labels = torch.stack(iter_point_labels, dim=0).to(self.custom_device)
387
  iter_points = (iter_points_ax, iter_point_labels)
388
+ iter_bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
389
  return iter_points, iter_bboxes, iter_organs
390
 
391
  def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
392
+ pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
393
  # generate points
394
  points = []
395
  point_labels = []
 
411
  fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
412
  points.append(point)
413
  point_labels.append(point_label)
414
+ points = torch.stack(points, dim=0).to(self.custom_device)
415
+ point_labels = torch.stack(point_labels, dim=0).to(self.custom_device)
416
  pseudo_points_prompt = (points, point_labels)
417
  return pseudo_labels, pseudo_points_prompt
418
 
419
  def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
420
+ pseudo_labels = torch.zeros(input_shape).to(self.custom_device)
421
  iter_bboxes = []
422
  # generate boxes
423
  for batch_idx in range(input_shape[0]):
 
448
  # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
449
  pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
450
 
451
+ bboxes = torch.stack(iter_bboxes, dim=0).float().to(self.custom_device)
452
  return pseudo_labels, bboxes
453
 
454
  class TextEncoder(nn.Module):
455
  def __init__(self):
456
  super().__init__()
457
+ self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
458
  config = CLIPTextConfig()
459
  self.clip_text_model = CLIPTextModel(config)
460
  self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
 
467
  text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
468
  tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
469
  for key in tokens.keys():
470
+ tokens[key] = tokens[key].to(self.custom_device)
471
  return tokens
472
 
473
  def forward(self, text):
 
756
  if use_point:
757
  window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
758
  point, point_label = select_points(window_binary_prompt_map.squeeze())
759
+ points = (point.unsqueeze(0).float().to(device), point_label.unsqueeze(0).float().to(device))
760
  pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
761
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
762
  if use_box:
763
  if num_win == 1:
764
  window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
765
+ boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().to(device)
766
  else:
767
  pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
768
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().to(device)
769
  seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
770
  #############
771
  # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.