Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| class FeatureMerge(nn.Module): | |
| """Multimodal feature fusion used in VSR.""" | |
| def __init__( | |
| self, | |
| feature_names, | |
| visual_dim, | |
| semantic_dim, | |
| merge_type="Sum", | |
| dropout_ratio=0.1, | |
| with_extra_fc=True, | |
| shortcut=False, | |
| ): | |
| """Multimodal feature merge used in VSR. | |
| Args: | |
| visual_dim (list): the dim of visual features, e.g. [256] | |
| semantic_dim (list): the dim of semantic features, e.g. [256] | |
| merge_type (str): fusion type, e.g. 'Sum', 'Concat', 'Weighted' | |
| dropout_ratio (float): dropout ratio of fusion features | |
| with_extra_fc (bool): whether add extra fc layers for adaptation | |
| shortcut (bool): whether add shortcut connection | |
| """ | |
| super().__init__() | |
| # merge param | |
| self.feature_names = feature_names | |
| self.merge_type = merge_type | |
| self.visual_dim = visual_dim | |
| self.textual_dim = semantic_dim | |
| self.with_extra_fc = with_extra_fc | |
| self.shortcut = shortcut | |
| self.relu = nn.ReLU(inplace=True) | |
| if self.merge_type == "Sum": | |
| assert len(self.visual_dim) == len(self.textual_dim) | |
| elif self.merge_type == "Concat": | |
| assert len(self.visual_dim) == len(self.textual_dim) | |
| # self.concat_proj = nn.ModuleList() | |
| self.vis_proj = nn.ModuleList() | |
| self.text_proj = nn.ModuleList() | |
| self.alpha_proj = nn.ModuleList() | |
| for idx in range(len(self.visual_dim)): | |
| # self.concat_proj.append(nn.Conv2d(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx], kernel_size = (1,1), stride=1)) | |
| if self.with_extra_fc: | |
| self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx])) | |
| self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx])) | |
| self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx])) | |
| elif self.merge_type == "Weighted": | |
| assert len(self.visual_dim) == len(self.textual_dim) | |
| self.total_num = len(self.visual_dim) | |
| # vis projection | |
| self.vis_proj = nn.ModuleList() | |
| self.vis_proj_relu = nn.ModuleList() | |
| # text projection | |
| self.text_proj = nn.ModuleList() | |
| self.text_proj_relu = nn.ModuleList() | |
| self.alpha_proj = nn.ModuleList() | |
| for idx in range(self.total_num): | |
| if self.with_extra_fc: | |
| self.vis_proj.append(nn.Linear(self.visual_dim[idx], self.visual_dim[idx])) | |
| self.text_proj.append(nn.Linear(self.textual_dim[idx], self.textual_dim[idx])) | |
| self.alpha_proj.append(nn.Linear(self.visual_dim[idx] + self.textual_dim[idx], self.visual_dim[idx])) | |
| else: | |
| raise "Unknown merge type {}".format(self.merge_type) | |
| self.dropout = nn.Dropout(dropout_ratio) | |
| # visual context | |
| # self.visual_ap = nn.AdaptiveAvgPool2d((1, 1)) | |
| def forward(self, visual_feat=None, textual_feat=None): | |
| """Forward computation | |
| Args: | |
| visual_feat (list(Tensor)): visual feature maps, in shape of [L x C x H x W] x B | |
| textual_feat (Tensor): textual feature maps, in shape of B x L x C | |
| Returns: | |
| Tensor: fused feature maps, in shape of [B x L x C] | |
| """ | |
| assert len(visual_feat) == len(textual_feat) | |
| # feature merge | |
| merged_feat = {} | |
| if self.merge_type == "Sum": | |
| for name in self.feature_names: | |
| merged_feat[name] = visual_feat[name] + textual_feat[name] | |
| elif self.merge_type == "Concat": | |
| for idx, name in enumerate(self.feature_names): | |
| # merged_feat[name] = self.concat_proj[idx](torch.cat((visual_feat[name],textual_feat[name]),1)) | |
| per_vis = visual_feat[name].permute(0, 2, 3, 1) | |
| per_text = textual_feat[name].permute(0, 2, 3, 1) | |
| if self.with_extra_fc: | |
| per_vis = self.relu(self.vis_proj[idx](per_vis)) | |
| per_text = self.relu(self.text_proj[idx](per_text)) | |
| x_sentence = self.alpha_proj[idx](torch.cat((per_vis, per_text), -1)) | |
| x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous() | |
| merged_feat[name] = x_sentence | |
| else: | |
| assert self.total_num == len(visual_feat) or self.total_num == 1 | |
| # for per_vis, per_text in zip(visual_feat, textual_feat): | |
| for idx, name in enumerate(self.feature_names): | |
| per_vis = visual_feat[name].permute(0, 2, 3, 1) | |
| per_text = textual_feat[name].permute(0, 2, 3, 1) | |
| if self.with_extra_fc: | |
| per_vis = self.relu(self.vis_proj[idx](per_vis)) | |
| per_text = self.relu(self.text_proj[idx](per_text)) | |
| alpha = torch.sigmoid(self.alpha_proj[idx](torch.cat((per_vis, per_text), -1))) | |
| if self.shortcut: | |
| # shortcut | |
| x_sentence = per_vis + alpha * per_text | |
| else: | |
| # selection | |
| x_sentence = alpha * per_vis + (1 - alpha) * per_text | |
| x_sentence = x_sentence.permute(0, 3, 1, 2).contiguous() | |
| merged_feat[name] = x_sentence | |
| return merged_feat | |