Zevin2023 commited on
Commit
430f024
1 Parent(s): 63ccb58

Update models/monet.py

Browse files
Files changed (1) hide show
  1. models/monet.py +17 -6
models/monet.py CHANGED
@@ -69,7 +69,7 @@ class MAL(nn.Module):
69
  Multi-view Attention Learning (MAL) module
70
  """
71
 
72
- def __init__(self, in_dim=768, feature_num=4, feature_size=28):
73
  super().__init__()
74
 
75
  self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
@@ -82,9 +82,14 @@ class MAL(nn.Module):
82
 
83
  self.feature_num = feature_num
84
  self.in_dim = in_dim
 
85
 
86
  def forward(self, features):
87
- feature = torch.tensor([]).cuda()
 
 
 
 
88
  for index, _ in enumerate(features):
89
  feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
90
  features = feature
@@ -118,7 +123,7 @@ class SaveOutput:
118
 
119
 
120
  class MoNet(nn.Module):
121
- def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
122
  super().__init__()
123
  self.img_size = img_size
124
  self.input_size = img_size // patch_size
@@ -136,10 +141,10 @@ class MoNet(nn.Module):
136
 
137
  self.MALs = nn.ModuleList()
138
  for _ in range(config.mal_num):
139
- self.MALs.append(MAL())
140
 
141
  # Image Quality Score Regression
142
- self.fusion_wam = MAL(feature_num=config.mal_num)
143
  self.block = Block(dim_mlp, 12)
144
  self.cnn = nn.Sequential(
145
  nn.Conv2d(dim_mlp, 256, 5),
@@ -163,6 +168,8 @@ class MoNet(nn.Module):
163
  nn.Sigmoid()
164
  )
165
 
 
 
166
  def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
167
  x1 = save_output.outputs[block_index[0]][:, 1:]
168
  x2 = save_output.outputs[block_index[1]][:, 1:]
@@ -182,7 +189,11 @@ class MoNet(nn.Module):
182
  x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
183
 
184
  # Different Opinion Features (DOF)
185
- DOF = torch.tensor([]).cuda()
 
 
 
 
186
  for index, _ in enumerate(self.MALs):
187
  DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
188
  DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
 
69
  Multi-view Attention Learning (MAL) module
70
  """
71
 
72
+ def __init__(self, in_dim=768, feature_num=4, feature_size=28, is_gpu=True):
73
  super().__init__()
74
 
75
  self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
 
82
 
83
  self.feature_num = feature_num
84
  self.in_dim = in_dim
85
+ self.is_gpu = is_gpu
86
 
87
  def forward(self, features):
88
+ if self.is_gpu:
89
+ feature = torch.tensor([]).cuda()
90
+ else:
91
+ feature = torch.tensor([])
92
+
93
  for index, _ in enumerate(features):
94
  feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
95
  features = feature
 
123
 
124
 
125
  class MoNet(nn.Module):
126
+ def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224, is_gpu=True):
127
  super().__init__()
128
  self.img_size = img_size
129
  self.input_size = img_size // patch_size
 
141
 
142
  self.MALs = nn.ModuleList()
143
  for _ in range(config.mal_num):
144
+ self.MALs.append(MAL(is_gpu=is_gpu))
145
 
146
  # Image Quality Score Regression
147
+ self.fusion_wam = MAL(feature_num=config.mal_num, is_gpu=is_gpu)
148
  self.block = Block(dim_mlp, 12)
149
  self.cnn = nn.Sequential(
150
  nn.Conv2d(dim_mlp, 256, 5),
 
168
  nn.Sigmoid()
169
  )
170
 
171
+ self.is_gpu = is_gpu
172
+
173
  def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
174
  x1 = save_output.outputs[block_index[0]][:, 1:]
175
  x2 = save_output.outputs[block_index[1]][:, 1:]
 
189
  x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
190
 
191
  # Different Opinion Features (DOF)
192
+ if self.is_gpu:
193
+ DOF = torch.tensor([]).cuda()
194
+ else:
195
+ DOF = torch.tensor([])
196
+
197
  for index, _ in enumerate(self.MALs):
198
  DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
199
  DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28