yuxin commited on
Commit
c995392
1 Parent(s): faa61f9
Files changed (3) hide show
  1. config.json +0 -1
  2. config_segvol.py +0 -2
  3. model_segvol_single.py +15 -6
config.json CHANGED
@@ -18,7 +18,6 @@
18
  256
19
  ],
20
  "test_mode": true,
21
- "test_w_zoom": false,
22
  "torch_dtype": "float32",
23
  "transformers_version": "4.18.0"
24
  }
 
18
  256
19
  ],
20
  "test_mode": true,
 
21
  "torch_dtype": "float32",
22
  "transformers_version": "4.18.0"
23
  }
config_segvol.py CHANGED
@@ -6,11 +6,9 @@ class SegVolConfig(PretrainedConfig):
6
  def __init__(
7
  self,
8
  test_mode=True,
9
- test_w_zoom=False,
10
  **kwargs,
11
  ):
12
  self.spatial_size = [32, 256, 256]
13
  self.patch_size = [4, 16, 16]
14
  self.test_mode = test_mode
15
- self.test_w_zoom = test_w_zoom
16
  super().__init__(**kwargs)
 
6
  def __init__(
7
  self,
8
  test_mode=True,
 
9
  **kwargs,
10
  ):
11
  self.spatial_size = [32, 256, 256]
12
  self.patch_size = [4, 16, 16]
13
  self.test_mode = test_mode
 
14
  super().__init__(**kwargs)
model_segvol_single.py CHANGED
@@ -26,8 +26,16 @@ class SegVolModel(PreTrainedModel):
26
 
27
  self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
28
 
29
- def forward(self, image, text=None, boxes=None, points=None, **kwargs):
30
- return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)
 
 
 
 
 
 
 
 
31
 
32
  # processor
33
  class SegVolProcessor():
@@ -89,15 +97,15 @@ class SegVolProcessor():
89
  item['zoom_out_label'] = item_zoom_out['label']
90
  return item
91
 
92
- def point_prompt(self, label_single_resize):
93
  point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
94
  points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
95
- binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape)
96
  return points_single, binary_points_resize
97
 
98
- def bbox_prompt(self, label_single_resize):
99
  box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
100
- binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape)
101
  return box_single, binary_cube_resize
102
 
103
  class MinMaxNormalization(transforms.Transform):
@@ -461,6 +469,7 @@ class TextEncoder(nn.Module):
461
  if text is None:
462
  return None
463
  if type(text) is str:
 
464
  text = [text]
465
  tokens = self.organ2tokens(text)
466
  clip_outputs = self.clip_text_model(**tokens)
 
26
 
27
  self.processor = SegVolProcessor(spatial_size=self.config.spatial_size)
28
 
29
+ def forward(self, image, zoomed_image=None, text_prompt=None, bbox_prompt=None, point_prompt=None, **kwargs):
30
+ print(image.shape, zoomed_image.shape, text_prompt)
31
+ print(bbox_prompt[0].shape, bbox_prompt[1].shape, point_prompt[0].shape, point_prompt[1].shape)
32
+ # test mode
33
+ if self.config.test_mode:
34
+ return
35
+ else:
36
+ print('unsupport training mode now')
37
+ return
38
+ return self.model.forward(image, text=None, boxes=None, points=None, **kwargs)
39
 
40
  # processor
41
  class SegVolProcessor():
 
97
  item['zoom_out_label'] = item_zoom_out['label']
98
  return item
99
 
100
+ def point_prompt_b(self, label_single_resize):
101
  point, point_label = select_points(label_single_resize, num_positive_extra=3, num_negative_extra=3)
102
  points_single = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
103
+ binary_points_resize = build_binary_points(point, point_label, label_single_resize.shape).unsqueeze(0).unsqueeze(0)
104
  return points_single, binary_points_resize
105
 
106
+ def bbox_prompt_b(self, label_single_resize):
107
  box_single = generate_box(label_single_resize).unsqueeze(0).float().cuda()
108
+ binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
109
  return box_single, binary_cube_resize
110
 
111
  class MinMaxNormalization(transforms.Transform):
 
469
  if text is None:
470
  return None
471
  if type(text) is str:
472
+ # text is supposed to be list
473
  text = [text]
474
  tokens = self.organ2tokens(text)
475
  clip_outputs = self.clip_text_model(**tokens)