xmutly commited on
Commit
88e5994
·
verified ·
1 Parent(s): 5df2892

Upload adversarial_training_clip_with_object_token.py

Browse files
train/adversarial_training_clip_with_object_token.py CHANGED
@@ -31,6 +31,8 @@ import argparse
31
  from slots.DINOSAUR import DINOSAURpp
32
  import matplotlib.pyplot as plt
33
  from einops import rearrange, repeat
 
 
34
 
35
  parser = argparse.ArgumentParser()
36
  parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
@@ -129,9 +131,42 @@ def main(args):
129
  ####################################################### get slot-attention model #########################################################
130
  cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
131
  model_slots = DINOSAURpp(cfg_dict)
132
- proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  if args.optimizer_state != '':
134
  proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
 
 
135
 
136
 
137
 
@@ -338,7 +373,37 @@ def train_one_epoch(
338
  embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
339
  reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
340
 
341
- object_token = proj_head(slots)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  # loss for the attack
344
  loss_inner_wrapper = ComputeLossWrapper(
 
31
  from slots.DINOSAUR import DINOSAURpp
32
  import matplotlib.pyplot as plt
33
  from einops import rearrange, repeat
34
+ from IPG.IPG_arch import IPG
35
+
36
 
37
  parser = argparse.ArgumentParser()
38
  parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
 
131
  ####################################################### get slot-attention model #########################################################
132
  cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
133
  model_slots = DINOSAURpp(cfg_dict)
134
+ # proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
135
+ # add for IPG
136
+ upscale = 1
137
+ height = (8 // upscale)
138
+ width = (8 // upscale)
139
+ proj_head = IPG(
140
+ upscale=upscale,
141
+ in_chans=64,
142
+ out_chans=64,
143
+ img_size=(height, width),
144
+ window_size=2,
145
+ img_range=1.,
146
+ depths=[2, 2],
147
+ embed_dim=256,
148
+ num_heads=[8, 8],
149
+ mlp_ratio=4,
150
+ upsampler='sam',
151
+ resi_connection='1conv',
152
+ graph_flags=[1, 1],
153
+ stage_spec=[['GN', 'GS'], ['GN', 'GS']],
154
+ dist_type='cossim',
155
+ top_k=256,
156
+ head_wise=0,
157
+ sample_size=4,
158
+ graph_switch=1,
159
+ flex_type='interdiff_plain',
160
+ FFNtype='basic-dwconv3',
161
+ conv_scale=0,
162
+ conv_type='dwconv3-gelu-conv1-ca',
163
+ diff_scales=[1.5, 1.5],
164
+ fast_graph=1
165
+ )
166
  if args.optimizer_state != '':
167
  proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
168
+ if args.slots_ckp != '':
169
+ model_slots.load_state_dict(torch.load(args.slots_ckp))
170
 
171
 
172
 
 
373
  embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
374
  reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
375
 
376
+
377
+
378
+ with torch.no_grad():
379
+ b, hw, c = reconstruction.shape
380
+ h = int(pow(hw, 0.5))
381
+ w = h
382
+ k = masks.size(1)
383
+ reconstruction = rearrange(reconstruction, 'b (h w) c -> b c h w', h=h, w=w)
384
+ masks = rearrange(masks, 'b k (h w) -> b k h w', h=h, w=w)
385
+ masks_recon_feat = torch.einsum('b k h w, b c h w -> b k c', masks, reconstruction)
386
+ masks_recon_feat = masks_recon_feat.repeat(1, k, 1)
387
+ b, hw, c = masks_recon_feat.shape
388
+ h = int(pow(hw, 0.5))
389
+ w = h
390
+ sim = F.cosine_similarity(masks_recon_feat[:,None, :, :], masks_recon_feat[:,:, None, :], dim=-1).mean(-1)
391
+ sim = rearrange(sim, 'b (h w) -> b h w', h=h, w=w)
392
+
393
+ top_values, top_indices = torch.topk(sim[:, 1], k-2)
394
+ maxsim_idx = torch.argmax(sim[:, 1], dim=-1)
395
+ top_indices_slos = top_indices.unsqueeze(-1).repeat(1,1,slots.size(-1))
396
+ top_indices_sim = top_indices.unsqueeze(-1).repeat(1,1,k-2)
397
+
398
+ h, w = k-2, k-2
399
+ slots = torch.gather(slots, dim=1, index=top_indices_slos)
400
+ sim = torch.gather(sim, dim=1, index=top_indices_sim)
401
+ slot_tokens = slots.repeat(1, k-2, 1)
402
+ slot_tokens = rearrange(slot_tokens, 'b (h w) c -> b c h w', h=h, w=w)
403
+ b, c, h, w = slot_tokens.shape
404
+ object_token = proj_head(slot_tokens, sim_matric=sim)
405
+
406
+ # object_token = proj_head(slots)
407
 
408
  # loss for the attack
409
  loss_inner_wrapper = ComputeLossWrapper(