Maskgit-pytorch / Models /models /transformer.py
llvictorll's picture
add gradio app
8513f87 verified
# BERT architecture for the Masked Bidirectional Encoder Transformer
import torch
from torch import nn
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim, bias=True),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim, bias=True),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.):
super(Attention, self).__init__()
self.dim = embed_dim
self.mha = nn.MultiheadAttention(embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True, bias=True)
def forward(self, x):
attention_value, attention_weight = self.mha(x, x, x)
return attention_value, attention_weight
class TransformerEncoder(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
l_attn = []
for attn, ff in self.layers:
attention_value, attention_weight = attn(x)
x = attention_value + x
x = ff(x) + x
l_attn.append(attention_weight)
return x, l_attn
class MaskTransformer(nn.Module):
def __init__(self, img_size=256, hidden_dim=768, codebook_size=1024, depth=24, heads=8, mlp_dim=3072, dropout=0.1, nclass=1000):
super().__init__()
self.nclass = nclass
self.patch_size = img_size // 16
self.codebook_size = codebook_size
self.tok_emb = nn.Embedding(codebook_size+1+nclass+1, hidden_dim) # +1 for the mask of the viz token, +1 for mask of the class
# self.msk_emb = nn.Embedding(2, hidden_dim)
self.pos_emb = nn.init.trunc_normal_(nn.Parameter(torch.zeros(1, (self.patch_size*self.patch_size)+1, hidden_dim)), 0., 0.02)
self.first_layer = nn.Sequential(
nn.LayerNorm(hidden_dim, eps=1e-12),
nn.Dropout(p=dropout),
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim, eps=1e-12),
nn.Dropout(p=dropout),
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
)
self.transformer = TransformerEncoder(dim=hidden_dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dropout=dropout)
self.last_layer = nn.Sequential(
nn.LayerNorm(hidden_dim, eps=1e-12),
nn.Dropout(p=dropout),
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim, eps=1e-12),
)
self.bias = nn.Parameter(torch.zeros((self.patch_size*self.patch_size)+1, codebook_size+1+nclass+1))
def forward(self, img_token, y=None, drop_label=None, return_attn=False): # , masking_flag=None):
b, w, h = img_token.size()
cls_token = y.view(b, -1) + self.codebook_size + 1
cls_token[drop_label] = self.codebook_size + 1 + self.nclass
input = torch.cat([img_token.view(b, -1), cls_token.view(b, -1)], -1)
tok_embeddings = self.tok_emb(input)
pos_embeddings = self.pos_emb
x = tok_embeddings + pos_embeddings
# if masking_flag is not None:
# flag = torch.cat([masking_flag.view(b, -1), torch.zeros_like(cls_token.view(b, -1))], -1)
# x += self.msk_emb(flag)
x = self.first_layer(x)
x, attn = self.transformer(x)
x = self.last_layer(x)
logit = torch.matmul(x, self.tok_emb.weight.T) + self.bias
if return_attn:
return logit[:, :self.patch_size * self.patch_size, :self.codebook_size + 1], attn
return logit[:, :self.patch_size*self.patch_size, :self.codebook_size+1]