|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from src.models.utils.modules import Block, CrossAttention, CrossAttentionBlock |
|
from src.utils.tensors import trunc_normal_ |
|
|
|
|
|
class AttentivePooler(nn.Module): |
|
"""Attentive Pooler""" |
|
|
|
def __init__( |
|
self, |
|
num_queries=1, |
|
embed_dim=768, |
|
num_heads=12, |
|
mlp_ratio=4.0, |
|
depth=1, |
|
norm_layer=nn.LayerNorm, |
|
init_std=0.02, |
|
qkv_bias=True, |
|
complete_block=True, |
|
use_activation_checkpointing=False, |
|
): |
|
super().__init__() |
|
self.use_activation_checkpointing = use_activation_checkpointing |
|
self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) |
|
|
|
self.complete_block = complete_block |
|
if complete_block: |
|
self.cross_attention_block = CrossAttentionBlock( |
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, norm_layer=norm_layer |
|
) |
|
else: |
|
self.cross_attention_block = CrossAttention(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias) |
|
|
|
self.blocks = None |
|
if depth > 1: |
|
self.blocks = nn.ModuleList( |
|
[ |
|
Block( |
|
dim=embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_scale=False, |
|
norm_layer=norm_layer, |
|
) |
|
for i in range(depth - 1) |
|
] |
|
) |
|
|
|
self.init_std = init_std |
|
trunc_normal_(self.query_tokens, std=self.init_std) |
|
self.apply(self._init_weights) |
|
self._rescale_blocks() |
|
|
|
def _rescale_blocks(self): |
|
def rescale(param, layer_id): |
|
param.div_(math.sqrt(2.0 * layer_id)) |
|
|
|
layer_id = 0 |
|
if self.blocks is not None: |
|
for layer_id, layer in enumerate(self.blocks): |
|
rescale(layer.attn.proj.weight.data, layer_id + 1) |
|
rescale(layer.mlp.fc2.weight.data, layer_id + 1) |
|
|
|
if self.complete_block: |
|
rescale(self.cross_attention_block.mlp.fc2.weight.data, layer_id + 1) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=self.init_std) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.LayerNorm): |
|
nn.init.constant_(m.bias, 0) |
|
nn.init.constant_(m.weight, 1.0) |
|
elif isinstance(m, nn.Conv2d): |
|
trunc_normal_(m.weight, std=self.init_std) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
if self.blocks is not None: |
|
for blk in self.blocks: |
|
if self.use_activation_checkpointing: |
|
x = torch.utils.checkpoint.checkpoint(blk, x, False, None, use_reentrant=False) |
|
else: |
|
x = blk(x) |
|
q = self.query_tokens.repeat(len(x), 1, 1) |
|
q = self.cross_attention_block(q, x) |
|
return q |
|
|
|
|
|
class AttentiveClassifier(nn.Module): |
|
"""Attentive Classifier""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim=768, |
|
num_heads=12, |
|
mlp_ratio=4.0, |
|
depth=1, |
|
norm_layer=nn.LayerNorm, |
|
init_std=0.02, |
|
qkv_bias=True, |
|
num_classes=1000, |
|
complete_block=True, |
|
use_activation_checkpointing=False, |
|
): |
|
super().__init__() |
|
self.pooler = AttentivePooler( |
|
num_queries=1, |
|
embed_dim=embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
depth=depth, |
|
norm_layer=norm_layer, |
|
init_std=init_std, |
|
qkv_bias=qkv_bias, |
|
complete_block=complete_block, |
|
use_activation_checkpointing=use_activation_checkpointing, |
|
) |
|
self.linear = nn.Linear(embed_dim, num_classes, bias=True) |
|
|
|
def forward(self, x): |
|
x = self.pooler(x).squeeze(1) |
|
x = self.linear(x) |
|
return x |
|
|