Upload DINOSAUR.py
Browse files- slots/DINOSAUR.py +78 -5
slots/DINOSAUR.py
CHANGED
@@ -221,7 +221,78 @@ class Decoder(nn.Module):
|
|
221 |
|
222 |
slot_maps = self.layer4(slot_maps) # (B * S, token, 1024 + 1)
|
223 |
|
224 |
-
return slot_maps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
|
227 |
class ISA(nn.Module):
|
@@ -517,7 +588,9 @@ class DINOSAURpp(nn.Module):
|
|
517 |
else:
|
518 |
self.slot_encoder = SA(args, input_dim=1024)
|
519 |
|
520 |
-
self.slot_decoder = Decoder(args)
|
|
|
|
|
521 |
|
522 |
self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
|
523 |
init.normal_(self.pos_dec, mean=0., std=.02)
|
@@ -572,17 +645,17 @@ class DINOSAURpp(nn.Module):
|
|
572 |
rel_grid = self.slot_encoder.get_rel_grid(attn) # (B, S, token, D_slot)
|
573 |
|
574 |
slot_maps = self.sbd_slots(slots) + rel_grid # (B, S, token, D_slot)
|
575 |
-
slot_maps = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
576 |
|
577 |
else:
|
578 |
slots = self.slot_encoder(features) # (B, S, D_slot), (B, S, token)
|
579 |
assert torch.sum(torch.isnan(slots)) == 0
|
580 |
|
581 |
slot_maps, pos_maps = self.sbd_slots(slots)
|
582 |
-
slot_maps = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
583 |
|
584 |
reconstruction, masks = self.reconstruct_feature_map(slot_maps) # (B, token, 1024), (B, S, token)
|
585 |
|
586 |
-
return reconstruction, slots, masks
|
587 |
|
588 |
|
|
|
221 |
|
222 |
slot_maps = self.layer4(slot_maps) # (B * S, token, 1024 + 1)
|
223 |
|
224 |
+
return slot_maps, slot_maps
|
225 |
+
|
226 |
+
|
227 |
+
class Decoder_to_DINOV2(nn.Module):
|
228 |
+
def __init__(self, args):
|
229 |
+
super().__init__()
|
230 |
+
|
231 |
+
# === Token calculations ===
|
232 |
+
slot_dim = args['slot_dim']
|
233 |
+
hidden_dim = 2048
|
234 |
+
|
235 |
+
# === MLP Based Decoder ===
|
236 |
+
self.layer1 = nn.Linear(slot_dim, hidden_dim)
|
237 |
+
self.layer2 = nn.Linear(hidden_dim, hidden_dim)
|
238 |
+
self.layer3 = nn.Linear(hidden_dim, hidden_dim)
|
239 |
+
self.layer4 = nn.Linear(hidden_dim, 1024 + 1)
|
240 |
+
|
241 |
+
self.layer_to_dinov2 = nn.Linear(hidden_dim, 768)
|
242 |
+
self.relu = nn.ReLU(inplace=True)
|
243 |
+
|
244 |
+
def forward(self, slot_maps):
|
245 |
+
# :arg slot_maps: (B * S, token, D_slot)
|
246 |
+
slot_maps = self.relu(self.layer1(slot_maps)) # (B * S, token, D_hidden)
|
247 |
+
x_dinov2 = self.layer_to_dinov2(slot_maps)
|
248 |
+
slot_maps = self.relu(self.layer2(slot_maps)) # (B * S, token, D_hidden)
|
249 |
+
slot_maps = self.relu(self.layer3(slot_maps)) # (B * S, token, D_hidden)
|
250 |
+
|
251 |
+
slot_maps = self.layer4(slot_maps) # (B * S, token, 1024 + 1)
|
252 |
+
|
253 |
+
return slot_maps, x_dinov2
|
254 |
+
|
255 |
+
from torch.nn.init import trunc_normal_
|
256 |
+
class DINOHead(nn.Module):
|
257 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=768):
|
258 |
+
super().__init__()
|
259 |
+
nlayers = max(nlayers, 1)
|
260 |
+
if nlayers == 1:
|
261 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
262 |
+
else:
|
263 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
264 |
+
if use_bn:
|
265 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
266 |
+
layers.append(nn.GELU())
|
267 |
+
for _ in range(nlayers - 2):
|
268 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
269 |
+
if use_bn:
|
270 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
271 |
+
layers.append(nn.GELU())
|
272 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
273 |
+
self.mlp = nn.Sequential(*layers)
|
274 |
+
self.apply(self._init_weights)
|
275 |
+
self.gelu = nn.GELU()
|
276 |
+
self.last_layer1 = nn.Linear(bottleneck_dim, bottleneck_dim)
|
277 |
+
self.last_layer2 = nn.Linear(bottleneck_dim, out_dim)
|
278 |
+
|
279 |
+
# self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
280 |
+
# self.last_layer.weight_g.data.fill_(1)
|
281 |
+
# if norm_last_layer:
|
282 |
+
# self.last_layer.weight_g.requires_grad = False
|
283 |
+
|
284 |
+
def _init_weights(self, m):
|
285 |
+
if isinstance(m, nn.Linear):
|
286 |
+
trunc_normal_(m.weight, std=.02)
|
287 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
288 |
+
nn.init.constant_(m.bias, 0)
|
289 |
+
|
290 |
+
def forward(self, x):
|
291 |
+
x_dinov2 = self.mlp(x)
|
292 |
+
# x = nn.functional.normalize(x, dim=-1, p=2)
|
293 |
+
x = self.gelu(self.last_layer1(x_dinov2))
|
294 |
+
x = self.last_layer2(x)
|
295 |
+
return x, x_dinov2
|
296 |
|
297 |
|
298 |
class ISA(nn.Module):
|
|
|
588 |
else:
|
589 |
self.slot_encoder = SA(args, input_dim=1024)
|
590 |
|
591 |
+
self.slot_decoder = Decoder(args) #ori easy mlp
|
592 |
+
# self.slot_decoder = DINOHead(in_dim=256, out_dim=1024+1, nlayers=3, bottleneck_dim=768) #ori easy mlp
|
593 |
+
# self.slot_decoder = Decoder_to_DINOV2(args) #ori easy mlp
|
594 |
|
595 |
self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
|
596 |
init.normal_(self.pos_dec, mean=0., std=.02)
|
|
|
645 |
rel_grid = self.slot_encoder.get_rel_grid(attn) # (B, S, token, D_slot)
|
646 |
|
647 |
slot_maps = self.sbd_slots(slots) + rel_grid # (B, S, token, D_slot)
|
648 |
+
slot_maps, x_dinov2 = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
649 |
|
650 |
else:
|
651 |
slots = self.slot_encoder(features) # (B, S, D_slot), (B, S, token)
|
652 |
assert torch.sum(torch.isnan(slots)) == 0
|
653 |
|
654 |
slot_maps, pos_maps = self.sbd_slots(slots)
|
655 |
+
slot_maps, x_dinov2 = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
|
656 |
|
657 |
reconstruction, masks = self.reconstruct_feature_map(slot_maps) # (B, token, 1024), (B, S, token)
|
658 |
|
659 |
+
return reconstruction, slots, masks, x_dinov2
|
660 |
|
661 |
|