|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer |
|
from mmengine.model import ModuleList, Sequential |
|
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead |
|
from mmseg.models.utils import Upsample |
|
from opencd.registry import MODELS |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__(self, |
|
in_dims, |
|
embed_dims, |
|
num_heads, |
|
dropout_rate=0., |
|
apply_softmax=True): |
|
super(CrossAttention, self).__init__() |
|
self.num_heads = num_heads |
|
self.scale = in_dims ** -0.5 |
|
|
|
self.apply_softmax = apply_softmax |
|
|
|
self.to_q = nn.Linear(in_dims, embed_dims, bias=False) |
|
self.to_k = nn.Linear(in_dims, embed_dims, bias=False) |
|
self.to_v = nn.Linear(in_dims, embed_dims, bias=False) |
|
|
|
self.fc_out = nn.Sequential( |
|
nn.Linear(embed_dims, in_dims), |
|
nn.Dropout(dropout_rate) |
|
) |
|
|
|
def forward(self, x, ref): |
|
b, n = x.shape[:2] |
|
h = self.num_heads |
|
|
|
q = self.to_q(x) |
|
k = self.to_k(ref) |
|
v = self.to_v(ref) |
|
|
|
q = q.reshape((b, n, h, -1)).permute((0, 2, 1, 3)) |
|
k = k.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) |
|
v = v.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) |
|
|
|
mult = torch.matmul(q, k.transpose(-1,-2)) * self.scale |
|
|
|
if self.apply_softmax: |
|
mult = F.softmax(mult, dim=-1) |
|
|
|
out = torch.matmul(mult, v) |
|
out = out.permute((0,2,1,3)).flatten(2) |
|
return self.fc_out(out) |
|
|
|
|
|
class FeedForward(nn.Sequential): |
|
def __init__(self, dim, hidden_dim, dropout_rate=0.): |
|
super().__init__( |
|
|
|
nn.Linear(dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(hidden_dim, dim), |
|
nn.Dropout(dropout_rate) |
|
) |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__(self, |
|
in_dims, |
|
embed_dims, |
|
num_heads, |
|
drop_rate, |
|
norm_cfg, |
|
apply_softmax=True): |
|
super(TransformerEncoder, self).__init__() |
|
self.attn = CrossAttention( |
|
in_dims, |
|
embed_dims, |
|
num_heads, |
|
dropout_rate=drop_rate, |
|
apply_softmax=apply_softmax) |
|
self.ff = FeedForward( |
|
in_dims, |
|
embed_dims, |
|
drop_rate |
|
) |
|
self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] |
|
self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] |
|
def forward(self, x): |
|
x_ = self.attn(self.norm1(x),self.norm1(x)) + x |
|
y = self.ff(self.norm2(x_)) + x_ |
|
return y |
|
|
|
|
|
class TransformerDecoder(nn.Module): |
|
def __init__( |
|
self, |
|
in_dims, |
|
embed_dims, |
|
num_heads, |
|
drop_rate, |
|
norm_cfg, |
|
apply_softmax=True |
|
): |
|
super(TransformerDecoder, self).__init__() |
|
self.attn = CrossAttention( |
|
in_dims, |
|
embed_dims, |
|
num_heads, |
|
dropout_rate=drop_rate, |
|
apply_softmax=apply_softmax) |
|
self.ff = FeedForward( |
|
in_dims, |
|
embed_dims, |
|
drop_rate |
|
) |
|
self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] |
|
self.norm1_ = build_norm_layer(norm_cfg, in_dims)[1] |
|
self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] |
|
|
|
def forward(self, x, ref): |
|
x_ = self.attn(self.norm1(x),self.norm1_(ref)) + x |
|
y = self.ff(self.norm2(x_)) + x_ |
|
return y |
|
|
|
|
|
@MODELS.register_module() |
|
class BITHead(BaseDecodeHead): |
|
"""BIT Head |
|
|
|
This head is the improved implementation of'Remote Sensing Image |
|
Change Detection With Transformers<https://github.com/justchenhao/BIT_CD>' |
|
|
|
Args: |
|
in_channels (int): Number of input feature channels (from backbone). Default: 512 |
|
channels (int): Number of output channels of pre_process. Default: 32. |
|
embed_dims (int): Number of expanded channels of Attention block. Default: 64. |
|
enc_depth (int): Depth of block of transformer encoder. Default: 1. |
|
enc_with_pos (bool): Using position embedding in transformer encoder. |
|
Default: True |
|
dec_depth (int): Depth of block of transformer decoder. Default: 8. |
|
num_heads (int): Number of Multi-Head Cross-Attention Head of transformer encoder. |
|
Default: 8. |
|
use_tokenizer (bool),Using semantic token. Default: True |
|
token_len (int): Number of dims of token. Default: 4. |
|
pre_upsample (int): Scale factor of upsample of pre_process. |
|
(default upsample to 64x64) |
|
Default: 2. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels=256, |
|
channels=32, |
|
embed_dims=64, |
|
enc_depth=1, |
|
enc_with_pos=True, |
|
dec_depth=8, |
|
num_heads=8, |
|
drop_rate=0., |
|
pool_size=2, |
|
pool_mode='max', |
|
use_tokenizer=True, |
|
token_len=4, |
|
pre_upsample=2, |
|
upsample_size=4, |
|
norm_cfg=dict(type='LN'), |
|
act_cfg=dict(type='ReLU', inplace=True), |
|
**kwargs): |
|
super().__init__(in_channels, channels, **kwargs) |
|
self.norm_cfg = norm_cfg |
|
self.act_cfg = act_cfg |
|
self.embed_dims=embed_dims |
|
self.use_tokenizer = use_tokenizer |
|
self.num_heads=num_heads |
|
if not use_tokenizer: |
|
|
|
self.pool_size = pool_size |
|
self.pool_mode = pool_mode |
|
self.token_len = pool_size * pool_size |
|
else: |
|
self.token_len = token_len |
|
self.conv_att = ConvModule( |
|
self.channels, |
|
self.token_len, |
|
1, |
|
conv_cfg=self.conv_cfg, |
|
) |
|
|
|
self.enc_with_pos = enc_with_pos |
|
if enc_with_pos: |
|
self.enc_pos_embedding = nn.Parameter(torch.randn(1, self.token_len * 2, self.channels)) |
|
|
|
|
|
self.pre_process = Sequential( |
|
Upsample(scale_factor=pre_upsample, mode='bilinear', align_corners=self.align_corners), |
|
ConvModule( |
|
self.in_channels, |
|
self.channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg |
|
) |
|
) |
|
|
|
|
|
self.encoder = ModuleList() |
|
for _ in range(enc_depth): |
|
block = TransformerEncoder( |
|
self.channels, |
|
self.embed_dims, |
|
self.num_heads, |
|
drop_rate=drop_rate, |
|
norm_cfg=self.norm_cfg, |
|
) |
|
self.encoder.append(block) |
|
|
|
|
|
self.decoder = ModuleList() |
|
for _ in range(dec_depth): |
|
block = TransformerDecoder( |
|
self.channels, |
|
self.embed_dims, |
|
self.num_heads, |
|
drop_rate=drop_rate, |
|
norm_cfg=self.norm_cfg, |
|
) |
|
self.decoder.append(block) |
|
|
|
self.upsample = Upsample(scale_factor=upsample_size,mode='bilinear',align_corners=self.align_corners) |
|
|
|
|
|
def _forward_semantic_tokens(self, x): |
|
b, c = x.shape[:2] |
|
att_map = self.conv_att(x) |
|
att_map = att_map.reshape((b, self.token_len, 1, -1)) |
|
att_map = F.softmax(att_map, dim=-1) |
|
x = x.reshape((b, 1, c, -1)) |
|
tokens = (x * att_map).sum(-1) |
|
return tokens |
|
|
|
def _forward_reshaped_tokens(self, x): |
|
if self.pool_mode == 'max': |
|
x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size)) |
|
elif self.pool_mode == 'avg': |
|
x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size)) |
|
else: |
|
x = x |
|
tokens = x.permute((0, 2, 3, 1)).flatten(1, 2) |
|
return tokens |
|
|
|
|
|
def _forward_feature(self, inputs): |
|
"""Forward function for feature maps before classifying each pixel with |
|
``self.cls_seg`` fc. |
|
|
|
Args: |
|
inputs (list[Tensor]): List of multi-level img features. |
|
|
|
Returns: |
|
feats (Tensor): A tensor of shape (batch_size, self.channels, |
|
H, W) which is feature map for last layer of decoder head. |
|
""" |
|
inputs = self._transform_inputs(inputs) |
|
x1, x2 = torch.chunk(inputs, 2, dim=1) |
|
x1 = self.pre_process(x1) |
|
x2 = self.pre_process(x2) |
|
|
|
if self.use_tokenizer: |
|
token1 = self._forward_semantic_tokens(x1) |
|
token2 = self._forward_semantic_tokens(x2) |
|
else: |
|
token1 = self._forward_reshaped_tokens(x1) |
|
token2 = self._forward_reshaped_tokens(x2) |
|
|
|
|
|
token = torch.cat([token1, token2], dim=1) |
|
if self.enc_with_pos: |
|
token += self.enc_pos_embedding |
|
for i, _encoder in enumerate(self.encoder): |
|
token = _encoder(token) |
|
token1, token2 = torch.chunk(token, 2, dim=1) |
|
|
|
|
|
for _decoder in self.decoder: |
|
b, c, h, w = x1.shape |
|
x1 = x1.permute((0, 2, 3, 1)).flatten(1, 2) |
|
x2 = x2.permute((0, 2, 3, 1)).flatten(1, 2) |
|
|
|
x1 = _decoder(x1, token1) |
|
x2 = _decoder(x2, token2) |
|
|
|
x1 = x1.transpose(1, 2).reshape((b, c, h, w)) |
|
x2 = x2.transpose(1, 2).reshape((b, c, h, w)) |
|
|
|
|
|
y = torch.abs(x1 - x2) |
|
y = self.upsample(y) |
|
|
|
return y |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
output = self._forward_feature(inputs) |
|
output = self.cls_seg(output) |
|
return output |
|
|