TTP / opencd /models /necks /feature_fusion.py
KyanChen's picture
Upload 1861 files
3b96cb1
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from opencd.registry import MODELS
@MODELS.register_module()
class FeatureFusionNeck(BaseModule):
"""Feature Fusion Neck.
Args:
policy (str): The operation to fuse features. candidates
are `concat`, `sum`, `diff` and `Lp_distance`.
in_channels (Sequence(int)): Input channels.
channels (int): Channels after modules, before conv_seg.
out_indices (tuple[int]): Output from which layer.
"""
def __init__(self,
policy,
in_channels=None,
channels=None,
out_indices=(0, 1, 2, 3)):
super().__init__()
self.policy = policy
self.in_channels = in_channels
self.channels = channels
self.out_indices = out_indices
@staticmethod
def fusion(x1, x2, policy):
"""Specify the form of feature fusion"""
_fusion_policies = ['concat', 'sum', 'diff', 'abs_diff']
assert policy in _fusion_policies, 'The fusion policies {} are ' \
'supported'.format(_fusion_policies)
if policy == 'concat':
x = torch.cat([x1, x2], dim=1)
elif policy == 'sum':
x = x1 + x2
elif policy == 'diff':
x = x2 - x1
elif policy == 'abs_diff':
x = torch.abs(x1 - x2)
return x
def forward(self, x1, x2):
"""Forward function."""
assert len(x1) == len(x2), "The features x1 and x2 from the" \
"backbone should be of equal length"
outs = []
for i in range(len(x1)):
out = self.fusion(x1[i], x2[i], self.policy)
outs.append(out)
outs = [outs[i] for i in self.out_indices]
return tuple(outs)