tennant commited on
Commit
47b1e6f
β€’
1 Parent(s): 627ad10

update laion

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. model.py +3 -2
app.py CHANGED
@@ -17,7 +17,7 @@ for k, v in ckpt.items():
17
  k = k[len('image_encoder.model.'):]
18
  new_dict.update({k: v})
19
 
20
- model = mae_vit_base_patch16(uni_dim=768, less_u=True)
21
 
22
  msg = model.load_state_dict(new_dict, strict=False)
23
  print(msg)
 
17
  k = k[len('image_encoder.model.'):]
18
  new_dict.update({k: v})
19
 
20
+ model = mae_vit_base_patch16(uni_dim=768, uni_heads=12, less_u=True)
21
 
22
  msg = model.load_state_dict(new_dict, strict=False)
23
  print(msg)
model.py CHANGED
@@ -143,6 +143,7 @@ class ParallelTransformerBlock(nn.Module):
143
 
144
  attn_inner_dim = dim_head * heads
145
  ff_inner_dim = dim * ff_mult
 
146
  self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
147
 
148
  self.heads = heads
@@ -431,7 +432,7 @@ class MaskedAutoencoderViT(nn.Module):
431
  # NOTE: +1 for mask token used by MLM objective
432
  # self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
433
 
434
- self.token_emb = nn.Embedding(len(self.tokenizer.vocab), uni_dim)
435
  self.text_cls_token = nn.Parameter(torch.randn(uni_dim))
436
 
437
  self.embed_dim = embed_dim
@@ -528,7 +529,7 @@ class MaskedAutoencoderViT(nn.Module):
528
  # self.text_mask_token = nn.Parameter(torch.randn(embed_dim))
529
  self.mask_token_id = len(self.tokenizer.vocab)
530
 
531
- # self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False)
532
  self.text_length = text_length
533
 
534
  self.latent_projector_layer = projector_layer
 
143
 
144
  attn_inner_dim = dim_head * heads
145
  ff_inner_dim = dim * ff_mult
146
+ # import ipdb; ipdb.set_trace()
147
  self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
148
 
149
  self.heads = heads
 
432
  # NOTE: +1 for mask token used by MLM objective
433
  # self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
434
 
435
+ self.token_emb = nn.Embedding(len(self.tokenizer.vocab) + 1, uni_dim)
436
  self.text_cls_token = nn.Parameter(torch.randn(uni_dim))
437
 
438
  self.embed_dim = embed_dim
 
529
  # self.text_mask_token = nn.Parameter(torch.randn(embed_dim))
530
  self.mask_token_id = len(self.tokenizer.vocab)
531
 
532
+ self.text_position_embed = nn.Parameter(torch.zeros(1, text_length, embed_dim), requires_grad=False)
533
  self.text_length = text_length
534
 
535
  self.latent_projector_layer = projector_layer