yuxin commited on
Commit
8d8118b
1 Parent(s): 35cfafe
Files changed (1) hide show
  1. model_segvol_single.py +51 -17
model_segvol_single.py CHANGED
@@ -40,29 +40,44 @@ class SegVolModel(PreTrainedModel):
40
 
41
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
- def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
 
 
 
 
 
 
 
44
  bbox_prompt, bbox_prompt_map = bbox_prompt
45
  point_prompt, point_prompt_map = point_prompt
46
  print(image.shape, zoomed_image.shape, text_prompt)
 
47
  print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
48
  volume_shape = image[0][0].shape
49
- # test mode
50
- if self.config.test_mode:
51
- with torch.no_grad():
52
- logits_global_single = self.model(zoomed_image,
53
- text=text_prompt,
54
- boxes=bbox_prompt,
55
- points=point_prompt)
56
- print(logits_global_single.shape)
57
- logits_global_single = F.interpolate(
58
- logits_global_single.cpu(),
 
 
 
 
 
 
 
59
  size=volume_shape, mode='nearest')[0][0]
60
- print(logits_global_single.shape)
61
- return
62
- else:
63
- print('unsupport training mode now')
64
- return
65
- return self.model.forward(image, text=None, boxes=None, points=None, **kwargs)
66
 
67
  # processor
68
  class SegVolProcessor():
@@ -136,6 +151,25 @@ class SegVolProcessor():
136
  binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
137
  return box_single, binary_cube_resize
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  class MinMaxNormalization(transforms.Transform):
140
  def __call__(self, data):
141
  d = dict(data)
 
40
 
41
  self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
+ def forward_test(self,
44
+ image,
45
+ zoomed_image=None,
46
+ text_prompt=None,
47
+ bbox_prompt=None,
48
+ point_prompt=None,
49
+ use_zoom=True,
50
+ **kwargs):
51
  bbox_prompt, bbox_prompt_map = bbox_prompt
52
  point_prompt, point_prompt_map = point_prompt
53
  print(image.shape, zoomed_image.shape, text_prompt)
54
+ assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
55
  print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
56
  volume_shape = image[0][0].shape
57
+
58
+ with torch.no_grad():
59
+ logits_global_single = self.model(zoomed_image,
60
+ text=text_prompt,
61
+ boxes=bbox_prompt,
62
+ points=point_prompt)
63
+ print(logits_global_single.shape)
64
+ logits_global_single = F.interpolate(
65
+ logits_global_single.cpu(),
66
+ size=volume_shape, mode='nearest')
67
+ if not use_zoom:
68
+ return logits_global_single
69
+ logits_global_single = logits_global_single[0][0]
70
+ print(logits_global_single.shape)
71
+ if point_prompt_map is not None:
72
+ binary_points = F.interpolate(
73
+ point_prompt_map.unsqueeze(0).unsqueeze(0).float(),
74
  size=volume_shape, mode='nearest')[0][0]
75
+ if bbox_prompt_map is not None:
76
+ binary_cube = F.interpolate(
77
+ bbox_prompt_map.unsqueeze(0).unsqueeze(0).float(),
78
+ size=volume_shape, mode='nearest')[0][0]
79
+ return
80
+
81
 
82
  # processor
83
  class SegVolProcessor():
 
151
  binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
152
  return box_single, binary_cube_resize
153
 
154
+ def dice_score(self, preds, labels):
155
+ assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
156
+ predict = preds.view(1, -1)
157
+ target = labels.view(1, -1)
158
+ if target.shape[1] < 1e8:
159
+ predict = predict.to(self.custom_device)
160
+ target = target.to(self.custom_device)
161
+ predict = torch.sigmoid(predict)
162
+ predict = torch.where(predict > 0.5, 1., 0.)
163
+
164
+ tp = torch.sum(torch.mul(predict, target))
165
+ den = torch.sum(predict) + torch.sum(target) + 1
166
+ dice = 2 * tp / den
167
+
168
+ if target.shape[1] < 1e8:
169
+ predict = predict.cpu()
170
+ target = target.cpu()
171
+ return dice
172
+
173
  class MinMaxNormalization(transforms.Transform):
174
  def __call__(self, data):
175
  d = dict(data)