internimage_t_1k_224 / intern_image.py
Weiyun1025's picture
Upload model
a9905e1
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from transformers import PreTrainedModel
from timm.models.layers import trunc_normal_, DropPath
from .intern_image_config import InternImageConfig
from .dcnv3 import DCNv3_pytorch
class to_channels_first(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 3, 1, 2)
class to_channels_last(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.permute(0, 2, 3, 1)
def build_norm_layer(dim,
norm_layer,
in_format='channels_last',
out_format='channels_last',
eps=1e-6):
layers = []
if norm_layer == 'BN':
if in_format == 'channels_last':
layers.append(to_channels_first())
layers.append(nn.BatchNorm2d(dim))
if out_format == 'channels_last':
layers.append(to_channels_last())
elif norm_layer == 'LN':
if in_format == 'channels_first':
layers.append(to_channels_last())
layers.append(nn.LayerNorm(dim, eps=eps))
if out_format == 'channels_first':
layers.append(to_channels_first())
else:
raise NotImplementedError(
f'build_norm_layer does not support {norm_layer}')
return nn.Sequential(*layers)
def build_act_layer(act_layer):
if act_layer == 'ReLU':
return nn.ReLU(inplace=True)
elif act_layer == 'SiLU':
return nn.SiLU(inplace=True)
elif act_layer == 'GELU':
return nn.GELU()
raise NotImplementedError(f'build_act_layer does not support {act_layer}')
class StemLayer(nn.Module):
r""" Stem layer of InternImage
Args:
in_chans (int): number of input channels
out_chans (int): number of output channels
act_layer (str): activation layer
norm_layer (str): normalization layer
"""
def __init__(self,
in_chans=3,
out_chans=96,
act_layer='GELU',
norm_layer='BN'):
super().__init__()
self.conv1 = nn.Conv2d(in_chans,
out_chans // 2,
kernel_size=3,
stride=2,
padding=1)
self.norm1 = build_norm_layer(out_chans // 2, norm_layer,
'channels_first', 'channels_first')
self.act = build_act_layer(act_layer)
self.conv2 = nn.Conv2d(out_chans // 2,
out_chans,
kernel_size=3,
stride=2,
padding=1)
self.norm2 = build_norm_layer(out_chans, norm_layer, 'channels_first',
'channels_last')
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.act(x)
x = self.conv2(x)
x = self.norm2(x)
return x
class DownsampleLayer(nn.Module):
r""" Downsample layer of InternImage
Args:
channels (int): number of input channels
norm_layer (str): normalization layer
"""
def __init__(self, channels, norm_layer='LN'):
super().__init__()
self.conv = nn.Conv2d(channels,
2 * channels,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.norm = build_norm_layer(2 * channels, norm_layer,
'channels_first', 'channels_last')
def forward(self, x):
x = self.conv(x.permute(0, 3, 1, 2))
x = self.norm(x)
return x
class MLPLayer(nn.Module):
r""" MLP layer of InternImage
Args:
in_features (int): number of input features
hidden_features (int): number of hidden features
out_features (int): number of output features
act_layer (str): activation layer
drop (float): dropout rate
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer='GELU',
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = build_act_layer(act_layer)
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InternImageLayer(nn.Module):
r""" Basic layer of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
groups,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
layer_scale=None,
offset_scale=1.0,
with_cp=False):
super().__init__()
self.channels = channels
self.groups = groups
self.mlp_ratio = mlp_ratio
self.with_cp = with_cp
self.norm1 = build_norm_layer(channels, 'LN')
self.post_norm = post_norm
self.dcn = core_op(channels=channels,
kernel_size=3,
stride=1,
pad=1,
dilation=1,
group=groups,
offset_scale=offset_scale,
act_layer=act_layer,
norm_layer=norm_layer)
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.norm2 = build_norm_layer(channels, 'LN')
self.mlp = MLPLayer(in_features=channels,
hidden_features=int(channels * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.layer_scale = layer_scale is not None
if self.layer_scale:
self.gamma1 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
self.gamma2 = nn.Parameter(layer_scale * torch.ones(channels),
requires_grad=True)
def forward(self, x):
def _inner_forward(x):
if not self.layer_scale:
if self.post_norm:
x = x + self.drop_path(self.norm1(self.dcn(x)))
x = x + self.drop_path(self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.dcn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.post_norm:
x = x + self.drop_path(self.gamma1 * self.norm1(self.dcn(x)))
x = x + self.drop_path(self.gamma2 * self.norm2(self.mlp(x)))
else:
x = x + self.drop_path(self.gamma1 * self.dcn(self.norm1(x)))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = checkpoint.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class InternImageBlock(nn.Module):
r""" Block of InternImage
Args:
core_op (nn.Module): core operation of InternImage
channels (int): number of input channels
depths (list): Depth of each block.
groups (list): Groups of each block.
mlp_ratio (float): ratio of mlp hidden features to input channels
drop (float): dropout rate
drop_path (float): drop path rate
act_layer (str): activation layer
norm_layer (str): normalization layer
post_norm (bool): whether to use post normalization
layer_scale (float): layer scale
offset_scale (float): offset scale
with_cp (bool): whether to use checkpoint
"""
def __init__(self,
core_op,
channels,
depth,
groups,
downsample=True,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
act_layer='GELU',
norm_layer='LN',
post_norm=False,
offset_scale=1.0,
layer_scale=None,
with_cp=False):
super().__init__()
self.channels = channels
self.depth = depth
self.post_norm = post_norm
self.blocks = nn.ModuleList([
InternImageLayer(core_op=core_op,
channels=channels,
groups=groups,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(
drop_path, list) else drop_path,
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp) for i in range(depth)
])
if not self.post_norm:
self.norm = build_norm_layer(channels, 'LN')
self.downsample = DownsampleLayer(
channels=channels, norm_layer=norm_layer) if downsample else None
def forward(self, x, return_wo_downsample=False):
for blk in self.blocks:
x = blk(x)
if not self.post_norm:
x = self.norm(x)
if return_wo_downsample:
x_ = x
if self.downsample is not None:
x = self.downsample(x)
if return_wo_downsample:
return x, x_
return x
class InternImage(nn.Module):
r""" InternImage
A PyTorch impl of : `InternImage: Exploring Large-Scale Vision Foundation Models with Deformable Convolutions` -
https://arxiv.org/pdf/2103.14030
Args:
core_op (str): Core operator. Default: 'DCNv3'
channels (int): Number of the first stage. Default: 64
depths (list): Depth of each block. Default: [3, 4, 18, 5]
groups (list): Groups of each block. Default: [3, 6, 12, 24]
num_classes (int): Number of classes. Default: 1000
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
drop_rate (float): Probability of an element to be zeroed. Default: 0.
drop_path_rate (float): Stochastic depth rate. Default: 0.
act_layer (str): Activation layer. Default: 'GELU'
norm_layer (str): Normalization layer. Default: 'LN'
layer_scale (bool): Whether to use layer scale. Default: False
cls_scale (bool): Whether to use class scale. Default: False
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
"""
def __init__(self,
core_op='DCNv3_pytorch',
channels=64,
depths=[3, 4, 18, 5],
groups=[3, 6, 12, 24],
num_classes=1000,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.2,
drop_path_type='linear',
act_layer='GELU',
norm_layer='LN',
layer_scale=None,
offset_scale=1.0,
post_norm=False,
cls_scale=1.5,
with_cp=False,
**kwargs):
super().__init__()
assert core_op == 'DCNv3_pytorch'
core_op = DCNv3_pytorch
self.core_op = core_op
self.num_classes = num_classes
self.num_levels = len(depths)
self.depths = depths
self.channels = channels
self.num_features = int(channels * 2**(self.num_levels - 1))
self.post_norm = post_norm
self.mlp_ratio = mlp_ratio
print(f'using core type: {core_op}')
print(f'using activation layer: {act_layer}')
print(f'using main norm layer: {norm_layer}')
print(f'using dpr: {drop_path_type}, {drop_path_rate}')
in_chans = 3
self.patch_embed = StemLayer(in_chans=in_chans,
out_chans=channels,
act_layer=act_layer,
norm_layer=norm_layer)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
]
if drop_path_type == 'uniform':
for i in range(len(dpr)):
dpr[i] = drop_path_rate
self.levels = nn.ModuleList()
for i in range(self.num_levels):
level = InternImageBlock(
core_op=core_op,
channels=int(channels * 2**i),
depth=depths[i],
groups=groups[i],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])],
act_layer=act_layer,
norm_layer=norm_layer,
post_norm=post_norm,
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
offset_scale=offset_scale,
with_cp=with_cp)
self.levels.append(level)
self.conv_head = nn.Sequential(
nn.Conv2d(self.num_features,
int(self.num_features * cls_scale),
kernel_size=1,
bias=False),
build_norm_layer(int(self.num_features * cls_scale), 'BN',
'channels_first', 'channels_first'),
build_act_layer(act_layer))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.head = nn.Linear(int(self.num_features * cls_scale), num_classes) \
if num_classes > 0 else nn.Identity()
self.num_layers = len(depths)
self.apply(self._init_weights)
self.apply(self._init_deform_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _init_deform_weights(self, m):
if isinstance(m, self.core_op):
m._reset_parameters()
@torch.jit.ignore
def lr_decay_keywards(self, decay_ratio=0.87):
lr_ratios = {}
# blocks
idx = 0
for i in range(4):
layer_num = 3 - i # 3 2 1 0
for j in range(self.depths[layer_num]):
block_num = self.depths[layer_num] - j - 1
tag = 'levels.{}.blocks.{}.'.format(layer_num, block_num)
decay = 1.0 * (decay_ratio**idx)
lr_ratios[tag] = decay
idx += 1
# patch_embed (before stage-1)
lr_ratios["patch_embed"] = lr_ratios['levels.0.blocks.0.']
# levels.0.downsample (between stage-1 and stage-2)
lr_ratios["levels.0.downsample"] = lr_ratios['levels.1.blocks.0.']
lr_ratios["levels.0.norm"] = lr_ratios['levels.1.blocks.0.']
# levels.1.downsample (between stage-2 and stage-3)
lr_ratios["levels.1.downsample"] = lr_ratios['levels.2.blocks.0.']
lr_ratios["levels.1.norm"] = lr_ratios['levels.2.blocks.0.']
# levels.2.downsample (between stage-3 and stage-4)
lr_ratios["levels.2.downsample"] = lr_ratios['levels.3.blocks.0.']
lr_ratios["levels.2.norm"] = lr_ratios['levels.3.blocks.0.']
return lr_ratios
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for level in self.levels:
x = level(x)
x = self.conv_head(x.permute(0, 3, 1, 2))
x = self.avgpool(x)
x = torch.flatten(x, 1)
return x
def forward_features_seq_out(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
seq_out = []
for level in self.levels:
x, x_ = level(x, return_wo_downsample=True)
seq_out.append(x_)
return seq_out
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
class InternImageModel(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
)
def forward(self, tensor):
return self.model.forward_features(tensor)
class InternImageModelForImageClassification(PreTrainedModel):
config_class = InternImageConfig
def __init__(self, config):
super().__init__(config)
self.model = InternImage(
core_op=config.core_op,
channels=config.channels,
depths=config.depths,
groups=config.groups,
num_classes=config.num_classes,
mlp_ratio=config.mlp_ratio,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
drop_path_type=config.drop_path_type,
act_layer=config.act_layer,
norm_layer=config.norm_layer,
layer_scale=config.layer_scale,
offset_scale=config.offset_scale,
post_norm=config.post_norm,
cls_scale=config.cls_scale,
with_cp=config.with_cp,
)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = F.cross_entropy(logits, labels)
return {'loss': loss, 'logits': logits}
return {'logits': logits}