Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer | |
from mmengine.model import ModuleList, Sequential | |
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | |
from mmseg.models.utils import Upsample | |
from opencd.registry import MODELS | |
class CrossAttention(nn.Module): | |
def __init__(self, | |
in_dims, | |
embed_dims, | |
num_heads, | |
dropout_rate=0., | |
apply_softmax=True): | |
super(CrossAttention, self).__init__() | |
self.num_heads = num_heads | |
self.scale = in_dims ** -0.5 | |
self.apply_softmax = apply_softmax | |
self.to_q = nn.Linear(in_dims, embed_dims, bias=False) | |
self.to_k = nn.Linear(in_dims, embed_dims, bias=False) | |
self.to_v = nn.Linear(in_dims, embed_dims, bias=False) | |
self.fc_out = nn.Sequential( | |
nn.Linear(embed_dims, in_dims), | |
nn.Dropout(dropout_rate) | |
) | |
def forward(self, x, ref): | |
b, n = x.shape[:2] | |
h = self.num_heads | |
q = self.to_q(x) | |
k = self.to_k(ref) | |
v = self.to_v(ref) | |
q = q.reshape((b, n, h, -1)).permute((0, 2, 1, 3)) | |
k = k.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) | |
v = v.reshape((b, ref.shape[1], h, -1)).permute((0, 2, 1, 3)) | |
mult = torch.matmul(q, k.transpose(-1,-2)) * self.scale | |
if self.apply_softmax: | |
mult = F.softmax(mult, dim=-1) | |
out = torch.matmul(mult, v) | |
out = out.permute((0,2,1,3)).flatten(2) | |
return self.fc_out(out) | |
class FeedForward(nn.Sequential): | |
def __init__(self, dim, hidden_dim, dropout_rate=0.): | |
super().__init__( | |
# TODO:to be more mmlab | |
nn.Linear(dim, hidden_dim), | |
nn.ReLU(), | |
nn.Dropout(dropout_rate), | |
nn.Linear(hidden_dim, dim), | |
nn.Dropout(dropout_rate) | |
) | |
class TransformerEncoder(nn.Module): | |
def __init__(self, | |
in_dims, | |
embed_dims, | |
num_heads, | |
drop_rate, | |
norm_cfg, | |
apply_softmax=True): | |
super(TransformerEncoder, self).__init__() | |
self.attn = CrossAttention( | |
in_dims, | |
embed_dims, | |
num_heads, | |
dropout_rate=drop_rate, | |
apply_softmax=apply_softmax) | |
self.ff = FeedForward( | |
in_dims, | |
embed_dims, | |
drop_rate | |
) | |
self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] | |
self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] | |
def forward(self, x): | |
x_ = self.attn(self.norm1(x),self.norm1(x)) + x | |
y = self.ff(self.norm2(x_)) + x_ | |
return y | |
class TransformerDecoder(nn.Module): | |
def __init__( | |
self, | |
in_dims, | |
embed_dims, | |
num_heads, | |
drop_rate, | |
norm_cfg, | |
apply_softmax=True | |
): | |
super(TransformerDecoder, self).__init__() | |
self.attn = CrossAttention( | |
in_dims, | |
embed_dims, | |
num_heads, | |
dropout_rate=drop_rate, | |
apply_softmax=apply_softmax) | |
self.ff = FeedForward( | |
in_dims, | |
embed_dims, | |
drop_rate | |
) | |
self.norm1 = build_norm_layer(norm_cfg, in_dims)[1] | |
self.norm1_ = build_norm_layer(norm_cfg, in_dims)[1] | |
self.norm2 = build_norm_layer(norm_cfg, in_dims)[1] | |
def forward(self, x, ref): | |
x_ = self.attn(self.norm1(x),self.norm1_(ref)) + x | |
y = self.ff(self.norm2(x_)) + x_ | |
return y | |
class BITHead(BaseDecodeHead): | |
"""BIT Head | |
This head is the improved implementation of'Remote Sensing Image | |
Change Detection With Transformers<https://github.com/justchenhao/BIT_CD>' | |
Args: | |
in_channels (int): Number of input feature channels (from backbone). Default: 512 | |
channels (int): Number of output channels of pre_process. Default: 32. | |
embed_dims (int): Number of expanded channels of Attention block. Default: 64. | |
enc_depth (int): Depth of block of transformer encoder. Default: 1. | |
enc_with_pos (bool): Using position embedding in transformer encoder. | |
Default: True | |
dec_depth (int): Depth of block of transformer decoder. Default: 8. | |
num_heads (int): Number of Multi-Head Cross-Attention Head of transformer encoder. | |
Default: 8. | |
use_tokenizer (bool),Using semantic token. Default: True | |
token_len (int): Number of dims of token. Default: 4. | |
pre_upsample (int): Scale factor of upsample of pre_process. | |
(default upsample to 64x64) | |
Default: 2. | |
""" | |
def __init__(self, | |
in_channels=256, | |
channels=32, | |
embed_dims=64, | |
enc_depth=1, | |
enc_with_pos=True, | |
dec_depth=8, | |
num_heads=8, | |
drop_rate=0., | |
pool_size=2, | |
pool_mode='max', | |
use_tokenizer=True, | |
token_len=4, | |
pre_upsample=2, | |
upsample_size=4, | |
norm_cfg=dict(type='LN'), | |
act_cfg=dict(type='ReLU', inplace=True), | |
**kwargs): | |
super().__init__(in_channels, channels, **kwargs) | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.embed_dims=embed_dims | |
self.use_tokenizer = use_tokenizer | |
self.num_heads=num_heads | |
if not use_tokenizer: | |
# If a tokenzier is not to be used,then downsample the feature maps | |
self.pool_size = pool_size | |
self.pool_mode = pool_mode | |
self.token_len = pool_size * pool_size | |
else: | |
self.token_len = token_len | |
self.conv_att = ConvModule( | |
self.channels, | |
self.token_len, | |
1, | |
conv_cfg=self.conv_cfg, | |
) | |
self.enc_with_pos = enc_with_pos | |
if enc_with_pos: | |
self.enc_pos_embedding = nn.Parameter(torch.randn(1, self.token_len * 2, self.channels)) | |
# pre_process to backbone feature | |
self.pre_process = Sequential( | |
Upsample(scale_factor=pre_upsample, mode='bilinear', align_corners=self.align_corners), | |
ConvModule( | |
self.in_channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg | |
) | |
) | |
# Transformer Encoder | |
self.encoder = ModuleList() | |
for _ in range(enc_depth): | |
block = TransformerEncoder( | |
self.channels, | |
self.embed_dims, | |
self.num_heads, | |
drop_rate=drop_rate, | |
norm_cfg=self.norm_cfg, | |
) | |
self.encoder.append(block) | |
# Transformer Decoder | |
self.decoder = ModuleList() | |
for _ in range(dec_depth): | |
block = TransformerDecoder( | |
self.channels, | |
self.embed_dims, | |
self.num_heads, | |
drop_rate=drop_rate, | |
norm_cfg=self.norm_cfg, | |
) | |
self.decoder.append(block) | |
self.upsample = Upsample(scale_factor=upsample_size,mode='bilinear',align_corners=self.align_corners) | |
# Token | |
def _forward_semantic_tokens(self, x): | |
b, c = x.shape[:2] | |
att_map = self.conv_att(x) | |
att_map = att_map.reshape((b, self.token_len, 1, -1)) | |
att_map = F.softmax(att_map, dim=-1) | |
x = x.reshape((b, 1, c, -1)) | |
tokens = (x * att_map).sum(-1) | |
return tokens | |
def _forward_reshaped_tokens(self, x): | |
if self.pool_mode == 'max': | |
x = F.adaptive_max_pool2d(x, (self.pool_size, self.pool_size)) | |
elif self.pool_mode == 'avg': | |
x = F.adaptive_avg_pool2d(x, (self.pool_size, self.pool_size)) | |
else: | |
x = x | |
tokens = x.permute((0, 2, 3, 1)).flatten(1, 2) | |
return tokens | |
def _forward_feature(self, inputs): | |
"""Forward function for feature maps before classifying each pixel with | |
``self.cls_seg`` fc. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
Returns: | |
feats (Tensor): A tensor of shape (batch_size, self.channels, | |
H, W) which is feature map for last layer of decoder head. | |
""" | |
inputs = self._transform_inputs(inputs) | |
x1, x2 = torch.chunk(inputs, 2, dim=1) | |
x1 = self.pre_process(x1) | |
x2 = self.pre_process(x2) | |
# Tokenization | |
if self.use_tokenizer: | |
token1 = self._forward_semantic_tokens(x1) | |
token2 = self._forward_semantic_tokens(x2) | |
else: | |
token1 = self._forward_reshaped_tokens(x1) | |
token2 = self._forward_reshaped_tokens(x2) | |
# Transformer encoder forward | |
token = torch.cat([token1, token2], dim=1) | |
if self.enc_with_pos: | |
token += self.enc_pos_embedding | |
for i, _encoder in enumerate(self.encoder): | |
token = _encoder(token) | |
token1, token2 = torch.chunk(token, 2, dim=1) | |
# Transformer decoder forward | |
for _decoder in self.decoder: | |
b, c, h, w = x1.shape | |
x1 = x1.permute((0, 2, 3, 1)).flatten(1, 2) | |
x2 = x2.permute((0, 2, 3, 1)).flatten(1, 2) | |
x1 = _decoder(x1, token1) | |
x2 = _decoder(x2, token2) | |
x1 = x1.transpose(1, 2).reshape((b, c, h, w)) | |
x2 = x2.transpose(1, 2).reshape((b, c, h, w)) | |
# Feature differencing | |
y = torch.abs(x1 - x2) | |
y = self.upsample(y) | |
return y | |
def forward(self, inputs): | |
"""Forward function.""" | |
output = self._forward_feature(inputs) | |
output = self.cls_seg(output) | |
return output | |