topdu's picture
openocr demo
29f689c
raw
history blame
3.88 kB
import numpy as np
import torch
from torch import nn
from torch.nn.init import kaiming_normal_, ones_, trunc_normal_, zeros_
from openrec.modeling.common import Block, PatchEmbed
from openrec.modeling.encoders.svtrv2_lnconv import Feat2D, LastStage
class ViT(nn.Module):
def __init__(
self,
img_size=[32, 128],
patch_size=[4, 8],
in_channels=3,
out_channels=256,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
act_layer=nn.GELU,
last_stage=False,
feat2d=False,
use_cls_token=False,
**kwargs,
):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = embed_dim
self.use_cls_token = use_cls_token
self.feat_sz = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
self.patch_embed = PatchEmbed(img_size, patch_size, in_channels,
embed_dim)
num_patches = self.patch_embed.num_patches
if use_cls_token:
self.cls_token = nn.Parameter(
torch.zeros([1, 1, embed_dim], dtype=torch.float32),
requires_grad=True,
)
trunc_normal_(self.cls_token, mean=0, std=0.02)
self.pos_embed = nn.Parameter(
torch.zeros([1, num_patches + 1, embed_dim],
dtype=torch.float32),
requires_grad=True,
)
else:
self.pos_embed = nn.Parameter(
torch.zeros([1, num_patches, embed_dim], dtype=torch.float32),
requires_grad=True,
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = np.linspace(0, drop_path_rate, depth)
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=act_layer,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.last_stage = last_stage
self.feat2d = feat2d
if last_stage:
self.out_channels = out_channels
self.stages = LastStage(embed_dim, out_channels, last_drop=0.1)
if feat2d:
self.stages = Feat2D()
trunc_normal_(self.pos_embed, mean=0, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, mean=0, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
if isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
if isinstance(m, nn.Conv2d):
kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed'}
def forward(self, x):
x = self.patch_embed(x)
if self.use_cls_token:
x = torch.concat([self.cls_token.tile([x.shape[0], 1, 1]), x], 1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
if self.use_cls_token:
x = x[:, 1:, :]
if self.last_stage:
x, sz = self.stages(x, self.feat_sz)
if self.feat2d:
x, sz = self.stages(x, self.feat_sz)
return x