Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,148 Bytes
c295391 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch.nn as nn
import torch
import torch.nn.functional as F
class FeatureFusionModule(nn.Module):
def __init__(self, in_channels_list, out_channels, fusion_type='sum'):
super(FeatureFusionModule, self).__init__()
self.fusion_type = fusion_type
self.adjusted_features_convs = nn.ModuleList([nn.Conv2d(in_channels, out_channels, 1) for in_channels in in_channels_list])
self.output_conv = nn.Conv2d(out_channels * len(in_channels_list) if fusion_type == 'concat' else out_channels, out_channels, 1)
def forward(self, features):
h, w = features[0].shape[-2:]
adjusted_features = []
for feature, conv in zip(features, self.adjusted_features_convs):
# チャンネル数を調整
feature = conv(feature)
# アップサンプリング
adjusted_feature = F.interpolate(feature, size=(h, w), mode='bilinear', align_corners=True)
adjusted_features.append(adjusted_feature)
if self.fusion_type == 'concat':
fused_feature = torch.cat(adjusted_features, dim=1)
elif self.fusion_type == 'sum':
fused_feature = sum(adjusted_features)
else:
raise ValueError(f"Invalid fusion_type: {self.fusion_type}")
fused_feature = self.output_conv(fused_feature)
return fused_feature
# # 特徴マップのリスト
# feature_maps = [
# torch.randn(10, 32, 64, 64),
# torch.randn(10, 64, 32, 32),
# torch.randn(10, 128, 16, 16),
# torch.randn(10, 256, 8, 8),
# ]
# # モデルの構築
# in_channels_list = [32, 64, 128, 256] # 各レベルのチャンネル数
# out_channels = 64 # 融合後のチャンネル数を64に指定
# fusion_type = 'sum' # 'sum'または'concat'を指定
# feature_fusion_module = FeatureFusionModule(in_channels_list, out_channels, fusion_type=fusion_type)
# # 融合
# fused_feature = feature_fusion_module(feature_maps)
# print(fused_feature.shape) # 融合後の特徴マップのサイズとチャンネル数を確認
|