|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmengine.model import BaseModule |
|
|
|
from opencd.registry import MODELS |
|
|
|
|
|
@MODELS.register_module() |
|
class FeatureFusionNeck(BaseModule): |
|
"""Feature Fusion Neck. |
|
|
|
Args: |
|
policy (str): The operation to fuse features. candidates |
|
are `concat`, `sum`, `diff` and `Lp_distance`. |
|
in_channels (Sequence(int)): Input channels. |
|
channels (int): Channels after modules, before conv_seg. |
|
out_indices (tuple[int]): Output from which layer. |
|
""" |
|
|
|
def __init__(self, |
|
policy, |
|
in_channels=None, |
|
channels=None, |
|
out_indices=(0, 1, 2, 3)): |
|
super().__init__() |
|
self.policy = policy |
|
self.in_channels = in_channels |
|
self.channels = channels |
|
self.out_indices = out_indices |
|
|
|
@staticmethod |
|
def fusion(x1, x2, policy): |
|
"""Specify the form of feature fusion""" |
|
|
|
_fusion_policies = ['concat', 'sum', 'diff', 'abs_diff'] |
|
assert policy in _fusion_policies, 'The fusion policies {} are ' \ |
|
'supported'.format(_fusion_policies) |
|
|
|
if policy == 'concat': |
|
x = torch.cat([x1, x2], dim=1) |
|
elif policy == 'sum': |
|
x = x1 + x2 |
|
elif policy == 'diff': |
|
x = x2 - x1 |
|
elif policy == 'abs_diff': |
|
x = torch.abs(x1 - x2) |
|
|
|
return x |
|
|
|
def forward(self, x1, x2): |
|
"""Forward function.""" |
|
|
|
assert len(x1) == len(x2), "The features x1 and x2 from the" \ |
|
"backbone should be of equal length" |
|
outs = [] |
|
for i in range(len(x1)): |
|
out = self.fusion(x1[i], x2[i], self.policy) |
|
outs.append(out) |
|
|
|
outs = [outs[i] for i in self.out_indices] |
|
return tuple(outs) |