Update models/monet.py
Browse files- models/monet.py +17 -6
models/monet.py
CHANGED
@@ -69,7 +69,7 @@ class MAL(nn.Module):
|
|
69 |
Multi-view Attention Learning (MAL) module
|
70 |
"""
|
71 |
|
72 |
-
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
73 |
super().__init__()
|
74 |
|
75 |
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
@@ -82,9 +82,14 @@ class MAL(nn.Module):
|
|
82 |
|
83 |
self.feature_num = feature_num
|
84 |
self.in_dim = in_dim
|
|
|
85 |
|
86 |
def forward(self, features):
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
for index, _ in enumerate(features):
|
89 |
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
90 |
features = feature
|
@@ -118,7 +123,7 @@ class SaveOutput:
|
|
118 |
|
119 |
|
120 |
class MoNet(nn.Module):
|
121 |
-
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
122 |
super().__init__()
|
123 |
self.img_size = img_size
|
124 |
self.input_size = img_size // patch_size
|
@@ -136,10 +141,10 @@ class MoNet(nn.Module):
|
|
136 |
|
137 |
self.MALs = nn.ModuleList()
|
138 |
for _ in range(config.mal_num):
|
139 |
-
self.MALs.append(MAL())
|
140 |
|
141 |
# Image Quality Score Regression
|
142 |
-
self.fusion_wam = MAL(feature_num=config.mal_num)
|
143 |
self.block = Block(dim_mlp, 12)
|
144 |
self.cnn = nn.Sequential(
|
145 |
nn.Conv2d(dim_mlp, 256, 5),
|
@@ -163,6 +168,8 @@ class MoNet(nn.Module):
|
|
163 |
nn.Sigmoid()
|
164 |
)
|
165 |
|
|
|
|
|
166 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
167 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
168 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
@@ -182,7 +189,11 @@ class MoNet(nn.Module):
|
|
182 |
x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
|
183 |
|
184 |
# Different Opinion Features (DOF)
|
185 |
-
|
|
|
|
|
|
|
|
|
186 |
for index, _ in enumerate(self.MALs):
|
187 |
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
188 |
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
|
|
|
69 |
Multi-view Attention Learning (MAL) module
|
70 |
"""
|
71 |
|
72 |
+
def __init__(self, in_dim=768, feature_num=4, feature_size=28, is_gpu=True):
|
73 |
super().__init__()
|
74 |
|
75 |
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
|
|
82 |
|
83 |
self.feature_num = feature_num
|
84 |
self.in_dim = in_dim
|
85 |
+
self.is_gpu = is_gpu
|
86 |
|
87 |
def forward(self, features):
|
88 |
+
if self.is_gpu:
|
89 |
+
feature = torch.tensor([]).cuda()
|
90 |
+
else:
|
91 |
+
feature = torch.tensor([])
|
92 |
+
|
93 |
for index, _ in enumerate(features):
|
94 |
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
95 |
features = feature
|
|
|
123 |
|
124 |
|
125 |
class MoNet(nn.Module):
|
126 |
+
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224, is_gpu=True):
|
127 |
super().__init__()
|
128 |
self.img_size = img_size
|
129 |
self.input_size = img_size // patch_size
|
|
|
141 |
|
142 |
self.MALs = nn.ModuleList()
|
143 |
for _ in range(config.mal_num):
|
144 |
+
self.MALs.append(MAL(is_gpu=is_gpu))
|
145 |
|
146 |
# Image Quality Score Regression
|
147 |
+
self.fusion_wam = MAL(feature_num=config.mal_num, is_gpu=is_gpu)
|
148 |
self.block = Block(dim_mlp, 12)
|
149 |
self.cnn = nn.Sequential(
|
150 |
nn.Conv2d(dim_mlp, 256, 5),
|
|
|
168 |
nn.Sigmoid()
|
169 |
)
|
170 |
|
171 |
+
self.is_gpu = is_gpu
|
172 |
+
|
173 |
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
174 |
x1 = save_output.outputs[block_index[0]][:, 1:]
|
175 |
x2 = save_output.outputs[block_index[1]][:, 1:]
|
|
|
189 |
x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
|
190 |
|
191 |
# Different Opinion Features (DOF)
|
192 |
+
if self.is_gpu:
|
193 |
+
DOF = torch.tensor([]).cuda()
|
194 |
+
else:
|
195 |
+
DOF = torch.tensor([])
|
196 |
+
|
197 |
for index, _ in enumerate(self.MALs):
|
198 |
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
199 |
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
|