Spaces:
Build error
Build error
import math | |
import torch | |
from torch import nn, Tensor | |
from functools import partial | |
from .components import ImgLinearBackbone, PositionEmbedding, Encoder | |
class BeitEncoder(nn.Module): | |
def __init__( | |
self, | |
d_model: int, # embed_dim | |
backbone: nn.Module, | |
max_seq_len: int, # for positional embedding | |
codebook_tokens: int, | |
dropout: float, | |
encoder: Encoder, | |
norm_layer: nn.Module, | |
init_std: float = 0.02, | |
) -> None: | |
super().__init__() | |
self.d_model = d_model | |
self.init_std = init_std | |
self.backbone = backbone | |
self.pos_embed = PositionEmbedding( | |
max_seq_len=max_seq_len, d_model=d_model, dropout=dropout | |
) | |
self.encoder = encoder | |
self.norm = norm_layer(d_model) | |
self.generator = nn.Linear(d_model, codebook_tokens) | |
self.trunc_normal = partial( | |
nn.init.trunc_normal_, std=init_std, a=-init_std, b=init_std | |
) | |
self.apply(self._init_weights) | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, d_model)) | |
def _init_weights(self, m: nn.Module): | |
if isinstance(m, nn.Linear): | |
self.trunc_normal(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.weight, 1.0) | |
nn.init.constant_(m.bias, 0.0) | |
elif isinstance(m, nn.Conv2d): | |
self.trunc_normal(m.weight) | |
if m.bias is not None: | |
nn.init.constant_(m.bias, 0.0) | |
elif isinstance(m, PositionEmbedding): | |
self.trunc_normal(m.embedding.weight) | |
def no_weight_decay(self): | |
return {"pos_embed"} | |
def forward( | |
self, x: Tensor, bool_masked_pos: Tensor, return_all_tokens: bool = False | |
): | |
x = self.backbone(x) | |
B, S, E = x.shape | |
assert E == self.d_model | |
mask_token = self.mask_token.expand(B, S, -1) | |
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token) | |
x = x * (1 - w) + mask_token * w | |
x = self.pos_embed(x) | |
x = self.encoder(x) | |
x = self.norm(x) | |
if return_all_tokens: | |
return self.generator(x) | |
else: | |
return self.generator(x[bool_masked_pos]) | |
if __name__ == "__main__": | |
d_model = 512 | |
patch_size = 16 | |
nhead = 8 | |
dropout = 0.0 | |
acitvation = "gelu" | |
norm_first = True | |
nlayer = 12 | |
ff_ratio = 4 | |
norm_layer = partial(nn.LayerNorm, eps=1e-6) | |
codebook_tokens = 8192 | |
img_size = 448 | |
max_seq_len = (img_size // patch_size) ** 2 | |
backbone = ImgLinearBackbone(d_model=d_model, patch_size=patch_size) | |
encoder = Encoder( | |
d_model=d_model, | |
nhead=nhead, | |
dropout=dropout, | |
activation=acitvation, | |
norm_first=norm_first, | |
nlayer=nlayer, | |
ff_ratio=ff_ratio, | |
) | |
model = BeitEncoder( | |
d_model=d_model, | |
backbone=backbone, | |
max_seq_len=max_seq_len, | |
codebook_tokens=codebook_tokens, | |
dropout=dropout, | |
encoder=encoder, | |
norm_layer=norm_layer, | |
) | |
print(model) | |
x = torch.rand((1, 3, img_size, img_size)) | |
bool_masked_pos = torch.rand((1, (img_size // patch_size) ** 2)) < 0.5 | |
y = model(x, bool_masked_pos) | |
print(torch.sum(bool_masked_pos)) | |
print(y.shape) | |