yuxin
commited on
Commit
•
8d8118b
1
Parent(s):
35cfafe
add model
Browse files- model_segvol_single.py +51 -17
model_segvol_single.py
CHANGED
@@ -40,29 +40,44 @@ class SegVolModel(PreTrainedModel):
|
|
40 |
|
41 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
|
43 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
bbox_prompt, bbox_prompt_map = bbox_prompt
|
45 |
point_prompt, point_prompt_map = point_prompt
|
46 |
print(image.shape, zoomed_image.shape, text_prompt)
|
|
|
47 |
print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
|
48 |
volume_shape = image[0][0].shape
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
logits_global_single
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
size=volume_shape, mode='nearest')[0][0]
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
|
67 |
# processor
|
68 |
class SegVolProcessor():
|
@@ -136,6 +151,25 @@ class SegVolProcessor():
|
|
136 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
137 |
return box_single, binary_cube_resize
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
class MinMaxNormalization(transforms.Transform):
|
140 |
def __call__(self, data):
|
141 |
d = dict(data)
|
|
|
40 |
|
41 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
42 |
|
43 |
+
def forward_test(self,
|
44 |
+
image,
|
45 |
+
zoomed_image=None,
|
46 |
+
text_prompt=None,
|
47 |
+
bbox_prompt=None,
|
48 |
+
point_prompt=None,
|
49 |
+
use_zoom=True,
|
50 |
+
**kwargs):
|
51 |
bbox_prompt, bbox_prompt_map = bbox_prompt
|
52 |
point_prompt, point_prompt_map = point_prompt
|
53 |
print(image.shape, zoomed_image.shape, text_prompt)
|
54 |
+
assert image.shape[0] == 1 and zoomed_image.shape[0] == 1, 'batch size should be 1'
|
55 |
print(bbox_prompt.shape, bbox_prompt_map.shape, point_prompt[0].shape, point_prompt[1].shape, point_prompt_map.shape)
|
56 |
volume_shape = image[0][0].shape
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
logits_global_single = self.model(zoomed_image,
|
60 |
+
text=text_prompt,
|
61 |
+
boxes=bbox_prompt,
|
62 |
+
points=point_prompt)
|
63 |
+
print(logits_global_single.shape)
|
64 |
+
logits_global_single = F.interpolate(
|
65 |
+
logits_global_single.cpu(),
|
66 |
+
size=volume_shape, mode='nearest')
|
67 |
+
if not use_zoom:
|
68 |
+
return logits_global_single
|
69 |
+
logits_global_single = logits_global_single[0][0]
|
70 |
+
print(logits_global_single.shape)
|
71 |
+
if point_prompt_map is not None:
|
72 |
+
binary_points = F.interpolate(
|
73 |
+
point_prompt_map.unsqueeze(0).unsqueeze(0).float(),
|
74 |
size=volume_shape, mode='nearest')[0][0]
|
75 |
+
if bbox_prompt_map is not None:
|
76 |
+
binary_cube = F.interpolate(
|
77 |
+
bbox_prompt_map.unsqueeze(0).unsqueeze(0).float(),
|
78 |
+
size=volume_shape, mode='nearest')[0][0]
|
79 |
+
return
|
80 |
+
|
81 |
|
82 |
# processor
|
83 |
class SegVolProcessor():
|
|
|
151 |
binary_cube_resize = build_binary_cube(box_single, binary_cube_shape=label_single_resize.shape).unsqueeze(0).unsqueeze(0)
|
152 |
return box_single, binary_cube_resize
|
153 |
|
154 |
+
def dice_score(self, preds, labels):
|
155 |
+
assert preds.shape[0] == labels.shape[0], "predict & target batch size don't match\n" + str(preds.shape) + str(labels.shape)
|
156 |
+
predict = preds.view(1, -1)
|
157 |
+
target = labels.view(1, -1)
|
158 |
+
if target.shape[1] < 1e8:
|
159 |
+
predict = predict.to(self.custom_device)
|
160 |
+
target = target.to(self.custom_device)
|
161 |
+
predict = torch.sigmoid(predict)
|
162 |
+
predict = torch.where(predict > 0.5, 1., 0.)
|
163 |
+
|
164 |
+
tp = torch.sum(torch.mul(predict, target))
|
165 |
+
den = torch.sum(predict) + torch.sum(target) + 1
|
166 |
+
dice = 2 * tp / den
|
167 |
+
|
168 |
+
if target.shape[1] < 1e8:
|
169 |
+
predict = predict.cpu()
|
170 |
+
target = target.cpu()
|
171 |
+
return dice
|
172 |
+
|
173 |
class MinMaxNormalization(transforms.Transform):
|
174 |
def __call__(self, data):
|
175 |
d = dict(data)
|