Jassk28 commited on
Commit
aac53ff
1 Parent(s): d9ae36e

Upload Multimodalmodel.py

Browse files
Files changed (1) hide show
  1. net/Multimodalmodel.py +41 -0
net/Multimodalmodel.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from utils.config import cfg
5
+ from utils.basicblocks import BasicBlock
6
+ from utils.feature_fusion_block import DCT_Attention_Fusion_Conv
7
+ from utils.classifier import ClassifierModel
8
+
9
+ class Image_n_DCT(nn.Module):
10
+ def __init__(self,):
11
+ super(Image_n_DCT, self).__init__()
12
+ self.Img_Block = nn.ModuleList()
13
+ self.DCT_Block = nn.ModuleList()
14
+ self.RGB_n_DCT_Fusion = nn.ModuleList()
15
+ self.num_classes = len(cfg.CLASSES)
16
+
17
+
18
+
19
+ for i in range(len(cfg.MULTIMODAL_FUSION.IMG_CHANNELS) - 1):
20
+ self.Img_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
21
+ self.DCT_Block.append(BasicBlock(cfg.MULTIMODAL_FUSION.DCT_CHANNELS[i], cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1], stride=1))
22
+ self.RGB_n_DCT_Fusion.append(DCT_Attention_Fusion_Conv(cfg.MULTIMODAL_FUSION.IMG_CHANNELS[i+1]))
23
+
24
+
25
+ self.classifier = ClassifierModel(self.num_classes)
26
+
27
+
28
+
29
+ def forward(self, rgb_image, dct_image):
30
+ image = [rgb_image]
31
+ dct_image = [dct_image]
32
+
33
+ for i in range(len(self.Img_Block)):
34
+ image.append(self.Img_Block[i](image[-1]))
35
+ dct_image.append(self.DCT_Block[i](dct_image[-1]))
36
+ image[-1] = self.RGB_n_DCT_Fusion[i](image[-1], dct_image[-1])
37
+ dct_image[-1] = image[-1]
38
+ out = self.classifier(image[-1])
39
+
40
+ return out
41
+