mae / models /mae.py
adelelsayed1991's picture
Upload folder using huggingface_hub
5ffe2e2 verified
import torch
import torch.nn as nn
import math
def patchify(x,patch_size=8):
b,c,h,w=x.shape
th=h//patch_size
tw=w//patch_size
assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size"
out=x.reshape(b,c,th,patch_size,tw,patch_size)
out=out.permute(0,2,4,1,3,5).contiguous()
out=out.view(b,th*tw,c*(patch_size**2))
return out
def unpatchify(x,patch_size=8):
b,z,p=x.shape
c=p//(patch_size**2)
th=int(math.sqrt(z))
tw=th
h=th*patch_size
w=tw*patch_size
x=x.view(b,th,tw,c,patch_size,patch_size)
x=x.permute(0,3,1,4,2,5).contiguous()
out=x.view(b,c,h,w)
return out
def random_mask(x,mask_ratio=0.75):
b,n,p=x.shape
len_keep=int(n*(1-mask_ratio))
noise=torch.rand(b,n).to(x.device)
ids_shuffle=torch.argsort(noise,dim=1)
ids_restore=torch.argsort(ids_shuffle,dim=1)
ids_keep=ids_shuffle[:,:len_keep]
x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device)
mask=torch.ones(b,n).to(x.device)
mask[:,:len_keep]=0
mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device)
return x_masked,mask,ids_restore,ids_keep
def mae_loss(pred, target, mask):
# pred/target: (B, N, P), mask: (B, N) with 1=masked
B, N, P = pred.shape
mask = mask.unsqueeze(-1).float() # (B, N, 1)
loss = (pred - target) ** 2
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
return loss
class PositionalEncoding(nn.Module):
def __init__(self,num_patches,hidden_dim=768):
super().__init__()
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x, visible_indices):
# x: (B, len_keep, D); visible_indices: (B, len_keep)
B, L, D = x.shape
# expand table to (B, N, D)
pos = self.pos_embed.expand(B, -1, -1) # (B, N, D)
# build gather index (B, L, D)
idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1))
visible_pos = torch.gather(pos, 1, idx) # (B, L, D)
return x + visible_pos
class TransformerBlock(nn.Module):
def __init__(self,hidden_dim,mlp_dim,num_heads,dropout):
super().__init__()
self.layernorm1=nn.LayerNorm(hidden_dim)
self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout)
self.layernorm2=nn.LayerNorm(hidden_dim)
self.mlp=nn.Sequential(
nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout)
)
def forward(self,x):
residual=x
x=self.layernorm1(x)
attn,_=self.multihead(x,x,x)
x=residual+attn
residual=x
x=self.layernorm2(x)
x=self.mlp(x)
x=residual+x
return x
class MAEEncoder(nn.Module):
"""
patch_dim-> % non-masked * no. of patches
"""
def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8):
super().__init__()
self.mask_ratio=mask_ratio
self.patch_size=patch_size
self.patch_embed=nn.Linear(patch_dim,hidden_dim)
self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim)
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
for _ in range(depth)])
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self,x_in):
x_p=patchify(x_in,self.patch_size)
x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio)
x= self.patch_embed(x_masked)
x=self.pos_embed(x,ids_keep)
for attn_layer in self.transformer:x=attn_layer(x)
return x,mask,ids_keep,ids_restore
class MAEDecoder(nn.Module):
def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout):
super().__init__()
self.num_patches=num_patches
self.encoder_dim=encoder_dim
self.decoder_dim=decoder_dim
self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim))
self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim)
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim))
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
for _ in range(decoder_depth)])
self.layernorm=nn.LayerNorm(decoder_dim)
self.pred=nn.Linear(decoder_dim,c*(patch_size**2))
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.mask_token, std=0.02)
def forward(self,x,ids_keep,ids_restore):
b,n,p=x.shape
xdec=self.enc_to_dec(x)
len_keep=xdec.size(1)
num_patches=ids_restore.size(1)
num_mask=num_patches-len_keep
mask_token=self.mask_token.expand(b,num_mask,-1)
x_=torch.cat([xdec,mask_token],dim=1)
x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1)))
x_=x_+self.pos_embed
for block in self.transformer:x_=block(x_)
x_=self.layernorm(x_)
out=self.pred(x_)
return out
class MaskedAutoEncoder(nn.Module):
def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8):
super().__init__()
self.patch_size=patch_size
self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2
,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head
,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size)
self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size
,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth
,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout)
def forward(self,x):
b,c,h,w=x.shape
encoded,mask,ids_keep,ids_restore=self.encoder(x)
decoded=self.decoder(encoded,ids_keep,ids_restore)
xpatched=patchify(x,self.patch_size)
return xpatched,decoded,mask
@staticmethod
def testme():
img=torch.rand(1,1,384,384)
mae=MaskedAutoEncoder()
a,b,c=mae(img)
print(a.shape)
print(b.shape)
print(c.shape)