Spaces:
Running
Running
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_ | |
from openrec.modeling.common import Block, PatchEmbed | |
from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage | |
class ViT(nn.Module): | |
def __init__( | |
self, | |
img_size=[32, 128], | |
patch_size=[4, 8], | |
in_channels=3, | |
out_channels=256, | |
embed_dim=384, | |
depth=12, | |
num_heads=6, | |
mlp_ratio=4, | |
qkv_bias=True, | |
qk_scale=None, | |
drop_rate=0.0, | |
attn_drop_rate=0.0, | |
drop_path_rate=0.0, | |
norm_layer=nn.LayerNorm, | |
act_layer=nn.GELU, | |
last_stage=False, | |
feat2d=False, | |
use_cls_token=False, | |
**kwargs, | |
): | |
super().__init__() | |
self.img_size = img_size | |
self.embed_dim = embed_dim | |
self.out_channels = embed_dim | |
self.use_cls_token = use_cls_token | |
self.feat_sz = [ | |
img_size[0] // patch_size[0], img_size[1] // patch_size[1] | |
] | |
self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, | |
embed_dim) | |
num_patches = self.patch_embed.num_patches | |
if use_cls_token: | |
self.cls_token = nn.Parameter( | |
torch.zeros([1, 1, embed_dim], dtype=torch.float32), | |
requires_grad=True, | |
) | |
trunc_normal_(self.cls_token, mean=0, std=0.02) | |
self.pos_embed = nn.Parameter( | |
torch.zeros([1, num_patches + 1, embed_dim], | |
dtype=torch.float32), | |
requires_grad=True, | |
) | |
else: | |
self.pos_embed = nn.Parameter( | |
torch.zeros([1, num_patches, embed_dim], dtype=torch.float32), | |
requires_grad=True, | |
) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
dpr = np.linspace(0, drop_path_rate, depth) | |
self.blocks = nn.ModuleList([ | |
Block( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop=drop_rate, | |
act_layer=act_layer, | |
attn_drop=attn_drop_rate, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
) for i in range(depth) | |
]) | |
self.norm = norm_layer(embed_dim) | |
self.last_stage = last_stage | |
self.feat2d = feat2d | |
if last_stage: | |
self.out_channels = out_channels | |
self.stages = LastStage(embed_dim, out_channels, last_drop=0.1) | |
if feat2d: | |
self.stages = Feat2D() | |
trunc_normal_(self.pos_embed, mean=0, std=0.02) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
trunc_normal_(m.weight, mean=0, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
zeros_(m.bias) | |
if isinstance(m, nn.LayerNorm): | |
zeros_(m.bias) | |
ones_(m.weight) | |
if isinstance(m, nn.Conv2d): | |
kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
def no_weight_decay(self): | |
return {'pos_embed'} | |
def forward(self, x): | |
x = self.patch_embed(x) | |
if self.use_cls_token: | |
x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1) | |
x = x + self.pos_embed | |
x = self.pos_drop(x) | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
if self.use_cls_token: | |
x = x[:, 1:, :] | |
if self.last_stage: | |
x, sz = self.stages(x, self.feat_sz) | |
if self.feat2d: | |
x, sz = self.stages(x, self.feat_sz) | |
return x | |