xmutly commited on
Commit
c5ff5f9
·
verified ·
1 Parent(s): 2aa0c85

Upload DINOSAUR.py

Browse files
Files changed (1) hide show
  1. slots/DINOSAUR.py +78 -5
slots/DINOSAUR.py CHANGED
@@ -221,7 +221,78 @@ class Decoder(nn.Module):
221
 
222
  slot_maps = self.layer4(slot_maps) #  (B * S, token, 1024 + 1)
223
 
224
- return slot_maps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
 
227
  class ISA(nn.Module):
@@ -517,7 +588,9 @@ class DINOSAURpp(nn.Module):
517
  else:
518
  self.slot_encoder = SA(args, input_dim=1024)
519
 
520
- self.slot_decoder = Decoder(args)
 
 
521
 
522
  self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
523
  init.normal_(self.pos_dec, mean=0., std=.02)
@@ -572,17 +645,17 @@ class DINOSAURpp(nn.Module):
572
  rel_grid = self.slot_encoder.get_rel_grid(attn) # (B, S, token, D_slot)
573
 
574
  slot_maps = self.sbd_slots(slots) + rel_grid # (B, S, token, D_slot)
575
- slot_maps = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
576
 
577
  else:
578
  slots = self.slot_encoder(features) # (B, S, D_slot), (B, S, token)
579
  assert torch.sum(torch.isnan(slots)) == 0
580
 
581
  slot_maps, pos_maps = self.sbd_slots(slots)
582
- slot_maps = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
583
 
584
  reconstruction, masks = self.reconstruct_feature_map(slot_maps) # (B, token, 1024), (B, S, token)
585
 
586
- return reconstruction, slots, masks
587
 
588
 
 
221
 
222
  slot_maps = self.layer4(slot_maps) #  (B * S, token, 1024 + 1)
223
 
224
+ return slot_maps, slot_maps
225
+
226
+
227
+ class Decoder_to_DINOV2(nn.Module):
228
+ def __init__(self, args):
229
+ super().__init__()
230
+
231
+ # === Token calculations ===
232
+ slot_dim = args['slot_dim']
233
+ hidden_dim = 2048
234
+
235
+ # === MLP Based Decoder ===
236
+ self.layer1 = nn.Linear(slot_dim, hidden_dim)
237
+ self.layer2 = nn.Linear(hidden_dim, hidden_dim)
238
+ self.layer3 = nn.Linear(hidden_dim, hidden_dim)
239
+ self.layer4 = nn.Linear(hidden_dim, 1024 + 1)
240
+
241
+ self.layer_to_dinov2 = nn.Linear(hidden_dim, 768)
242
+ self.relu = nn.ReLU(inplace=True)
243
+
244
+ def forward(self, slot_maps):
245
+ # :arg slot_maps: (B * S, token, D_slot)
246
+ slot_maps = self.relu(self.layer1(slot_maps)) #  (B * S, token, D_hidden)
247
+ x_dinov2 = self.layer_to_dinov2(slot_maps)
248
+ slot_maps = self.relu(self.layer2(slot_maps)) #  (B * S, token, D_hidden)
249
+ slot_maps = self.relu(self.layer3(slot_maps)) #  (B * S, token, D_hidden)
250
+
251
+ slot_maps = self.layer4(slot_maps) #  (B * S, token, 1024 + 1)
252
+
253
+ return slot_maps, x_dinov2
254
+
255
+ from torch.nn.init import trunc_normal_
256
+ class DINOHead(nn.Module):
257
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=768):
258
+ super().__init__()
259
+ nlayers = max(nlayers, 1)
260
+ if nlayers == 1:
261
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
262
+ else:
263
+ layers = [nn.Linear(in_dim, hidden_dim)]
264
+ if use_bn:
265
+ layers.append(nn.BatchNorm1d(hidden_dim))
266
+ layers.append(nn.GELU())
267
+ for _ in range(nlayers - 2):
268
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
269
+ if use_bn:
270
+ layers.append(nn.BatchNorm1d(hidden_dim))
271
+ layers.append(nn.GELU())
272
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
273
+ self.mlp = nn.Sequential(*layers)
274
+ self.apply(self._init_weights)
275
+ self.gelu = nn.GELU()
276
+ self.last_layer1 = nn.Linear(bottleneck_dim, bottleneck_dim)
277
+ self.last_layer2 = nn.Linear(bottleneck_dim, out_dim)
278
+
279
+ # self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
280
+ # self.last_layer.weight_g.data.fill_(1)
281
+ # if norm_last_layer:
282
+ # self.last_layer.weight_g.requires_grad = False
283
+
284
+ def _init_weights(self, m):
285
+ if isinstance(m, nn.Linear):
286
+ trunc_normal_(m.weight, std=.02)
287
+ if isinstance(m, nn.Linear) and m.bias is not None:
288
+ nn.init.constant_(m.bias, 0)
289
+
290
+ def forward(self, x):
291
+ x_dinov2 = self.mlp(x)
292
+ # x = nn.functional.normalize(x, dim=-1, p=2)
293
+ x = self.gelu(self.last_layer1(x_dinov2))
294
+ x = self.last_layer2(x)
295
+ return x, x_dinov2
296
 
297
 
298
  class ISA(nn.Module):
 
588
  else:
589
  self.slot_encoder = SA(args, input_dim=1024)
590
 
591
+ self.slot_decoder = Decoder(args) #ori easy mlp
592
+ # self.slot_decoder = DINOHead(in_dim=256, out_dim=1024+1, nlayers=3, bottleneck_dim=768) #ori easy mlp
593
+ # self.slot_decoder = Decoder_to_DINOV2(args) #ori easy mlp
594
 
595
  self.pos_dec = nn.Parameter(torch.Tensor(1, self.token_num, self.slot_dim))
596
  init.normal_(self.pos_dec, mean=0., std=.02)
 
645
  rel_grid = self.slot_encoder.get_rel_grid(attn) # (B, S, token, D_slot)
646
 
647
  slot_maps = self.sbd_slots(slots) + rel_grid # (B, S, token, D_slot)
648
+ slot_maps, x_dinov2 = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
649
 
650
  else:
651
  slots = self.slot_encoder(features) # (B, S, D_slot), (B, S, token)
652
  assert torch.sum(torch.isnan(slots)) == 0
653
 
654
  slot_maps, pos_maps = self.sbd_slots(slots)
655
+ slot_maps, x_dinov2 = self.slot_decoder(slot_maps) # (B, S, token, 1024 + 1)
656
 
657
  reconstruction, masks = self.reconstruct_feature_map(slot_maps) # (B, token, 1024), (B, S, token)
658
 
659
+ return reconstruction, slots, masks, x_dinov2
660
 
661