yuxin commited on
Commit
2b7366f
1 Parent(s): 8d8118b
Files changed (1) hide show
  1. model_segvol_single.py +44 -8
model_segvol_single.py CHANGED
@@ -48,8 +48,10 @@ class SegVolModel(PreTrainedModel):
48
  point_prompt=None,
49
  use_zoom=True,
50
  **kwargs):
51
- bbox_prompt, bbox_prompt_map = bbox_prompt
52
- point_prompt, point_prompt_map = point_prompt
 
 
53
  print(image.shape, zoomed_image.shape, text_prompt)
54
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
55
  print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
@@ -66,17 +68,51 @@ class SegVolModel(PreTrainedModel):
66
  size=volume_shape, mode='nearest')
67
  if not use_zoom:
68
  return logits_global_single
69
- logits_global_single = logits_global_single[0][0]
70
  print(logits_global_single.shape)
71
  if point_prompt_map is not None:
72
  binary_points = F.interpolate(
73
- point_prompt_map.unsqueeze(0).unsqueeze(0).float(),
74
- size=volume_shape, mode='nearest')[0][0]
75
  if bbox_prompt_map is not None:
76
  binary_cube = F.interpolate(
77
- bbox_prompt_map.unsqueeze(0).unsqueeze(0).float(),
78
- size=volume_shape, mode='nearest')[0][0]
79
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  # processor
 
48
  point_prompt=None,
49
  use_zoom=True,
50
  **kwargs):
51
+ if bbox_prompt is not None:
52
+ bbox_prompt, bbox_prompt_map = bbox_prompt
53
+ if point_prompt is not None:
54
+ point_prompt, point_prompt_map = point_prompt
55
  print(image.shape, zoomed_image.shape, text_prompt)
56
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
57
  print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
 
68
  size=volume_shape, mode='nearest')
69
  if not use_zoom:
70
  return logits_global_single
 
71
  print(logits_global_single.shape)
72
  if point_prompt_map is not None:
73
  binary_points = F.interpolate(
74
+ point_prompt_map.float(),
75
+ size=volume_shape, mode='nearest')
76
  if bbox_prompt_map is not None:
77
  binary_cube = F.interpolate(
78
+ bbox_prompt_map.float(),
79
+ size=volume_shape, mode='nearest')
80
+
81
+ min_d, min_h, min_w, max_d, max_h, max_w = logits2roi_coor(self.config.spatial_size, logits_global_single[0][0])
82
+ assert min_d is None, 'Fail to detect foreground!'
83
+
84
+ # Crop roi
85
+ image_single_cropped = image[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
86
+ global_preds = (torch.sigmoid(logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1])>0.5).long()
87
+
88
+ assert not (bbox_prompt is not None and point_prompt is not None)
89
+ prompt_reflection = None
90
+ if bbox_prompt is not None:
91
+ binary_cube_cropped = binary_cube[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
92
+ prompt_reflection = (
93
+ binary_cube_cropped,
94
+ global_preds
95
+ )
96
+ if point_prompt is not None:
97
+ binary_points_cropped = binary_points[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1]
98
+ prompt_reflection = (
99
+ binary_points_cropped,
100
+ global_preds
101
+ )
102
+
103
+ ## inference
104
+ with torch.no_grad():
105
+ logits_single_cropped = sliding_window_inference(
106
+ image_single_cropped.cuda(), prompt_reflection,
107
+ self.config.spatial_size, 1, self.model, 0.5,
108
+ text=text_prompt,
109
+ use_box=bbox_prompt is not None,
110
+ use_point=point_prompt is not None,
111
+ )
112
+ logits_single_cropped = logits_single_cropped.cpu().squeeze()
113
+ logits_global_single[:, :, min_d:max_d+1, min_h:max_h+1, min_w:max_w+1] = logits_single_cropped
114
+ return logits_global_single
115
+
116
 
117
 
118
  # processor