Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer | |
| from mmengine.model import BaseModule | |
| from opencd.models.utils.builder import ITERACTION_LAYERS | |
| class ChannelExchange(BaseModule): | |
| """ | |
| channel exchange | |
| Args: | |
| p (float, optional): p of the features will be exchanged. | |
| Defaults to 1/2. | |
| """ | |
| def __init__(self, p=1/2): | |
| super().__init__() | |
| assert p >= 0 and p <= 1 | |
| self.p = int(1/p) | |
| def forward(self, x1, x2): | |
| N, c, h, w = x1.shape | |
| exchange_map = torch.arange(c) % self.p == 0 | |
| exchange_mask = exchange_map.unsqueeze(0).expand((N, -1)) | |
| out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2) | |
| out_x1[~exchange_mask, ...] = x1[~exchange_mask, ...] | |
| out_x2[~exchange_mask, ...] = x2[~exchange_mask, ...] | |
| out_x1[exchange_mask, ...] = x2[exchange_mask, ...] | |
| out_x2[exchange_mask, ...] = x1[exchange_mask, ...] | |
| return out_x1, out_x2 | |
| class SpatialExchange(BaseModule): | |
| """ | |
| spatial exchange | |
| Args: | |
| p (float, optional): p of the features will be exchanged. | |
| Defaults to 1/2. | |
| """ | |
| def __init__(self, p=1/2): | |
| super().__init__() | |
| assert p >= 0 and p <= 1 | |
| self.p = int(1/p) | |
| def forward(self, x1, x2): | |
| N, c, h, w = x1.shape | |
| exchange_mask = torch.arange(w) % self.p == 0 | |
| out_x1, out_x2 = torch.zeros_like(x1), torch.zeros_like(x2) | |
| out_x1[..., ~exchange_mask] = x1[..., ~exchange_mask] | |
| out_x2[..., ~exchange_mask] = x2[..., ~exchange_mask] | |
| out_x1[..., exchange_mask] = x2[..., exchange_mask] | |
| out_x2[..., exchange_mask] = x1[..., exchange_mask] | |
| return out_x1, out_x2 | |
| class Aggregation_distribution(BaseModule): | |
| # Aggregation_Distribution Layer (AD) | |
| def __init__(self, | |
| channels, | |
| num_paths=2, | |
| attn_channels=None, | |
| act_cfg=dict(type='ReLU'), | |
| norm_cfg=dict(type='BN', requires_grad=True)): | |
| super().__init__() | |
| self.num_paths = num_paths # `2` is supported. | |
| attn_channels = attn_channels or channels // 16 | |
| attn_channels = max(attn_channels, 8) | |
| self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) | |
| self.bn = build_norm_layer(norm_cfg, attn_channels)[1] | |
| self.act = build_activation_layer(act_cfg) | |
| self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False) | |
| def forward(self, x1, x2): | |
| x = torch.stack([x1, x2], dim=1) | |
| attn = x.sum(1).mean((2, 3), keepdim=True) | |
| attn = self.fc_reduce(attn) | |
| attn = self.bn(attn) | |
| attn = self.act(attn) | |
| attn = self.fc_select(attn) | |
| B, C, H, W = attn.shape | |
| attn1, attn2 = attn.reshape(B, self.num_paths, C // self.num_paths, H, W).transpose(0, 1) | |
| attn1 = torch.sigmoid(attn1) | |
| attn2 = torch.sigmoid(attn2) | |
| return x1 * attn1, x2 * attn2 | |
| class TwoIdentity(BaseModule): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| def forward(self, x1, x2): | |
| return x1, x2 | |