TTP / opencd /models /utils /interaction_layer.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
3.47 kB
# Copyright (c) Open-CD. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from opencd.models.utils.builder import ITERACTION_LAYERS
@ITERACTION_LAYERS.register_module()
class ChannelExchange(BaseModule):
"""
channel exchange
Args:
p (float, optional): p of the features will be exchanged.
Defaults to 1/2.
"""
def __init__(self, p=1/2):
super().__init__()
assert p >= 0 and p <= 1
self.p = int(1/p)
def forward(self, x1, x2):
N, c, h, w = x1.shape
exchange_map = torch.arange(c) % self.p == 0
exchange_mask = exchange_map.unsqueeze(0).expand((N, -1))
out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2)
out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...]
out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...]
out_x1[exchange_mask, ...] = x2[exchange_mask, ...]
out_x2[exchange_mask, ...] = x1[exchange_mask, ...]
return out_x1, out_x2
@ITERACTION_LAYERS.register_module()
class SpatialExchange(BaseModule):
"""
spatial exchange
Args:
p (float, optional): p of the features will be exchanged.
Defaults to 1/2.
"""
def __init__(self, p=1/2):
super().__init__()
assert p >= 0 and p <= 1
self.p = int(1/p)
def forward(self, x1, x2):
N, c, h, w = x1.shape
exchange_mask = torch.arange(w) % self.p == 0
out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2)
out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask]
out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask]
out_x1[..., exchange_mask] = x2[..., exchange_mask]
out_x2[..., exchange_mask] = x1[..., exchange_mask]
return out_x1, out_x2
@ITERACTION_LAYERS.register_module()
class Aggregation_distribution(BaseModule):
# Aggregation_Distribution Layer (AD)
def __init__(self,
channels,
num_paths=2,
attn_channels=None,
act_cfg=dict(type='ReLU'),
norm_cfg=dict(type='BN', requires_grad=True)):
super().__init__()
self.num_paths = num_paths # `2` is supported.
attn_channels = attn_channels or channels // 16
attn_channels = max(attn_channels, 8)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = build_norm_layer(norm_cfg, attn_channels)[1]
self.act = build_activation_layer(act_cfg)
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x1, x2):
x = torch.stack([x1, x2], dim=1)
attn = x.sum(1).mean((2, 3), keepdim=True)
attn = self.fc_reduce(attn)
attn = self.bn(attn)
attn = self.act(attn)
attn = self.fc_select(attn)
B, C, H, W = attn.shape
attn1, attn2 = attn.reshape(B, self.num_paths, C // self.num_paths, H, W).transpose(0, 1)
attn1 = torch.sigmoid(attn1)
attn2 = torch.sigmoid(attn2)
return x1 * attn1, x2 * attn2
@ITERACTION_LAYERS.register_module()
class TwoIdentity(BaseModule):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x1, x2):
return x1, x2