yuxin commited on
Commit
72b250b
·
1 Parent(s): e46628a
Files changed (1) hide show
  1. model_segvol_single.py +8 -6
model_segvol_single.py CHANGED
@@ -44,14 +44,16 @@ class SegVolModel(PreTrainedModel):
44
  image,
45
  zoomed_image=None,
46
  text_prompt=None,
47
- bbox_prompt=None,
48
- bbox_prompt_map=None,
49
- point_prompt=None,
50
- point_prompt_map=None,
51
  use_zoom=True):
52
- # print(image.shape, zoomed_image.shape, text_prompt)
53
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
54
- # print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
 
 
 
 
 
55
  volume_shape = image[0][0].shape
56
 
57
  with torch.no_grad():
 
44
  image,
45
  zoomed_image=None,
46
  text_prompt=None,
47
+ bbox_prompt_group=None,
48
+ point_prompt_group=None,
 
 
49
  use_zoom=True):
 
50
  assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
51
+ assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None)
52
+ bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
53
+ if bbox_prompt_group is not None:
54
+ bbox_prompt, bbox_prompt_map = bbox_prompt_group
55
+ if point_prompt_group is not None:
56
+ point_prompt, point_prompt_map = point_prompt_group
57
  volume_shape = image[0][0].shape
58
 
59
  with torch.no_grad():