yuxin
commited on
Commit
·
72b250b
1
Parent(s):
e46628a
add model
Browse files- model_segvol_single.py +8 -6
model_segvol_single.py
CHANGED
@@ -44,14 +44,16 @@ class SegVolModel(PreTrainedModel):
|
|
44 |
image,
|
45 |
zoomed_image=None,
|
46 |
text_prompt=None,
|
47 |
-
|
48 |
-
|
49 |
-
point_prompt=None,
|
50 |
-
point_prompt_map=None,
|
51 |
use_zoom=True):
|
52 |
-
# print(image.shape, zoomed_image.shape, text_prompt)
|
53 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
volume_shape = image[0][0].shape
|
56 |
|
57 |
with torch.no_grad():
|
|
|
44 |
image,
|
45 |
zoomed_image=None,
|
46 |
text_prompt=None,
|
47 |
+
bbox_prompt_group=None,
|
48 |
+
point_prompt_group=None,
|
|
|
|
|
49 |
use_zoom=True):
|
|
|
50 |
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
51 |
+
assert not (text_prompt is None and bbox_prompt_group is None and point_prompt_group is None)
|
52 |
+
bbox_prompt, bbox_prompt_map, point_prompt, point_prompt_map=None, None, None, None
|
53 |
+
if bbox_prompt_group is not None:
|
54 |
+
bbox_prompt, bbox_prompt_map = bbox_prompt_group
|
55 |
+
if point_prompt_group is not None:
|
56 |
+
point_prompt, point_prompt_map = point_prompt_group
|
57 |
volume_shape = image[0][0].shape
|
58 |
|
59 |
with torch.no_grad():
|