|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import math
|
|
|
def patchify(x,patch_size=8):
|
|
|
b,c,h,w=x.shape
|
|
|
th=h//patch_size
|
|
|
tw=w//patch_size
|
|
|
assert h%patch_size==0 and w%patch_size==0, "Image size must be divisible by patch_size"
|
|
|
|
|
|
out=x.reshape(b,c,th,patch_size,tw,patch_size)
|
|
|
out=out.permute(0,2,4,1,3,5).contiguous()
|
|
|
out=out.view(b,th*tw,c*(patch_size**2))
|
|
|
return out
|
|
|
def unpatchify(x,patch_size=8):
|
|
|
b,z,p=x.shape
|
|
|
c=p//(patch_size**2)
|
|
|
th=int(math.sqrt(z))
|
|
|
tw=th
|
|
|
h=th*patch_size
|
|
|
w=tw*patch_size
|
|
|
x=x.view(b,th,tw,c,patch_size,patch_size)
|
|
|
x=x.permute(0,3,1,4,2,5).contiguous()
|
|
|
out=x.view(b,c,h,w)
|
|
|
return out
|
|
|
def random_mask(x,mask_ratio=0.75):
|
|
|
b,n,p=x.shape
|
|
|
len_keep=int(n*(1-mask_ratio))
|
|
|
noise=torch.rand(b,n).to(x.device)
|
|
|
ids_shuffle=torch.argsort(noise,dim=1)
|
|
|
ids_restore=torch.argsort(ids_shuffle,dim=1)
|
|
|
ids_keep=ids_shuffle[:,:len_keep]
|
|
|
x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).expand(-1,-1,p)).to(x.device)
|
|
|
mask=torch.ones(b,n).to(x.device)
|
|
|
mask[:,:len_keep]=0
|
|
|
mask=torch.gather(mask,dim=1,index=ids_restore).to(x.device)
|
|
|
return x_masked,mask,ids_restore,ids_keep
|
|
|
|
|
|
def mae_loss(pred, target, mask):
|
|
|
|
|
|
B, N, P = pred.shape
|
|
|
mask = mask.unsqueeze(-1).float()
|
|
|
loss = (pred - target) ** 2
|
|
|
loss = (loss * mask).sum() / mask.sum().clamp_min(1.0)
|
|
|
return loss
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
def __init__(self,num_patches,hidden_dim=768):
|
|
|
super().__init__()
|
|
|
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,hidden_dim))
|
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
|
def forward(self, x, visible_indices):
|
|
|
|
|
|
B, L, D = x.shape
|
|
|
|
|
|
pos = self.pos_embed.expand(B, -1, -1)
|
|
|
|
|
|
idx = visible_indices.unsqueeze(-1).expand(B, L, pos.size(-1))
|
|
|
visible_pos = torch.gather(pos, 1, idx)
|
|
|
return x + visible_pos
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
def __init__(self,hidden_dim,mlp_dim,num_heads,dropout):
|
|
|
super().__init__()
|
|
|
self.layernorm1=nn.LayerNorm(hidden_dim)
|
|
|
self.multihead=nn.MultiheadAttention(batch_first=True,embed_dim=hidden_dim,num_heads=num_heads,dropout=dropout)
|
|
|
self.layernorm2=nn.LayerNorm(hidden_dim)
|
|
|
self.mlp=nn.Sequential(
|
|
|
nn.Linear(hidden_dim,mlp_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(mlp_dim,hidden_dim),nn.Dropout(dropout)
|
|
|
)
|
|
|
|
|
|
|
|
|
def forward(self,x):
|
|
|
residual=x
|
|
|
x=self.layernorm1(x)
|
|
|
attn,_=self.multihead(x,x,x)
|
|
|
x=residual+attn
|
|
|
residual=x
|
|
|
x=self.layernorm2(x)
|
|
|
x=self.mlp(x)
|
|
|
x=residual+x
|
|
|
return x
|
|
|
|
|
|
class MAEEncoder(nn.Module):
|
|
|
"""
|
|
|
patch_dim-> % non-masked * no. of patches
|
|
|
"""
|
|
|
def __init__(self,patch_dim,num_patches=(384//4)**2,hidden_dim=768,mlp_dim=768*4,num_heads=8,depth=12,dropout=0.25,mask_ratio=0.75,patch_size=8):
|
|
|
super().__init__()
|
|
|
self.mask_ratio=mask_ratio
|
|
|
self.patch_size=patch_size
|
|
|
self.patch_embed=nn.Linear(patch_dim,hidden_dim)
|
|
|
self.pos_embed=PositionalEncoding(num_patches=num_patches,hidden_dim=hidden_dim)
|
|
|
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=hidden_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
|
|
|
for _ in range(depth)])
|
|
|
|
|
|
self._init_weights()
|
|
|
def _init_weights(self):
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
|
if m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
|
|
def forward(self,x_in):
|
|
|
x_p=patchify(x_in,self.patch_size)
|
|
|
x_masked,mask,ids_restore,ids_keep=random_mask(x_p,self.mask_ratio)
|
|
|
x= self.patch_embed(x_masked)
|
|
|
x=self.pos_embed(x,ids_keep)
|
|
|
for attn_layer in self.transformer:x=attn_layer(x)
|
|
|
return x,mask,ids_keep,ids_restore
|
|
|
|
|
|
class MAEDecoder(nn.Module):
|
|
|
def __init__(self,c,num_patches,patch_size,encoder_dim,decoder_dim,decoder_depth,mlp_dim,num_heads,dropout):
|
|
|
super().__init__()
|
|
|
self.num_patches=num_patches
|
|
|
self.encoder_dim=encoder_dim
|
|
|
self.decoder_dim=decoder_dim
|
|
|
self.mask_token=nn.Parameter(torch.empty(1,1,decoder_dim))
|
|
|
self.enc_to_dec=nn.Linear(encoder_dim,decoder_dim)
|
|
|
self.pos_embed=nn.Parameter(torch.empty(1,num_patches,decoder_dim))
|
|
|
self.transformer=nn.ModuleList([TransformerBlock(hidden_dim=decoder_dim,mlp_dim=mlp_dim,num_heads=num_heads,dropout=dropout)
|
|
|
for _ in range(decoder_depth)])
|
|
|
self.layernorm=nn.LayerNorm(decoder_dim)
|
|
|
self.pred=nn.Linear(decoder_dim,c*(patch_size**2))
|
|
|
|
|
|
self._init_weights()
|
|
|
def _init_weights(self):
|
|
|
for m in self.modules():
|
|
|
if isinstance(m, nn.Linear):
|
|
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
|
|
if m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
|
|
nn.init.trunc_normal_(self.mask_token, std=0.02)
|
|
|
def forward(self,x,ids_keep,ids_restore):
|
|
|
b,n,p=x.shape
|
|
|
xdec=self.enc_to_dec(x)
|
|
|
len_keep=xdec.size(1)
|
|
|
num_patches=ids_restore.size(1)
|
|
|
num_mask=num_patches-len_keep
|
|
|
|
|
|
mask_token=self.mask_token.expand(b,num_mask,-1)
|
|
|
x_=torch.cat([xdec,mask_token],dim=1)
|
|
|
x_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).expand(-1,-1,x_.size(-1)))
|
|
|
x_=x_+self.pos_embed
|
|
|
for block in self.transformer:x_=block(x_)
|
|
|
x_=self.layernorm(x_)
|
|
|
out=self.pred(x_)
|
|
|
return out
|
|
|
|
|
|
class MaskedAutoEncoder(nn.Module):
|
|
|
def __init__(self,c=1,mask_ratio=0.75,dropout=0.25,img_size=384,encoder_dim=768,mlp_dim=3072,decoder_dim=512,encoder_depth=12,encoder_head=8,decoder_depth=8,decoder_head=8,patch_size=8):
|
|
|
super().__init__()
|
|
|
self.patch_size=patch_size
|
|
|
self.encoder=MAEEncoder(patch_dim=c*(patch_size**2),num_patches=(img_size//patch_size)**2
|
|
|
,hidden_dim=encoder_dim,mlp_dim=mlp_dim,num_heads=encoder_head
|
|
|
,depth=encoder_depth,dropout=dropout,mask_ratio=mask_ratio,patch_size=patch_size)
|
|
|
self.decoder=MAEDecoder(c,num_patches=(img_size//patch_size)**2,patch_size=patch_size
|
|
|
,encoder_dim=encoder_dim,decoder_dim=decoder_dim,decoder_depth=decoder_depth
|
|
|
,mlp_dim=mlp_dim,num_heads=decoder_head,dropout=dropout)
|
|
|
|
|
|
def forward(self,x):
|
|
|
b,c,h,w=x.shape
|
|
|
encoded,mask,ids_keep,ids_restore=self.encoder(x)
|
|
|
decoded=self.decoder(encoded,ids_keep,ids_restore)
|
|
|
|
|
|
xpatched=patchify(x,self.patch_size)
|
|
|
return xpatched,decoded,mask
|
|
|
|
|
|
@staticmethod
|
|
|
def testme():
|
|
|
img=torch.rand(1,1,384,384)
|
|
|
mae=MaskedAutoEncoder()
|
|
|
a,b,c=mae(img)
|
|
|
print(a.shape)
|
|
|
print(b.shape)
|
|
|
print(c.shape) |