topdu's picture
openocr demo
29f689c
raw
history blame
7.43 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from openrec.modeling.common import DropPath
class LayerNorm(nn.Module):
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self,
normalized_shape,
eps=1e-6,
data_format='channels_last'):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ['channels_last', 'channels_first']:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == 'channels_last':
return F.layer_norm(x, self.normalized_shape, self.weight,
self.bias, self.eps)
elif self.data_format == 'channels_first':
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, inputs, mask=None):
x = inputs
if mask is not None:
x = x * (1. - mask)
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (inputs * Nx) + self.beta + inputs
class Block(nn.Module):
""" ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3,
groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim,
4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x.contiguous())
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class ConvNeXtV2(nn.Module):
""" ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_channels=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.,
strides=[(4, 4), (2, 2), (2, 2), (2, 2)],
out_channels=256,
last_stage=False,
feat2d=False,
**kwargs,
):
super().__init__()
self.strides = strides
self.depths = depths
self.downsample_layers = nn.ModuleList(
) # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_channels,
dims[0],
kernel_size=strides[0],
stride=strides[0]),
LayerNorm(dims[0], eps=1e-6, data_format='channels_first'))
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format='channels_first'),
nn.Conv2d(dims[i],
dims[i + 1],
kernel_size=strides[i + 1],
stride=strides[i + 1]),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList(
) # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
cur = 0
for i in range(4):
stage = nn.Sequential(*[
Block(dim=dims[i], drop_path=dp_rates[cur + j])
for j in range(depths[i])
])
self.stages.append(stage)
cur += depths[i]
self.out_channels = dims[-1]
self.last_stage = last_stage
self.feat2d = feat2d
if last_stage:
self.out_channels = out_channels
self.last_conv = nn.Linear(dims[-1], self.out_channels, bias=False)
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=0.1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if isinstance(m, (nn.Conv2d, nn.Linear)) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.SyncBatchNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def no_weight_decay(self):
return {}
def forward(self, x):
feats = []
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
feats.append(x)
if self.last_stage:
x = x.mean(2).transpose(1, 2)
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
return x
if self.feat2d:
return x
return feats