yuxin commited on
Commit
7a01238
1 Parent(s): 2fbf9d3
Files changed (2) hide show
  1. config.json +1 -1
  2. model_segvol_single.py +1661 -0
config.json CHANGED
@@ -4,7 +4,7 @@
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "config_segvol.SegVolConfig",
7
- "AutoModel": "model_segvol.SegVolModel"
8
  },
9
  "model_type": "segvol",
10
  "patch_size": [
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "config_segvol.SegVolConfig",
7
+ "AutoModel": "model_segvol_single.SegVolModel"
8
  },
9
  "model_type": "segvol",
10
  "patch_size": [
model_segvol_single.py ADDED
@@ -0,0 +1,1661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from config_segvol import SegVolConfig
3
+
4
+ class SegVolModel(PreTrainedModel):
5
+ config_class = SegVolConfig
6
+
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ sam_model = _build_sam(
10
+ image_encoder_type='vit',
11
+ embed_dim = 768,
12
+ patch_size=self.config.patch_size,
13
+ checkpoint=None,
14
+ image_size=self.config.spatial_size,
15
+ )
16
+ self.model = SegVol(
17
+ image_encoder=sam_model.image_encoder,
18
+ mask_decoder=sam_model.mask_decoder,
19
+ prompt_encoder=sam_model.prompt_encoder,
20
+ roi_size=self.config.spatial_size,
21
+ patch_size=self.config.patch_size,
22
+ test_mode=self.config.test_mode,
23
+ )
24
+
25
+ def forward(self, image, text=None, boxes=None, points=None, **kwargs):
26
+ return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)
27
+
28
+ # SegVol
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ import numpy as np
33
+ from transformers import AutoTokenizer, CLIPTextModel, CLIPTextConfig
34
+ import random
35
+
36
+ #%% set up model
37
+ class SegVol(nn.Module):
38
+ def __init__(self,
39
+ image_encoder,
40
+ mask_decoder,
41
+ prompt_encoder,
42
+ roi_size,
43
+ patch_size,
44
+ test_mode=False,
45
+ ):
46
+ super().__init__()
47
+ self.image_encoder = image_encoder
48
+ self.mask_decoder = mask_decoder
49
+ self.prompt_encoder = prompt_encoder
50
+ self.text_encoder = TextEncoder()
51
+ self.feat_shape = np.array(roi_size)/np.array(patch_size)
52
+ self.test_mode = test_mode
53
+ self.dice_loss = BinaryDiceLoss().cuda()
54
+ self.bce_loss = BCELoss().cuda()
55
+ self.decoder_iter = 6
56
+
57
+ def forward(self, image, text=None, boxes=None, points=None, **kwargs):
58
+ bs = image.shape[0]
59
+ img_shape = (image.shape[2], image.shape[3], image.shape[4])
60
+ image_embedding, _ = self.image_encoder(image)
61
+ image_embedding = image_embedding.transpose(1, 2).view(bs, -1,
62
+ int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))
63
+ # test mode
64
+ if self.test_mode:
65
+ return self.forward_decoder(image_embedding, img_shape, text, boxes, points)
66
+
67
+ # train mode
68
+ ## sl
69
+ sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
70
+ ## ssl
71
+ ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
72
+ return sl_loss, ssl_loss
73
+
74
+ def forward_decoder(self, image_embedding, img_shape, text=None, boxes=None, points=None):
75
+ with torch.no_grad():
76
+ if boxes is not None:
77
+ if len(boxes.shape) == 2:
78
+ boxes = boxes[:, None, :] # (B, 1, 6)
79
+ if text is not None:
80
+ text_embedding = self.text_encoder(text) # (B, 768)
81
+ else:
82
+ text_embedding = None
83
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
84
+ points=points,
85
+ boxes=boxes,
86
+ masks=None,
87
+ text_embedding=text_embedding,
88
+ )
89
+
90
+ dense_pe = self.prompt_encoder.get_dense_pe()
91
+ low_res_masks, _ = self.mask_decoder(
92
+ image_embeddings=image_embedding,
93
+ text_embedding = text_embedding,
94
+ image_pe=dense_pe,
95
+ sparse_prompt_embeddings=sparse_embeddings,
96
+ dense_prompt_embeddings=dense_embeddings,
97
+ multimask_output=False,
98
+ )
99
+ logits = F.interpolate(low_res_masks, size=img_shape, mode='trilinear', align_corners=False)
100
+ return logits
101
+
102
+ def supervised_forward(self, image, image_embedding, img_shape, training_organs, train_labels):
103
+ iter_points, iter_bboxes, iter_organs = self.build_prompt_label(image.shape[0], training_organs, train_labels)
104
+ # select prompt
105
+ prompt_options = [[None, iter_points, iter_organs], [iter_bboxes, None, iter_organs],
106
+ [None, None, iter_organs], [iter_bboxes, None, None], [None, iter_points, None],
107
+ [iter_bboxes, iter_points, None]]
108
+ sl_loss = 0
109
+ for prompt in prompt_options:
110
+ bboxes, points, organs = prompt
111
+ logits = self.forward_decoder(image_embedding, img_shape, text=organs, boxes=bboxes, points=points)
112
+ # cal loss
113
+ sl_loss_dice = self.dice_loss.forward(logits.squeeze().float(), train_labels.squeeze().float())
114
+ sl_loss_bce = self.bce_loss.forward(logits.squeeze().float(), train_labels.squeeze().float())
115
+ sl_loss += sl_loss_dice + sl_loss_bce
116
+ return sl_loss
117
+
118
+ def unsupervised_forward(self, image, image_embedding, pseudo_seg_cleaned, img_shape):
119
+ sll_loss = 0
120
+ for iter in range(self.decoder_iter):
121
+ if iter % 2 == 0:
122
+ pseudo_labels, pseudo_points_prompt = self.build_pseudo_point_prompt_label(image.shape, pseudo_seg_cleaned)
123
+ logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=None, points=pseudo_points_prompt)
124
+ else:
125
+ pseudo_labels, pseudo_bboxes_prompt = self.build_pseudo_box_prompt_label(image.shape, pseudo_seg_cleaned)
126
+ logits = self.forward_decoder(image_embedding, img_shape, text=None, boxes=pseudo_bboxes_prompt, points=None)
127
+ # cal loss
128
+ sll_loss_dice = self.dice_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
129
+ sll_loss_bce = self.bce_loss.forward(logits.squeeze().float(), pseudo_labels.squeeze().float())
130
+ sll_loss += sll_loss_dice + sll_loss_bce
131
+ return sll_loss
132
+
133
+ def build_prompt_label(self, bs, training_organs, train_labels):
134
+ # generate prompt & label
135
+ iter_organs = []
136
+ iter_bboxes = []
137
+ iter_points_ax = []
138
+ iter_point_labels = []
139
+ for sample_idx in range(bs):
140
+ # organ prompt
141
+ iter_organs.append(training_organs)
142
+ # box prompt
143
+ box = generate_box(train_labels[sample_idx])
144
+ iter_bboxes.append(box)
145
+ # point prompt
146
+ num_positive_extra_max, num_negative_extra_max = 10, 10
147
+ num_positive_extra = random.randint(0, num_positive_extra_max)
148
+ num_negative_extra = random.randint(0, num_negative_extra_max)
149
+ point, point_label = select_points(
150
+ train_labels[sample_idx],
151
+ num_positive_extra=num_positive_extra,
152
+ num_negative_extra=num_negative_extra,
153
+ fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
154
+ iter_points_ax.append(point)
155
+ iter_point_labels.append(point_label)
156
+ # batched prompt
157
+ iter_points_ax = torch.stack(iter_points_ax, dim=0).cuda()
158
+ iter_point_labels = torch.stack(iter_point_labels, dim=0).cuda()
159
+ iter_points = (iter_points_ax, iter_point_labels)
160
+ iter_bboxes = torch.stack(iter_bboxes, dim=0).float().cuda()
161
+ return iter_points, iter_bboxes, iter_organs
162
+
163
+ def build_pseudo_point_prompt_label(self, input_shape, seg_labels):
164
+ pseudo_labels = torch.zeros(input_shape).cuda()
165
+ # generate points
166
+ points = []
167
+ point_labels = []
168
+ for batch_idx in range(input_shape[0]):
169
+ # generate pseudo label
170
+ unique_ids = torch.unique(seg_labels[batch_idx])
171
+ unique_ids = unique_ids[unique_ids != -1]
172
+ region_id = random.choice(unique_ids).item()
173
+ pseudo_labels[batch_idx][seg_labels[batch_idx]==region_id] = 1
174
+ # generate point prompt
175
+ num_positive_extra_max, num_negative_extra_max = 10, 10
176
+ num_positive_extra = random.randint(4, num_positive_extra_max)
177
+ num_negative_extra = random.randint(0, num_negative_extra_max)
178
+ assert len(pseudo_labels[batch_idx][0].shape) == 3
179
+ point, point_label = select_points(
180
+ pseudo_labels[batch_idx][0],
181
+ num_positive_extra=num_positive_extra,
182
+ num_negative_extra=num_negative_extra,
183
+ fix_extra_point_num=num_positive_extra_max + num_negative_extra_max)
184
+ points.append(point)
185
+ point_labels.append(point_label)
186
+ points = torch.stack(points, dim=0).cuda()
187
+ point_labels = torch.stack(point_labels, dim=0).cuda()
188
+ pseudo_points_prompt = (points, point_labels)
189
+ return pseudo_labels, pseudo_points_prompt
190
+
191
+ def build_pseudo_box_prompt_label(self, input_shape, seg_labels_cleaned):
192
+ pseudo_labels = torch.zeros(input_shape).cuda()
193
+ iter_bboxes = []
194
+ # generate boxes
195
+ for batch_idx in range(input_shape[0]):
196
+ # generate ori pseudo label
197
+ unique_ids = torch.unique(seg_labels_cleaned[batch_idx])
198
+ unique_ids = unique_ids[unique_ids != -1]
199
+ region_id = random.choice(unique_ids).item()
200
+ pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==region_id] = 1
201
+ # generate box prompt
202
+ box = generate_box(pseudo_labels[batch_idx][0])
203
+ iter_bboxes.append(box)
204
+ # refine pseudo label
205
+ x_min, y_min, z_min, x_max, y_max, z_max = box
206
+ binary_cube = torch.zeros_like(pseudo_labels[batch_idx][0]).int()
207
+ binary_cube[x_min:x_max+1, y_min:y_max+1, z_min:z_max+1] = 1
208
+ # cal iou
209
+ mask_label = seg_labels_cleaned[batch_idx][0]
210
+ assert binary_cube.shape == mask_label.shape, str(binary_cube.shape) + ' ' + str(mask_label.shape)
211
+ mask_values_in_binary_cube = mask_label[binary_cube == 1]
212
+ unique_mask_values = torch.unique(mask_values_in_binary_cube)
213
+ # print('unique_mask_values ', unique_mask_values)
214
+ for value in unique_mask_values:
215
+ if value == -1: continue
216
+ mask_area = (mask_label == value)
217
+ intersection = (binary_cube & mask_area)
218
+ iou = intersection.float().sum() / mask_area.float().sum()
219
+ if iou > 0.90:
220
+ # print(f"Mask value {value} has IOU > 0.90 in binary cube.")
221
+ pseudo_labels[batch_idx][seg_labels_cleaned[batch_idx]==value] = 1
222
+
223
+ bboxes = torch.stack(iter_bboxes, dim=0).float().cuda()
224
+ return pseudo_labels, bboxes
225
+
226
+ class TextEncoder(nn.Module):
227
+ def __init__(self):
228
+ super().__init__()
229
+ config = CLIPTextConfig()
230
+ self.clip_text_model = CLIPTextModel(config)
231
+ self.tokenizer = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
232
+ self.dim_align = nn.Linear(512, 768)
233
+ # freeze text encoder
234
+ for param in self.clip_text_model.parameters():
235
+ param.requires_grad = False
236
+
237
+ def organ2tokens(self, organ_names):
238
+ text_list = ['A computerized tomography of a {}.'.format(organ_name) for organ_name in organ_names]
239
+ tokens = self.tokenizer(text_list, padding=True, return_tensors="pt")
240
+ for key in tokens.keys():
241
+ tokens[key] = tokens[key].cuda()
242
+ return tokens
243
+
244
+ def forward(self, text):
245
+ if text is None:
246
+ return None
247
+ if type(text) is str:
248
+ text = [text]
249
+ tokens = self.organ2tokens(text)
250
+ clip_outputs = self.clip_text_model(**tokens)
251
+ text_embedding = clip_outputs.pooler_output
252
+ text_embedding = self.dim_align(text_embedding)
253
+ return text_embedding
254
+
255
+ # loss
256
+ import torch
257
+ import torch.nn as nn
258
+
259
+ class BinaryDiceLoss(nn.Module):
260
+ def __init__(self, smooth=1, p=2, reduction='mean'):
261
+ super(BinaryDiceLoss, self).__init__()
262
+ self.smooth = smooth
263
+ self.p = p
264
+ self.reduction = reduction
265
+
266
+ def forward(self, predict, target):
267
+ predict = torch.sigmoid(predict)
268
+ target_ = target.clone()
269
+ target_[target == -1] = 0
270
+ assert predict.shape[0] == target.shape[0], "predict & target batch size don't match\n" + str(predict.shape) + '\n' + str(target.shape[0])
271
+ predict = predict.contiguous().view(predict.shape[0], -1)
272
+ target_ = target_.contiguous().view(target_.shape[0], -1)
273
+
274
+ num = torch.sum(torch.mul(predict, target_), dim=1)
275
+ den = torch.sum(predict, dim=1) + torch.sum(target_, dim=1) + self.smooth
276
+
277
+ dice_score = 2*num / den
278
+ dice_loss = 1 - dice_score
279
+
280
+ # dice_loss_avg = dice_loss[target[:,0]!=-1].sum() / dice_loss[target[:,0]!=-1].shape[0]
281
+ dice_loss_avg = dice_loss.sum() / dice_loss.shape[0]
282
+
283
+ return dice_loss_avg
284
+
285
+ class BCELoss(nn.Module):
286
+ def __init__(self):
287
+ super(BCELoss, self).__init__()
288
+ self.criterion = nn.BCEWithLogitsLoss()
289
+
290
+ def forward(self, predict, target):
291
+ assert predict.shape == target.shape, 'predict & target shape do not match\n' + str(predict.shape) + '\n' + str(target.shape)
292
+ target_ = target.clone()
293
+ target_[target == -1] = 0
294
+
295
+ ce_loss = self.criterion(predict, target_)
296
+
297
+ return ce_loss
298
+
299
+ # monai inference
300
+
301
+ # Copyright (c) MONAI Consortium
302
+ # Licensed under the Apache License, Version 2.0 (the "License");
303
+ # you may not use this file except in compliance with the License.
304
+ # You may obtain a copy of the License at
305
+ # http://www.apache.org/licenses/LICENSE-2.0
306
+ # Unless required by applicable law or agreed to in writing, software
307
+ # distributed under the License is distributed on an "AS IS" BASIS,
308
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
309
+ # See the License for the specific language governing permissions and
310
+ # limitations under the License.
311
+
312
+ import warnings
313
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
314
+
315
+ import torch
316
+ import torch.nn.functional as F
317
+ import random
318
+
319
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
320
+ from monai.transforms import Resize
321
+ from monai.utils import (
322
+ BlendMode,
323
+ PytorchPadMode,
324
+ convert_data_type,
325
+ ensure_tuple,
326
+ fall_back_tuple,
327
+ look_up_option,
328
+ optional_import,
329
+ )
330
+
331
+ tqdm, _ = optional_import("tqdm", name="tqdm")
332
+
333
+ __all__ = ["sliding_window_inference"]
334
+
335
+ def logits2roi_coor(spatial_size, logits_global_single):
336
+ # crop predict
337
+ pred_global_single = torch.sigmoid(logits_global_single) > 0.5
338
+ ## get all pos idx
339
+ nonzero_indices = torch.nonzero(pred_global_single)
340
+ if nonzero_indices.shape[0] == 0:
341
+ return None, None, None, None, None, None
342
+ ## get boundary
343
+ min_d, max_d = nonzero_indices[:, 0].min(), nonzero_indices[:, 0].max()
344
+ min_h, max_h = nonzero_indices[:, 1].min(), nonzero_indices[:, 1].max()
345
+ min_w, max_w = nonzero_indices[:, 2].min(), nonzero_indices[:, 2].max()
346
+ ## padding
347
+ crop_d, crop_h, crop_w = max_d - min_d + 1, max_h - min_h + 1, max_w - min_w + 1,
348
+ window_d, window_h, window_w = spatial_size
349
+ padding_d, padding_h, padding_w = max(0, window_d-crop_d), max(0, window_h-crop_h), max(0, window_w-crop_w)
350
+ global_d, global_h, global_w = logits_global_single.shape
351
+ min_d = max(0, min_d - int(padding_d)//2)
352
+ min_h = max(0, min_h - int(padding_h)//2)
353
+ min_w = max(0, min_w - int(padding_w)//2)
354
+ max_d = min(global_d, max_d + int(padding_d)//2)
355
+ max_h = min(global_h, max_h + int(padding_h)//2)
356
+ max_w = min(global_w, max_w + int(padding_w)//2)
357
+ return min_d, min_h, min_w, max_d, max_h, max_w
358
+
359
+ def build_binary_cube(bbox, binary_cube_shape):
360
+ min_coord = bbox[0][:3].int().tolist()
361
+ max_coord = bbox[0][3:].int().tolist()
362
+ binary_cube = torch.zeros(binary_cube_shape)
363
+ binary_cube[min_coord[0]:max_coord[0]+1, min_coord[1]:max_coord[1]+1, min_coord[2]:max_coord[2]+1] = 1
364
+ return binary_cube
365
+
366
+ def build_binary_points(points, labels, shape):
367
+ binary_points = torch.zeros(shape, dtype=torch.int16)
368
+ binary_points[points[labels == 1, 0].long(), points[labels == 1, 1].long(), points[labels == 1, 2].long()] = 1
369
+ return binary_points
370
+
371
+ def sliding_window_inference(
372
+ inputs: torch.Tensor,
373
+ prompt_reflection: Union[torch.Tensor, Tuple[torch.Tensor, ...]],
374
+ roi_size: Union[Sequence[int], int],
375
+ sw_batch_size: int,
376
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
377
+ overlap: float = 0.25,
378
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
379
+ sigma_scale: Union[Sequence[float], float] = 0.125,
380
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
381
+ cval: float = 0.0,
382
+ sw_device: Union[torch.device, str, None] = None,
383
+ device: Union[torch.device, str, None] = None,
384
+ progress: bool = False,
385
+ roi_weight_map: Union[torch.Tensor, None] = None,
386
+ *args: Any,
387
+ **kwargs: Any,
388
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
389
+ """
390
+ Sliding window inference on `inputs` with `predictor`.
391
+
392
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
393
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
394
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
395
+ could be ([128,64,256], [64,32,128]).
396
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
397
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
398
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
399
+
400
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
401
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
402
+
403
+ Args:
404
+ inputs: input image to be processed (assuming NCHW[D])
405
+ roi_size: the spatial window size for inferences.
406
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
407
+ if the components of the `roi_size` are non-positive values, the transform will use the
408
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
409
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
410
+ sw_batch_size: the batch size to run window slices.
411
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
412
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
413
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
414
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
415
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
416
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
417
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
418
+ to ensure the scaled output ROI sizes are still integers.
419
+ If the `predictor`'s input and output spatial sizes are different,
420
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
421
+ overlap: Amount of overlap between scans.
422
+ mode: {``"constant"``, ``"gaussian"``}
423
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
424
+
425
+ - ``"constant``": gives equal weight to all predictions.
426
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
427
+
428
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
429
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
430
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
431
+ spatial dimensions.
432
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
433
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
434
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
435
+ cval: fill value for 'constant' padding mode. Default: 0
436
+ sw_device: device for the window data.
437
+ By default the device (and accordingly the memory) of the `inputs` is used.
438
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
439
+ device: device for the stitched output prediction.
440
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
441
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
442
+ `inputs` and `roi_size`. Output is on the `device`.
443
+ progress: whether to print a `tqdm` progress bar.
444
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
445
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
446
+ args: optional args to be passed to ``predictor``.
447
+ kwargs: optional keyword args to be passed to ``predictor``.
448
+
449
+ Note:
450
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
451
+
452
+ """
453
+ print('sliding window inference for ROI')
454
+ text = kwargs['text']
455
+ use_box = kwargs['use_box']
456
+ use_point = kwargs['use_point']
457
+ assert not (use_box and use_point)
458
+ compute_dtype = inputs.dtype
459
+ num_spatial_dims = len(inputs.shape) - 2
460
+ if overlap < 0 or overlap >= 1:
461
+ raise ValueError("overlap must be >= 0 and < 1.")
462
+
463
+ # determine image spatial size and batch size
464
+ # Note: all input images must have the same image size and batch size
465
+ batch_size, _, *image_size_ = inputs.shape
466
+
467
+ if device is None:
468
+ device = inputs.device
469
+ if sw_device is None:
470
+ sw_device = inputs.device
471
+
472
+ roi_size = fall_back_tuple(roi_size, image_size_)
473
+ # in case that image size is smaller than roi size
474
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
475
+ pad_size = []
476
+ for k in range(len(inputs.shape) - 1, 1, -1):
477
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
478
+ half = diff // 2
479
+ pad_size.extend([half, diff - half])
480
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
481
+ #############
482
+ if use_point or use_box:
483
+ binary_prompt_map, global_preds = prompt_reflection
484
+ global_preds = F.pad(global_preds, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval)
485
+ #############
486
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
487
+
488
+ # Store all slices in list
489
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
490
+ num_win = len(slices) # number of windows per image
491
+ total_slices = num_win * batch_size # total number of windows
492
+
493
+ # Create window-level importance map
494
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
495
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
496
+ importance_map = roi_weight_map
497
+ else:
498
+ try:
499
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
500
+ except BaseException as e:
501
+ raise RuntimeError(
502
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
503
+ ) from e
504
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
505
+ # handle non-positive weights
506
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
507
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
508
+
509
+ # Perform predictions
510
+ dict_key, output_image_list, count_map_list = None, [], []
511
+ _initialized_ss = -1
512
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
513
+
514
+ # for each patch
515
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
516
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
517
+ unravel_slice = [
518
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
519
+ for idx in slice_range
520
+ ]
521
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
522
+ #############
523
+
524
+ boxes = None
525
+ points = None
526
+ if use_point:
527
+ window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
528
+ point, point_label = select_points(window_binary_prompt_map.squeeze())
529
+ points = (point.unsqueeze(0).float().cuda(), point_label.unsqueeze(0).float().cuda())
530
+ pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
531
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().cuda()
532
+ if use_box:
533
+ if num_win == 1:
534
+ window_binary_prompt_map = torch.cat([binary_prompt_map[win_slice] for win_slice in unravel_slice]).to(sw_device)
535
+ boxes = generate_box(window_binary_prompt_map.squeeze()).unsqueeze(0).float().cuda()
536
+ else:
537
+ pseudo_label = torch.cat([global_preds[win_slice] for win_slice in unravel_slice]).to(sw_device)
538
+ boxes = generate_box(pseudo_label.squeeze()).unsqueeze(0).float().cuda()
539
+ seg_prob_out = predictor(window_data, text, boxes, points) # batched patch segmentation
540
+ #############
541
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
542
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
543
+ if isinstance(seg_prob_out, torch.Tensor):
544
+ seg_prob_tuple = (seg_prob_out,)
545
+ elif isinstance(seg_prob_out, Mapping):
546
+ if dict_key is None:
547
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
548
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
549
+ is_tensor_output = False
550
+ else:
551
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
552
+ is_tensor_output = False
553
+
554
+ # for each output in multi-output list
555
+ for ss, seg_prob in enumerate(seg_prob_tuple):
556
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
557
+
558
+ # compute zoom scale: out_roi_size/in_roi_size
559
+ zoom_scale = []
560
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
561
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
562
+ ):
563
+ _scale = out_w_i / float(in_w_i)
564
+ if not (img_s_i * _scale).is_integer():
565
+ warnings.warn(
566
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
567
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
568
+ )
569
+ zoom_scale.append(_scale)
570
+
571
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
572
+ # construct multi-resolution outputs
573
+ output_classes = seg_prob.shape[1]
574
+ output_shape = [batch_size, output_classes] + [
575
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
576
+ ]
577
+ # allocate memory to store the full output and the count for overlapping parts
578
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
579
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
580
+ _initialized_ss += 1
581
+
582
+ # resizing the importance_map
583
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
584
+
585
+ # store the result in the proper location of the full output. Apply weights from importance map.
586
+ for idx, original_idx in zip(slice_range, unravel_slice):
587
+ # zoom roi
588
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
589
+ for axis in range(2, len(original_idx_zoom)):
590
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
591
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
592
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
593
+ warnings.warn(
594
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
595
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
596
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
597
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
598
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
599
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
600
+ )
601
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
602
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
603
+ # store results and weights
604
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
605
+ count_map_list[ss][original_idx_zoom] += (
606
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
607
+ )
608
+
609
+ # account for any overlapping sections
610
+ for ss in range(len(output_image_list)):
611
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
612
+
613
+ # remove padding if image_size smaller than roi_size
614
+ for ss, output_i in enumerate(output_image_list):
615
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
616
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
617
+
618
+ zoom_scale = [
619
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
620
+ ]
621
+
622
+ final_slicing: List[slice] = []
623
+ for sp in range(num_spatial_dims):
624
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
625
+ slice_dim = slice(
626
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
627
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
628
+ )
629
+ final_slicing.insert(0, slice_dim)
630
+ while len(final_slicing) < len(output_i.shape):
631
+ final_slicing.insert(0, slice(None))
632
+ output_image_list[ss] = output_i[final_slicing]
633
+
634
+ if dict_key is not None: # if output of predictor is a dict
635
+ final_output = dict(zip(dict_key, output_image_list))
636
+ else:
637
+ final_output = tuple(output_image_list) # type: ignore
638
+ return final_output[0] if is_tensor_output else final_output # type: ignore
639
+
640
+
641
+ def _get_scan_interval(
642
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
643
+ ) -> Tuple[int, ...]:
644
+ """
645
+ Compute scan interval according to the image size, roi size and overlap.
646
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
647
+ use 1 instead to make sure sliding window works.
648
+
649
+ """
650
+ if len(image_size) != num_spatial_dims:
651
+ raise ValueError("image coord different from spatial dims.")
652
+ if len(roi_size) != num_spatial_dims:
653
+ raise ValueError("roi coord different from spatial dims.")
654
+
655
+ scan_interval = []
656
+ for i in range(num_spatial_dims):
657
+ if roi_size[i] == image_size[i]:
658
+ scan_interval.append(int(roi_size[i]))
659
+ else:
660
+ interval = int(roi_size[i] * (1 - overlap))
661
+ scan_interval.append(interval if interval > 0 else 1)
662
+ return tuple(scan_interval)
663
+
664
+
665
+ def generate_box(pred_pre, bbox_shift=None):
666
+ meaning_post_label = pred_pre # [h, w, d]
667
+ ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
668
+ if all(tensor.nelement() == 0 for tensor in ones_idx):
669
+ bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
670
+ # print(bboxes, bboxes.shape)
671
+ return bboxes
672
+ min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
673
+ max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
674
+
675
+
676
+ if bbox_shift is None:
677
+ corner_min = []
678
+ corner_max = []
679
+ shape = meaning_post_label.shape
680
+ for coor in min_coords:
681
+ coor_ = max(0, coor)
682
+ corner_min.append(coor_)
683
+ for idx, coor in enumerate(max_coords):
684
+ coor_ = min(shape[idx], coor)
685
+ corner_max.append(coor_)
686
+ corner_min = torch.tensor(corner_min)
687
+ corner_max = torch.tensor(corner_max)
688
+ return torch.cat((corner_min, corner_max), dim=0)
689
+ else:
690
+ # add perturbation to bounding box coordinates
691
+ corner_min = []
692
+ corner_max = []
693
+ shape = meaning_post_label.shape
694
+ for coor in min_coords:
695
+ coor_ = max(0, coor + random.randint(-bbox_shift, bbox_shift))
696
+ corner_min.append(coor_)
697
+ for idx, coor in enumerate(max_coords):
698
+ coor_ = min(shape[idx], coor + random.randint(-bbox_shift, bbox_shift))
699
+ corner_max.append(coor_)
700
+ corner_min = torch.tensor(corner_min)
701
+ corner_max = torch.tensor(corner_max)
702
+ return torch.cat((corner_min, corner_max), dim=0)
703
+
704
+
705
+ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_point_num=None):
706
+ spacial_dim = 3
707
+ points = torch.zeros((0, 3))
708
+ labels = torch.zeros((0))
709
+ pos_thred = 0.9
710
+ neg_thred = 0.1
711
+
712
+ # get pos/net indices
713
+ positive_indices = torch.nonzero(preds > pos_thred, as_tuple=True) # ([pos x], [pos y], [pos z])
714
+ negative_indices = torch.nonzero(preds < neg_thred, as_tuple=True)
715
+
716
+ ones_idx = (preds > pos_thred).nonzero(as_tuple=True)
717
+ if all(tmp.nelement() == 0 for tmp in ones_idx):
718
+ # all neg
719
+ num_positive_extra = 0
720
+ selected_positive_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
721
+ points = torch.cat((points, selected_positive_point), dim=0)
722
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
723
+ else:
724
+ # random select a pos point
725
+ random_idx = torch.randint(len(positive_indices[0]), (1,))
726
+ selected_positive_point = torch.tensor([positive_indices[i][random_idx] for i in range(spacial_dim)]).unsqueeze(dim=0)
727
+ points = torch.cat((points, selected_positive_point), dim=0)
728
+ labels = torch.cat((labels, torch.ones((1))))
729
+
730
+ if num_positive_extra > 0:
731
+ pos_idx_list = torch.randperm(len(positive_indices[0]))[:num_positive_extra]
732
+ extra_positive_points = []
733
+ for pos_idx in pos_idx_list:
734
+ extra_positive_points.append([positive_indices[i][pos_idx] for i in range(spacial_dim)])
735
+ extra_positive_points = torch.tensor(extra_positive_points).reshape(-1, 3)
736
+ points = torch.cat((points, extra_positive_points), dim=0)
737
+ labels = torch.cat((labels, torch.ones((extra_positive_points.shape[0]))))
738
+
739
+ if num_negative_extra > 0:
740
+ neg_idx_list = torch.randperm(len(negative_indices[0]))[:num_negative_extra]
741
+ extra_negative_points = []
742
+ for neg_idx in neg_idx_list:
743
+ extra_negative_points.append([negative_indices[i][neg_idx] for i in range(spacial_dim)])
744
+ extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
745
+ points = torch.cat((points, extra_negative_points), dim=0)
746
+ labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
747
+ # print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
748
+ # print('==> points ', points.shape, labels)
749
+
750
+ if fix_extra_point_num is None:
751
+ left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
752
+ else:
753
+ left_point_num = fix_extra_point_num + 1 - labels.shape[0]
754
+
755
+ for _ in range(left_point_num):
756
+ ignore_point = torch.tensor([-1,-1,-1]).unsqueeze(dim=0)
757
+ points = torch.cat((points, ignore_point), dim=0)
758
+ labels = torch.cat((labels, torch.tensor([-1]).reshape(1)))
759
+
760
+ return (points, labels)
761
+
762
+ # build 3D SAM
763
+ import torch
764
+ import numpy as np
765
+ from monai.networks.nets import ViT
766
+
767
+ def _build_sam(
768
+ image_encoder_type,
769
+ embed_dim,
770
+ patch_size,
771
+ checkpoint,
772
+ image_size,
773
+ ):
774
+ mlp_dim = 3072
775
+ num_layers = 12
776
+ num_heads = 12
777
+ pos_embed = 'perceptron'
778
+ dropout_rate = 0.0
779
+
780
+ image_encoder=ViT(
781
+ in_channels=1,
782
+ img_size=image_size,
783
+ patch_size=patch_size,
784
+ hidden_size=embed_dim,
785
+ mlp_dim=mlp_dim,
786
+ num_layers=num_layers,
787
+ num_heads=num_heads,
788
+ pos_embed=pos_embed,
789
+ classification=False,
790
+ dropout_rate=dropout_rate,
791
+ )
792
+ image_embedding_size = [int(item) for item in (np.array(image_size) / np.array(patch_size))]
793
+
794
+ if checkpoint is not None:
795
+ with open(checkpoint, "rb") as f:
796
+ state_dict = torch.load(f, map_location='cpu')['state_dict']
797
+ encoder_dict = {k.replace('model.encoder.', ''): v for k, v in state_dict.items() if 'model.encoder.' in k}
798
+ image_encoder.load_state_dict(encoder_dict)
799
+ print(f'===> image_encoder.load_param: {checkpoint}')
800
+ sam = Sam(
801
+ image_encoder=image_encoder,
802
+ prompt_encoder=PromptEncoder(
803
+ embed_dim=embed_dim,
804
+ image_embedding_size=image_embedding_size,
805
+ input_image_size=image_size,
806
+ mask_in_chans=16,
807
+ ),
808
+ mask_decoder=MaskDecoder(
809
+ image_encoder_type=image_encoder_type,
810
+ num_multimask_outputs=3,
811
+ transformer=TwoWayTransformer(
812
+ depth=2,
813
+ embedding_dim=embed_dim,
814
+ mlp_dim=2048,
815
+ num_heads=8,
816
+ ),
817
+ transformer_dim=embed_dim,
818
+ iou_head_depth=3,
819
+ iou_head_hidden_dim=256,
820
+ image_size=np.array(image_size),
821
+ patch_size=np.array(patch_size),
822
+ ),
823
+ pixel_mean=[123.675, 116.28, 103.53],
824
+ pixel_std=[58.395, 57.12, 57.375],
825
+ )
826
+ sam.eval()
827
+ return sam
828
+
829
+ # mask decoder
830
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
831
+ # All rights reserved.
832
+
833
+ # This source code is licensed under the license found in the
834
+ # LICENSE file in the root directory of this source tree.
835
+
836
+ import torch
837
+ from torch import nn
838
+ from torch.nn import functional as F
839
+
840
+ from typing import List, Tuple, Type, Optional
841
+
842
+ class MaskDecoder(nn.Module):
843
+ def __init__(
844
+ self,
845
+ *,
846
+ image_encoder_type: str,
847
+ transformer_dim: int,
848
+ transformer: nn.Module,
849
+ num_multimask_outputs: int = 3,
850
+ activation: Type[nn.Module] = nn.GELU,
851
+ iou_head_depth: int = 3,
852
+ iou_head_hidden_dim: int = 256,
853
+ image_size,
854
+ patch_size,
855
+ ) -> None:
856
+ """
857
+ Predicts masks given an image and prompt embeddings, using a
858
+ transformer architecture.
859
+
860
+ Arguments:
861
+ transformer_dim (int): the channel dimension of the transformer
862
+ transformer (nn.Module): the transformer used to predict masks
863
+ num_multimask_outputs (int): the number of masks to predict
864
+ when disambiguating masks
865
+ activation (nn.Module): the type of activation to use when
866
+ upscaling masks
867
+ iou_head_depth (int): the depth of the MLP used to predict
868
+ mask quality
869
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
870
+ used to predict mask quality
871
+ """
872
+ super().__init__()
873
+ self.transformer_dim = transformer_dim
874
+ self.transformer = transformer
875
+
876
+ self.num_multimask_outputs = num_multimask_outputs
877
+
878
+ self.iou_token = nn.Embedding(1, transformer_dim)
879
+ self.num_mask_tokens = num_multimask_outputs + 1
880
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
881
+
882
+ if image_encoder_type == 'swin_vit':
883
+ self.feat_shape = image_size/patch_size
884
+ self.output_upscaling = nn.Sequential(
885
+ nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
886
+ nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # swin
887
+ activation(),
888
+ nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), # swin
889
+ # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1), # vit
890
+ activation(),
891
+ )
892
+ else:
893
+ self.feat_shape = image_size/patch_size * 2
894
+ self.output_upscaling = nn.Sequential(
895
+ nn.ConvTranspose3d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
896
+ nn.LayerNorm((transformer_dim // 4, int(self.feat_shape[0]), int(self.feat_shape[1]), int(self.feat_shape[2]))), # vit
897
+ activation(),
898
+ nn.ConvTranspose3d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
899
+ # nn.Conv3d(transformer_dim // 4, transformer_dim // 8, kernel_size=3, stride=1, padding=1),
900
+ activation(),
901
+ )
902
+ self.output_hypernetworks_mlps = nn.ModuleList(
903
+ [
904
+ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
905
+ for i in range(self.num_mask_tokens)
906
+ ]
907
+ )
908
+
909
+ self.iou_prediction_head = MLP(
910
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
911
+ )
912
+
913
+ self.txt_align_upscaled_embedding = nn.Linear(768, 96)
914
+
915
+ def forward(
916
+ self,
917
+ image_embeddings: torch.Tensor,
918
+ text_embedding: Optional[torch.Tensor],
919
+ image_pe: torch.Tensor,
920
+ sparse_prompt_embeddings: torch.Tensor,
921
+ dense_prompt_embeddings: torch.Tensor,
922
+ multimask_output: bool,
923
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
924
+ """
925
+ Predict masks given image and prompt embeddings.
926
+
927
+ Returns:
928
+ torch.Tensor: batched predicted masks
929
+ """
930
+ # print('--------------decoder here--------------')
931
+ masks, iou_pred = self.predict_masks(
932
+ image_embeddings=image_embeddings,
933
+ text_embedding=text_embedding,
934
+ image_pe=image_pe,
935
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
936
+ dense_prompt_embeddings=dense_prompt_embeddings,
937
+ )
938
+
939
+ # Select the correct mask or masks for output
940
+ if multimask_output:
941
+ mask_slice = slice(1, None)
942
+ else:
943
+ mask_slice = slice(0, 1)
944
+ masks = masks[:, mask_slice, :, :, :]
945
+ iou_pred = iou_pred[:, mask_slice]
946
+
947
+ # Prepare output
948
+ return masks, iou_pred
949
+
950
+ def predict_masks(
951
+ self,
952
+ image_embeddings: torch.Tensor,
953
+ text_embedding: torch.Tensor,
954
+ image_pe: torch.Tensor,
955
+ sparse_prompt_embeddings: torch.Tensor,
956
+ dense_prompt_embeddings: torch.Tensor,
957
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
958
+ """Predicts masks. See 'forward' for more details."""
959
+ # Concatenate output tokens
960
+ output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
961
+ output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
962
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
963
+ # Expand per-image data in batch direction to be per-mask
964
+ if image_embeddings.shape[0] != tokens.shape[0]:
965
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
966
+ else:
967
+ src = image_embeddings
968
+ src = src + dense_prompt_embeddings
969
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
970
+ b, c, h, w, d = src.shape
971
+
972
+ # Run the transformer
973
+ hs, src = self.transformer(src, pos_src, tokens)
974
+ iou_token_out = hs[:, 0, :]
975
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
976
+
977
+ # Upscale mask embeddings and predict masks using the mask tokens
978
+ src = src.transpose(1, 2).view(b, c, h, w, d)
979
+ upscaled_embedding = self.output_upscaling(src)
980
+ hyper_in_list: List[torch.Tensor] = []
981
+ for i in range(self.num_mask_tokens):
982
+ hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
983
+ hyper_in = torch.stack(hyper_in_list, dim=1)
984
+ b, c, h, w, d = upscaled_embedding.shape
985
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d)
986
+
987
+ if text_embedding is not None:
988
+ text_embedding_down = self.txt_align_upscaled_embedding(text_embedding).unsqueeze(dim=1)
989
+ upscaled_embedding = upscaled_embedding.view(b, c, h * w * d)
990
+ sim = (text_embedding_down @ upscaled_embedding).view(b, -1, h, w, d)
991
+ sim = sim.repeat(1, masks.shape[1], 1, 1, 1)
992
+ masks = masks + sim
993
+ iou_pred = self.iou_prediction_head(iou_token_out)
994
+
995
+ return masks, iou_pred
996
+
997
+ # Lightly adapted from
998
+ # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
999
+ class MLP(nn.Module):
1000
+ def __init__(
1001
+ self,
1002
+ input_dim: int,
1003
+ hidden_dim: int,
1004
+ output_dim: int,
1005
+ num_layers: int,
1006
+ sigmoid_output: bool = False,
1007
+ ) -> None:
1008
+ super().__init__()
1009
+ self.num_layers = num_layers
1010
+ h = [hidden_dim] * (num_layers - 1)
1011
+ self.layers = nn.ModuleList(
1012
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
1013
+ )
1014
+ self.sigmoid_output = sigmoid_output
1015
+
1016
+ def forward(self, x):
1017
+ for i, layer in enumerate(self.layers):
1018
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1019
+ if self.sigmoid_output:
1020
+ x = F.sigmoid(x)
1021
+ return x
1022
+
1023
+ # prompt encoder
1024
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1025
+ # All rights reserved.
1026
+
1027
+ # This source code is licensed under the license found in the
1028
+ # LICENSE file in the root directory of this source tree.
1029
+
1030
+ import numpy as np
1031
+ import torch
1032
+ from torch import nn
1033
+
1034
+ from typing import Any, Optional, Tuple, Type
1035
+
1036
+ class PromptEncoder(nn.Module):
1037
+ def __init__(
1038
+ self,
1039
+ embed_dim: int,
1040
+ image_embedding_size: Tuple[int, int, int],
1041
+ input_image_size: Tuple[int, int, int],
1042
+ mask_in_chans: int,
1043
+ activation: Type[nn.Module] = nn.GELU,
1044
+ ) -> None:
1045
+ """
1046
+ Encodes prompts for input to SAM's mask decoder.
1047
+
1048
+ Arguments:
1049
+ embed_dim (int): The prompts' embedding dimension
1050
+ image_embedding_size (tuple(int, int)): The spatial size of the
1051
+ image embedding, as (H, W).
1052
+ input_image_size (int): The padded size of the image as input
1053
+ to the image encoder, as (H, W).
1054
+ mask_in_chans (int): The number of hidden channels used for
1055
+ encoding input masks.
1056
+ activation (nn.Module): The activation to use when encoding
1057
+ input masks.
1058
+ """
1059
+ super().__init__()
1060
+ self.embed_dim = embed_dim
1061
+ self.input_image_size = input_image_size
1062
+ self.image_embedding_size = image_embedding_size
1063
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
1064
+
1065
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
1066
+ point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
1067
+ self.point_embeddings = nn.ModuleList(point_embeddings)
1068
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
1069
+
1070
+ self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1], 4 * image_embedding_size[2])
1071
+ self.mask_downscaling = nn.Sequential(
1072
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
1073
+ LayerNorm2d(mask_in_chans // 4),
1074
+ activation(),
1075
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
1076
+ LayerNorm2d(mask_in_chans),
1077
+ activation(),
1078
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
1079
+ )
1080
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
1081
+
1082
+ def get_dense_pe(self) -> torch.Tensor:
1083
+ """
1084
+ Returns the positional encoding used to encode point prompts,
1085
+ applied to a dense set of points the shape of the image encoding.
1086
+
1087
+ Returns:
1088
+ torch.Tensor: Positional encoding with shape
1089
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
1090
+ """
1091
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
1092
+
1093
+ def _embed_points(
1094
+ self,
1095
+ points: torch.Tensor,
1096
+ labels: torch.Tensor,
1097
+ pad: bool,
1098
+ ) -> torch.Tensor:
1099
+ """Embeds point prompts."""
1100
+ points = points + 0.5 # Shift to center of pixel
1101
+ if pad:
1102
+ padding_point = torch.zeros((points.shape[0], 1, 3), device=points.device)
1103
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
1104
+ points = torch.cat([points, padding_point], dim=1)
1105
+ labels = torch.cat([labels, padding_label], dim=1)
1106
+ point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
1107
+ point_embedding[labels == -1] = 0.0
1108
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
1109
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
1110
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
1111
+ return point_embedding
1112
+
1113
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
1114
+ """Embeds box prompts."""
1115
+ boxes = boxes + 0.5 # Shift to center of pixel
1116
+ coords = boxes.reshape(-1, 2, 3)
1117
+ corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
1118
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
1119
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
1120
+ return corner_embedding
1121
+
1122
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
1123
+ """Embeds mask inputs."""
1124
+ mask_embedding = self.mask_downscaling(masks)
1125
+ return mask_embedding
1126
+
1127
+ def _get_batch_size(
1128
+ self,
1129
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
1130
+ boxes: Optional[torch.Tensor],
1131
+ masks: Optional[torch.Tensor],
1132
+ text_embedding: Optional[torch.Tensor],
1133
+ ) -> int:
1134
+ """
1135
+ Gets the batch size of the output given the batch size of the input prompts.
1136
+ """
1137
+ if points is not None:
1138
+ return points[0].shape[0]
1139
+ elif boxes is not None:
1140
+ return boxes.shape[0]
1141
+ elif masks is not None:
1142
+ return masks.shape[0]
1143
+ elif text_embedding is not None:
1144
+ return text_embedding.shape[0]
1145
+ else:
1146
+ return 1
1147
+
1148
+ def _get_device(self) -> torch.device:
1149
+ return self.point_embeddings[0].weight.device
1150
+
1151
+ def forward(
1152
+ self,
1153
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
1154
+ boxes: Optional[torch.Tensor],
1155
+ masks: Optional[torch.Tensor],
1156
+ text_embedding: Optional[torch.Tensor],
1157
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1158
+
1159
+ bs = self._get_batch_size(points, boxes, masks, text_embedding)
1160
+ sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
1161
+
1162
+ if points is not None:
1163
+ coords, labels = points
1164
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
1165
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
1166
+
1167
+ if boxes is not None:
1168
+ box_embeddings = self._embed_boxes(boxes)
1169
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
1170
+
1171
+ if text_embedding is not None:
1172
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embedding.unsqueeze(dim=1)], dim=1)
1173
+
1174
+ if masks is not None:
1175
+ dense_embeddings = self._embed_masks(masks)
1176
+ else:
1177
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1, 1).expand(
1178
+ bs, -1, int(self.image_embedding_size[0]), int(self.image_embedding_size[1]), int(self.image_embedding_size[2])
1179
+ )
1180
+
1181
+ return sparse_embeddings, dense_embeddings
1182
+
1183
+
1184
+ class PositionEmbeddingRandom(nn.Module):
1185
+ """
1186
+ Positional encoding using random spatial frequencies.
1187
+ """
1188
+
1189
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
1190
+ super().__init__()
1191
+ if scale is None or scale <= 0.0:
1192
+ scale = 1.0
1193
+ self.register_buffer(
1194
+ "positional_encoding_gaussian_matrix",
1195
+ scale * torch.randn((3, num_pos_feats)),
1196
+ )
1197
+
1198
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
1199
+ """Positionally encode points that are normalized to [0,1]."""
1200
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
1201
+ coords = 2 * coords - 1
1202
+ coords = coords @ self.positional_encoding_gaussian_matrix
1203
+ coords = 2 * np.pi * coords
1204
+ # outputs d_1 x ... x d_n x C shape
1205
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
1206
+
1207
+ def forward(self, size: Tuple[int, int, int]) -> torch.Tensor:
1208
+ """Generate positional encoding for a grid of the specified size."""
1209
+ h, w, d = size
1210
+ device: Any = self.positional_encoding_gaussian_matrix.device
1211
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
1212
+ y_embed = grid.cumsum(dim=0) - 0.5
1213
+ x_embed = grid.cumsum(dim=1) - 0.5
1214
+ z_embed = grid.cumsum(dim=2) - 0.5
1215
+ y_embed = y_embed / h
1216
+ x_embed = x_embed / w
1217
+ z_embed = z_embed / d
1218
+
1219
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
1220
+ return pe.permute(3, 0, 1, 2) # C x H x W x D
1221
+
1222
+ def forward_with_coords(
1223
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
1224
+ ) -> torch.Tensor:
1225
+ """Positionally encode points that are not normalized to [0,1]."""
1226
+ coords = coords_input.clone()
1227
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
1228
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
1229
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
1230
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
1231
+
1232
+ # two way transformer
1233
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
1234
+ # All rights reserved.
1235
+
1236
+ # This source code is licensed under the license found in the
1237
+ # LICENSE file in the root directory of this source tree.
1238
+
1239
+ import torch
1240
+ from torch import Tensor, nn
1241
+
1242
+ import math
1243
+ from typing import Tuple, Type
1244
+
1245
+ class TwoWayTransformer(nn.Module):
1246
+ def __init__(
1247
+ self,
1248
+ depth: int,
1249
+ embedding_dim: int,
1250
+ num_heads: int,
1251
+ mlp_dim: int,
1252
+ activation: Type[nn.Module] = nn.ReLU,
1253
+ attention_downsample_rate: int = 2,
1254
+ ) -> None:
1255
+ """
1256
+ A transformer decoder that attends to an input image using
1257
+ queries whose positional embedding is supplied.
1258
+
1259
+ Args:
1260
+ depth (int): number of layers in the transformer
1261
+ embedding_dim (int): the channel dimension for the input embeddings
1262
+ num_heads (int): the number of heads for multihead attention. Must
1263
+ divide embedding_dim
1264
+ mlp_dim (int): the channel dimension internal to the MLP block
1265
+ activation (nn.Module): the activation to use in the MLP block
1266
+ """
1267
+ super().__init__()
1268
+ self.depth = depth
1269
+ self.embedding_dim = embedding_dim
1270
+ self.num_heads = num_heads
1271
+ self.mlp_dim = mlp_dim
1272
+ self.layers = nn.ModuleList()
1273
+
1274
+ for i in range(depth):
1275
+ self.layers.append(
1276
+ TwoWayAttentionBlock(
1277
+ embedding_dim=embedding_dim,
1278
+ num_heads=num_heads,
1279
+ mlp_dim=mlp_dim,
1280
+ activation=activation,
1281
+ attention_downsample_rate=attention_downsample_rate,
1282
+ skip_first_layer_pe=(i == 0),
1283
+ )
1284
+ )
1285
+
1286
+ self.final_attn_token_to_image = Attention(
1287
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1288
+ )
1289
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
1290
+
1291
+ def forward(
1292
+ self,
1293
+ image_embedding: Tensor,
1294
+ image_pe: Tensor,
1295
+ point_embedding: Tensor,
1296
+ ) -> Tuple[Tensor, Tensor]:
1297
+ """
1298
+ Args:
1299
+ image_embedding (torch.Tensor): image to attend to. Should be shape
1300
+ B x embedding_dim x h x w for any h and w.
1301
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
1302
+ have the same shape as image_embedding.
1303
+ point_embedding (torch.Tensor): the embedding to add to the query points.
1304
+ Must have shape B x N_points x embedding_dim for any N_points.
1305
+
1306
+ Returns:
1307
+ torch.Tensor: the processed point_embedding
1308
+ torch.Tensor: the processed image_embedding
1309
+ """
1310
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
1311
+ bs, c, h, w, d = image_embedding.shape
1312
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
1313
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
1314
+
1315
+ # Prepare queries
1316
+ queries = point_embedding
1317
+ keys = image_embedding
1318
+
1319
+ # Apply transformer blocks and final layernorm
1320
+ for layer in self.layers:
1321
+ queries, keys = layer(
1322
+ queries=queries,
1323
+ keys=keys,
1324
+ query_pe=point_embedding,
1325
+ key_pe=image_pe,
1326
+ )
1327
+
1328
+ # Apply the final attention layer from the points to the image
1329
+ q = queries + point_embedding
1330
+ k = keys + image_pe
1331
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
1332
+ queries = queries + attn_out
1333
+ queries = self.norm_final_attn(queries)
1334
+
1335
+ return queries, keys
1336
+
1337
+
1338
+ class TwoWayAttentionBlock(nn.Module):
1339
+ def __init__(
1340
+ self,
1341
+ embedding_dim: int,
1342
+ num_heads: int,
1343
+ mlp_dim: int = 2048,
1344
+ activation: Type[nn.Module] = nn.ReLU,
1345
+ attention_downsample_rate: int = 2,
1346
+ skip_first_layer_pe: bool = False,
1347
+ ) -> None:
1348
+ """
1349
+ A transformer block with four layers: (1) self-attention of sparse
1350
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
1351
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
1352
+ inputs.
1353
+
1354
+ Arguments:
1355
+ embedding_dim (int): the channel dimension of the embeddings
1356
+ num_heads (int): the number of heads in the attention layers
1357
+ mlp_dim (int): the hidden dimension of the mlp block
1358
+ activation (nn.Module): the activation of the mlp block
1359
+ skip_first_layer_pe (bool): skip the PE on the first layer
1360
+ """
1361
+ super().__init__()
1362
+ self.self_attn = Attention(embedding_dim, num_heads)
1363
+ self.norm1 = nn.LayerNorm(embedding_dim)
1364
+
1365
+ self.cross_attn_token_to_image = Attention(
1366
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1367
+ )
1368
+ self.norm2 = nn.LayerNorm(embedding_dim)
1369
+
1370
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
1371
+ self.norm3 = nn.LayerNorm(embedding_dim)
1372
+
1373
+ self.norm4 = nn.LayerNorm(embedding_dim)
1374
+ self.cross_attn_image_to_token = Attention(
1375
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
1376
+ )
1377
+
1378
+ self.skip_first_layer_pe = skip_first_layer_pe
1379
+
1380
+ def forward(
1381
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
1382
+ ) -> Tuple[Tensor, Tensor]:
1383
+ # Self attention block
1384
+ if self.skip_first_layer_pe:
1385
+ queries = self.self_attn(q=queries, k=queries, v=queries)
1386
+ else:
1387
+ q = queries + query_pe
1388
+ attn_out = self.self_attn(q=q, k=q, v=queries)
1389
+ queries = queries + attn_out
1390
+ queries = self.norm1(queries)
1391
+
1392
+ # Cross attention block, tokens attending to image embedding
1393
+ q = queries + query_pe
1394
+ k = keys + key_pe
1395
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
1396
+ queries = queries + attn_out
1397
+ queries = self.norm2(queries)
1398
+
1399
+ # MLP block
1400
+ mlp_out = self.mlp(queries)
1401
+ queries = queries + mlp_out
1402
+ queries = self.norm3(queries)
1403
+
1404
+ # Cross attention block, image embedding attending to tokens
1405
+ q = queries + query_pe
1406
+ k = keys + key_pe
1407
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
1408
+ keys = keys + attn_out
1409
+ keys = self.norm4(keys)
1410
+
1411
+ return queries, keys
1412
+
1413
+
1414
+ class Attention(nn.Module):
1415
+ """
1416
+ An attention layer that allows for downscaling the size of the embedding
1417
+ after projection to queries, keys, and values.
1418
+ """
1419
+
1420
+ def __init__(
1421
+ self,
1422
+ embedding_dim: int,
1423
+ num_heads: int,
1424
+ downsample_rate: int = 1,
1425
+ ) -> None:
1426
+ super().__init__()
1427
+ self.embedding_dim = embedding_dim
1428
+ self.internal_dim = embedding_dim // downsample_rate
1429
+ self.num_heads = num_heads
1430
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
1431
+
1432
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
1433
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
1434
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
1435
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
1436
+
1437
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
1438
+ b, n, c = x.shape
1439
+ x = x.reshape(b, n, num_heads, c // num_heads)
1440
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
1441
+
1442
+ def _recombine_heads(self, x: Tensor) -> Tensor:
1443
+ b, n_heads, n_tokens, c_per_head = x.shape
1444
+ x = x.transpose(1, 2)
1445
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
1446
+
1447
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
1448
+ # Input projections
1449
+ q = self.q_proj(q)
1450
+ k = self.k_proj(k)
1451
+ v = self.v_proj(v)
1452
+
1453
+ # Separate into heads
1454
+ q = self._separate_heads(q, self.num_heads)
1455
+ k = self._separate_heads(k, self.num_heads)
1456
+ v = self._separate_heads(v, self.num_heads)
1457
+
1458
+ # Attention
1459
+ _, _, _, c_per_head = q.shape
1460
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
1461
+ attn = attn / math.sqrt(c_per_head)
1462
+ attn = torch.softmax(attn, dim=-1)
1463
+
1464
+ # Get output
1465
+ out = attn @ v
1466
+ out = self._recombine_heads(out)
1467
+ out = self.out_proj(out)
1468
+
1469
+ return out
1470
+
1471
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
1472
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
1473
+ class LayerNorm2d(nn.Module):
1474
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
1475
+ super().__init__()
1476
+ self.weight = nn.Parameter(torch.ones(num_channels))
1477
+ self.bias = nn.Parameter(torch.zeros(num_channels))
1478
+ self.eps = eps
1479
+
1480
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1481
+ u = x.mean(1, keepdim=True)
1482
+ s = (x - u).pow(2).mean(1, keepdim=True)
1483
+ x = (x - u) / torch.sqrt(s + self.eps)
1484
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
1485
+ return x
1486
+
1487
+ class MLPBlock(nn.Module):
1488
+ def __init__(
1489
+ self,
1490
+ embedding_dim: int,
1491
+ mlp_dim: int,
1492
+ act: Type[nn.Module] = nn.GELU,
1493
+ ) -> None:
1494
+ super().__init__()
1495
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
1496
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
1497
+ self.act = act()
1498
+
1499
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1500
+ return self.lin2(self.act(self.lin1(x)))
1501
+
1502
+
1503
+ # sam
1504
+ class Sam(nn.Module):
1505
+ mask_threshold: float = 0.0
1506
+ image_format: str = "RGB"
1507
+
1508
+ def __init__(
1509
+ self,
1510
+ image_encoder,
1511
+ prompt_encoder,
1512
+ mask_decoder,
1513
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
1514
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
1515
+ ) -> None:
1516
+ """
1517
+ SAM predicts object masks from an image and input prompts.
1518
+
1519
+ Arguments:
1520
+ image_encoder (ImageEncoderViT): The backbone used to encode the
1521
+ image into image embeddings that allow for efficient mask prediction.
1522
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
1523
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
1524
+ and encoded prompts.
1525
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
1526
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
1527
+ """
1528
+ super().__init__()
1529
+ self.image_encoder = image_encoder
1530
+ self.prompt_encoder = prompt_encoder
1531
+ self.mask_decoder = mask_decoder
1532
+ self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
1533
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
1534
+
1535
+ @property
1536
+ def device(self) -> Any:
1537
+ return self.pixel_mean.device
1538
+
1539
+ @torch.no_grad()
1540
+ def forward(
1541
+ self,
1542
+ batched_input: List[Dict[str, Any]],
1543
+ multimask_output: bool,
1544
+ ) -> List[Dict[str, torch.Tensor]]:
1545
+ """
1546
+ Predicts masks end-to-end from provided images and prompts.
1547
+ If prompts are not known in advance, using SamPredictor is
1548
+ recommended over calling the model directly.
1549
+
1550
+ Arguments:
1551
+ batched_input (list(dict)): A list over input images, each a
1552
+ dictionary with the following keys. A prompt key can be
1553
+ excluded if it is not present.
1554
+ 'image': The image as a torch tensor in 3xHxW format,
1555
+ already transformed for input to the model.
1556
+ 'original_size': (tuple(int, int)) The original size of
1557
+ the image before transformation, as (H, W).
1558
+ 'point_coords': (torch.Tensor) Batched point prompts for
1559
+ this image, with shape BxNx2. Already transformed to the
1560
+ input frame of the model.
1561
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
1562
+ with shape BxN.
1563
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
1564
+ Already transformed to the input frame of the model.
1565
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
1566
+ in the form Bx1xHxW.
1567
+ multimask_output (bool): Whether the model should predict multiple
1568
+ disambiguating masks, or return a single mask.
1569
+
1570
+ Returns:
1571
+ (list(dict)): A list over input images, where each element is
1572
+ as dictionary with the following keys.
1573
+ 'masks': (torch.Tensor) Batched binary mask predictions,
1574
+ with shape BxCxHxW, where B is the number of input prompts,
1575
+ C is determined by multimask_output, and (H, W) is the
1576
+ original size of the image.
1577
+ 'iou_predictions': (torch.Tensor) The model's predictions
1578
+ of mask quality, in shape BxC.
1579
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
1580
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
1581
+ to subsequent iterations of prediction.
1582
+ """
1583
+ input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
1584
+ image_embeddings = self.image_encoder(input_images)
1585
+
1586
+ outputs = []
1587
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
1588
+ if "point_coords" in image_record:
1589
+ points = (image_record["point_coords"], image_record["point_labels"])
1590
+ else:
1591
+ points = None
1592
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1593
+ points=points,
1594
+ boxes=image_record.get("boxes", None),
1595
+ masks=image_record.get("mask_inputs", None),
1596
+ )
1597
+ low_res_masks, iou_predictions = self.mask_decoder(
1598
+ image_embeddings=curr_embedding.unsqueeze(0),
1599
+ image_pe=self.prompt_encoder.get_dense_pe(),
1600
+ sparse_prompt_embeddings=sparse_embeddings,
1601
+ dense_prompt_embeddings=dense_embeddings,
1602
+ multimask_output=multimask_output,
1603
+ )
1604
+ masks = self.postprocess_masks(
1605
+ low_res_masks,
1606
+ input_size=image_record["image"].shape[-2:],
1607
+ original_size=image_record["original_size"],
1608
+ )
1609
+ masks = masks > self.mask_threshold
1610
+ outputs.append(
1611
+ {
1612
+ "masks": masks,
1613
+ "iou_predictions": iou_predictions,
1614
+ "low_res_logits": low_res_masks,
1615
+ }
1616
+ )
1617
+ return outputs
1618
+
1619
+ def postprocess_masks(
1620
+ self,
1621
+ masks: torch.Tensor,
1622
+ input_size: Tuple[int, ...],
1623
+ original_size: Tuple[int, ...],
1624
+ ) -> torch.Tensor:
1625
+ """
1626
+ Remove padding and upscale masks to the original image size.
1627
+
1628
+ Arguments:
1629
+ masks (torch.Tensor): Batched masks from the mask_decoder,
1630
+ in BxCxHxW format.
1631
+ input_size (tuple(int, int)): The size of the image input to the
1632
+ model, in (H, W) format. Used to remove padding.
1633
+ original_size (tuple(int, int)): The original size of the image
1634
+ before resizing for input to the model, in (H, W) format.
1635
+
1636
+ Returns:
1637
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
1638
+ is given by original_size.
1639
+ """
1640
+ masks = F.interpolate(
1641
+ masks,
1642
+ (self.image_encoder.img_size, self.image_encoder.img_size),
1643
+ mode="bilinear",
1644
+ align_corners=False,
1645
+ )
1646
+ masks = masks[..., : input_size[0], : input_size[1]]
1647
+ masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
1648
+ return masks
1649
+
1650
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
1651
+ """Normalize pixel values and pad to a square input."""
1652
+ # Normalize colors
1653
+ # TODO
1654
+ x = (x - self.pixel_mean) / self.pixel_std
1655
+
1656
+ # Pad
1657
+ h, w = x.shape[-2:]
1658
+ padh = self.image_encoder.img_size - h
1659
+ padw = self.image_encoder.img_size - w
1660
+ x = F.pad(x, (0, padw, 0, padh))
1661
+ return x