Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
class FeatureFusionModule(nn.Module): | |
def __init__(self, in_channels_list, out_channels, fusion_type='sum'): | |
super(FeatureFusionModule, self).__init__() | |
self.fusion_type = fusion_type | |
self.adjusted_features_convs = nn.ModuleList([nn.Conv2d(in_channels, out_channels, 1) for in_channels in in_channels_list]) | |
self.output_conv = nn.Conv2d(out_channels * len(in_channels_list) if fusion_type == 'concat' else out_channels, out_channels, 1) | |
def forward(self, features): | |
h, w = features[0].shape[-2:] | |
adjusted_features = [] | |
for feature, conv in zip(features, self.adjusted_features_convs): | |
# チャンネル数を調整 | |
feature = conv(feature) | |
# アップサンプリング | |
adjusted_feature = F.interpolate(feature, size=(h, w), mode='bilinear', align_corners=True) | |
adjusted_features.append(adjusted_feature) | |
if self.fusion_type == 'concat': | |
fused_feature = torch.cat(adjusted_features, dim=1) | |
elif self.fusion_type == 'sum': | |
fused_feature = sum(adjusted_features) | |
else: | |
raise ValueError(f"Invalid fusion_type: {self.fusion_type}") | |
fused_feature = self.output_conv(fused_feature) | |
return fused_feature | |
# # 特徴マップのリスト | |
# feature_maps = [ | |
# torch.randn(10, 32, 64, 64), | |
# torch.randn(10, 64, 32, 32), | |
# torch.randn(10, 128, 16, 16), | |
# torch.randn(10, 256, 8, 8), | |
# ] | |
# # モデルの構築 | |
# in_channels_list = [32, 64, 128, 256] # 各レベルのチャンネル数 | |
# out_channels = 64 # 融合後のチャンネル数を64に指定 | |
# fusion_type = 'sum' # 'sum'または'concat'を指定 | |
# feature_fusion_module = FeatureFusionModule(in_channels_list, out_channels, fusion_type=fusion_type) | |
# # 融合 | |
# fused_feature = feature_fusion_module(feature_maps) | |
# print(fused_feature.shape) # 融合後の特徴マップのサイズとチャンネル数を確認 | |