yuxin
commited on
Commit
•
35cfafe
1
Parent(s):
28efdaa
add model
Browse files- model_segvol_single.py +5 -0
model_segvol_single.py
CHANGED
@@ -45,6 +45,7 @@ class SegVolModel(PreTrainedModel):
|
|
45 |
point_prompt, point_prompt_map = point_prompt
|
46 |
print(image.shape, zoomed_image.shape, text_prompt)
|
47 |
print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
|
|
|
48 |
# test mode
|
49 |
if self.config.test_mode:
|
50 |
with torch.no_grad():
|
@@ -53,6 +54,10 @@ class SegVolModel(PreTrainedModel):
|
|
53 |
boxes=bbox_prompt,
|
54 |
points=point_prompt)
|
55 |
print(logits_global_single.shape)
|
|
|
|
|
|
|
|
|
56 |
return
|
57 |
else:
|
58 |
print('unsupport training mode now')
|
|
|
45 |
point_prompt, point_prompt_map = point_prompt
|
46 |
print(image.shape, zoomed_image.shape, text_prompt)
|
47 |
print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
|
48 |
+
volume_shape = image[0][0].shape
|
49 |
# test mode
|
50 |
if self.config.test_mode:
|
51 |
with torch.no_grad():
|
|
|
54 |
boxes=bbox_prompt,
|
55 |
points=point_prompt)
|
56 |
print(logits_global_single.shape)
|
57 |
+
logits_global_single = F.interpolate(
|
58 |
+
logits_global_single.cpu(),
|
59 |
+
size=volume_shape, mode='nearest')[0][0]
|
60 |
+
print(logits_global_single.shape)
|
61 |
return
|
62 |
else:
|
63 |
print('unsupport training mode now')
|