import copy from typing import List, Tuple, Optional import torch.nn.functional as F import einops import torch from mmcv.cnn import ConvModule, build_norm_layer from mmcv.cnn.bricks.transformer import PatchEmbed, FFN, build_transformer_layer from mmengine.dist import is_main_process from mmengine.model import BaseModule from peft import get_peft_config, get_peft_model from torch import Tensor, nn # from mmdet.utils import OptConfigType, MultiConfig from mmpretrain.models import resize_pos_embed from mmpretrain.models.backbones.vit_sam import Attention, window_partition, window_unpartition from mmseg.models import BaseSegmentor, EncoderDecoder from mmseg.models.decode_heads.decode_head import BaseDecodeHead from mmseg.models.utils import resize from mmseg.utils import OptConfigType, MultiConfig from opencd.registry import MODELS from mmpretrain.models import build_norm_layer as build_norm_layer_mmpretrain @MODELS.register_module() class MMPretrainSamVisionEncoder(BaseModule): def __init__( self, encoder_cfg, peft_cfg=None, init_cfg=None, ): super().__init__(init_cfg=init_cfg) vision_encoder = MODELS.build(encoder_cfg) vision_encoder.init_weights() if peft_cfg is not None and isinstance(peft_cfg, dict): config = { "peft_type": "LORA", "r": 16, 'target_modules': ["qkv"], "lora_alpha": 32, "lora_dropout": 0.05, "bias": "none", "inference_mode": False, } config.update(peft_cfg) peft_config = get_peft_config(config) self.vision_encoder = get_peft_model(vision_encoder, peft_config) if is_main_process(): self.vision_encoder.print_trainable_parameters() else: self.vision_encoder = vision_encoder # freeze the vision encoder for param in self.vision_encoder.parameters(): param.requires_grad = False for name, param in self.vision_encoder.named_parameters(): if 'down_channel' in name: param.requires_grad = True if 'soft_ffn' in name: param.requires_grad = True if is_main_process() and peft_cfg is not None: self.vision_encoder.print_trainable_parameters() def forward(self, x): return self.vision_encoder(x) @MODELS.register_module() class MLPSegHead(BaseDecodeHead): def __init__( self, out_size, interpolate_mode='bilinear', **kwargs ): super().__init__(input_transform='multiple_select', **kwargs) self.interpolate_mode = interpolate_mode num_inputs = len(self.in_channels) assert num_inputs == len(self.in_index) self.out_size = out_size self.convs = nn.ModuleList() for i in range(num_inputs): self.convs.append( ConvModule( in_channels=self.in_channels[i], out_channels=self.channels, kernel_size=1, stride=1, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)) self.fusion_conv = ConvModule( in_channels=self.channels * num_inputs, out_channels=self.channels, kernel_size=1, norm_cfg=self.norm_cfg) def forward(self, inputs): inputs = self._transform_inputs(inputs) outs = [] for idx in range(len(inputs)): x = inputs[idx] conv = self.convs[idx] outs.append( resize( input=conv(x), size=self.out_size, mode=self.interpolate_mode, align_corners=self.align_corners)) out = self.fusion_conv(torch.cat(outs, dim=1)) out = self.cls_seg(out) return out @MODELS.register_module() class LN2d(nn.Module): """A LayerNorm variant, popularized by Transformers, that performs pointwise mean and variance normalization over the channel dimension for inputs that have shape (batch_size, channels, height, width).""" def __init__(self, normalized_shape, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.normalized_shape = (normalized_shape, ) def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x @MODELS.register_module() class SequentialNeck(BaseModule): def __init__(self, necks): super().__init__() self.necks = nn.ModuleList() for neck in necks: self.necks.append(MODELS.build(neck)) def forward(self, *args, **kwargs): for neck in self.necks: args = neck(*args, **kwargs) return args @MODELS.register_module() class SimpleFPN(BaseModule): def __init__(self, backbone_channel: int, in_channels: List[int], out_channels: int, num_outs: int, conv_cfg: OptConfigType = None, norm_cfg: OptConfigType = None, act_cfg: OptConfigType = None, init_cfg: MultiConfig = None) -> None: super().__init__(init_cfg=init_cfg) assert isinstance(in_channels, list) self.backbone_channel = backbone_channel self.in_channels = in_channels self.out_channels = out_channels self.num_ins = len(in_channels) self.num_outs = num_outs self.fpn1 = nn.Sequential( nn.ConvTranspose2d(self.backbone_channel, self.backbone_channel // 2, 2, 2), build_norm_layer(norm_cfg, self.backbone_channel // 2)[1], nn.GELU(), nn.ConvTranspose2d(self.backbone_channel // 2, self.backbone_channel // 4, 2, 2)) self.fpn2 = nn.Sequential( nn.ConvTranspose2d(self.backbone_channel, self.backbone_channel // 2, 2, 2)) self.fpn3 = nn.Sequential(nn.Identity()) self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2)) self.lateral_convs = nn.ModuleList() self.fpn_convs = nn.ModuleList() for i in range(self.num_ins): l_conv = ConvModule( in_channels[i], out_channels, 1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) fpn_conv = ConvModule( out_channels, out_channels, 3, padding=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, inplace=False) self.lateral_convs.append(l_conv) self.fpn_convs.append(fpn_conv) def forward(self, input: Tensor) -> tuple: # build FPN inputs = [] inputs.append(self.fpn1(input)) inputs.append(self.fpn2(input)) inputs.append(self.fpn3(input)) inputs.append(self.fpn4(input)) # build laterals laterals = [ lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs) ] # build outputs # part 1: from original levels outs = [self.fpn_convs[i](laterals[i]) for i in range(self.num_ins)] # part 2: add extra levels if self.num_outs > len(outs): for i in range(self.num_outs - self.num_ins): outs.append(F.max_pool2d(outs[-1], 1, stride=2)) return tuple(outs) @MODELS.register_module() class TimeFusionTransformerEncoderLayer(BaseModule): def __init__(self, embed_dims: int, num_heads: int, feedforward_channels: int, drop_rate: float = 0., drop_path_rate: float = 0., num_fcs: int = 2, qkv_bias: bool = True, act_cfg: dict = dict(type='GELU'), norm_cfg: dict = dict(type='LN'), use_rel_pos: bool = False, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None, init_cfg=None): super().__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.window_size = window_size self.ln1 = build_norm_layer_mmpretrain(norm_cfg, self.embed_dims) self.attn = Attention( embed_dims=embed_dims, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, input_size=input_size if window_size == 0 else (window_size, window_size), ) self.ln2 = build_norm_layer_mmpretrain(norm_cfg, self.embed_dims) self.ffn = FFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, num_fcs=num_fcs, ffn_drop=drop_rate, dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate), act_cfg=act_cfg) if self.window_size == 0: in_channels = embed_dims * 2 self.down_channel = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, bias=False) self.down_channel.weight.data.fill_(1.0/in_channels) self.soft_ffn = nn.Sequential( nn.Conv2d(embed_dims, embed_dims, kernel_size=1, stride=1), nn.GELU(), nn.Conv2d(embed_dims, embed_dims, kernel_size=1, stride=1), ) @property def norm1(self): return self.ln1 @property def norm2(self): return self.ln2 def forward(self, x): shortcut = x x = self.ln1(x) # Window partition if self.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.window_size) x = self.attn(x) # Reverse window partition if self.window_size > 0: x = window_unpartition(x, self.window_size, pad_hw, (H, W)) x = shortcut + x x = self.ffn(self.ln2(x), identity=x) # # time phase fusion if self.window_size == 0: x = einops.rearrange(x, 'b h w d -> b d h w') # 2B, C, H, W x0 = x[:x.size(0)//2] x1 = x[x.size(0)//2:] # B, C, H, W x0_1 = torch.cat([x0, x1], dim=1) activate_map = self.down_channel(x0_1) activate_map = torch.sigmoid(activate_map) x0 = x0 + self.soft_ffn(x1 * activate_map) x1 = x1 + self.soft_ffn(x0 * activate_map) x = torch.cat([x0, x1], dim=0) x = einops.rearrange(x, 'b d h w -> b h w d') return x