""" The completion for Mean-opinion Network(MoNet) """ import torch import torch.nn as nn import timm from timm.models.vision_transformer import Block from einops import rearrange class Attention_Block(nn.Module): def __init__(self, dim, drop=0.1): super().__init__() self.c_q = nn.Linear(dim, dim) self.c_k = nn.Linear(dim, dim) self.c_v = nn.Linear(dim, dim) self.norm_fact = dim ** -0.5 self.softmax = nn.Softmax(dim=-1) self.proj_drop = nn.Dropout(drop) def forward(self, x): _x = x B, C, N = x.shape q = self.c_q(x) k = self.c_k(x) v = self.c_v(x) attn = q @ k.transpose(-2, -1) * self.norm_fact attn = self.softmax(attn) x = (attn @ v).transpose(1, 2).reshape(B, C, N) x = self.proj_drop(x) x = x + _x return x class Self_Attention(nn.Module): """ Self attention Layer""" def __init__(self, in_dim): super(Self_Attention, self).__init__() self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) self.gamma = nn.Parameter(torch.zeros(1)) self.softmax = nn.Softmax(dim=-1) def forward(self, inFeature): bs, C, w, h = inFeature.size() proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1) proj_key = self.kConv(inFeature).view(bs, -1, w * h) energy = torch.bmm(proj_query, proj_key) attention = self.softmax(energy) proj_value = self.vConv(inFeature).view(bs, -1, w * h) out = torch.bmm(proj_value, attention.permute(0, 2, 1)) out = out.view(bs, C, w, h) out = self.gamma * out + inFeature return out class MAL(nn.Module): """ Multi-view Attention Learning (MAL) module """ def __init__(self, in_dim=768, feature_num=4, feature_size=28, is_gpu=True): super().__init__() self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention # Self attention module for each input feature self.attention_module = nn.ModuleList() for _ in range(feature_num): self.attention_module.append(Self_Attention(in_dim)) self.feature_num = feature_num self.in_dim = in_dim self.is_gpu = is_gpu def forward(self, features): if self.is_gpu: feature = torch.tensor([]).cuda() else: feature = torch.tensor([]) for index, _ in enumerate(features): feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0) features = feature input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28 bs, _, _ = input_tensor.shape # [2, 3072, 784] in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, c=self.feature_num) # bs, 768, 28 * 28 * feature_num feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768 in_channel = input_tensor.permute(0, 2, 1) # bs, 28 * 28, 768 * feature_num channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28 weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim, c=self.feature_num) + channel_weight_sum.permute(0, 2, 1)) / 2 # [2, 3072, 784] weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1) return weight_sum_res # bs, 768, 28 * 28 class SaveOutput: def __init__(self): self.outputs = [] def __call__(self, module, module_in, module_out): self.outputs.append(module_out) def clear(self): self.outputs = [] class MoNet(nn.Module): def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224, is_gpu=True): super().__init__() self.img_size = img_size self.input_size = img_size // patch_size self.dim_mlp = dim_mlp self.vit = timm.create_model(config.backbone, pretrained=False) self.save_output = SaveOutput() # Register Hooks hook_handles = [] for layer in self.vit.modules(): if isinstance(layer, Block): handle = layer.register_forward_hook(self.save_output) hook_handles.append(handle) self.MALs = nn.ModuleList() for _ in range(config.mal_num): self.MALs.append(MAL(is_gpu=is_gpu)) # Image Quality Score Regression self.fusion_wam = MAL(feature_num=config.mal_num, is_gpu=is_gpu) self.block = Block(dim_mlp, 12) self.cnn = nn.Sequential( nn.Conv2d(dim_mlp, 256, 5), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.AvgPool2d((2, 2)), nn.Conv2d(256, 128, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.AvgPool2d((2, 2)), nn.Conv2d(128, 128, 3), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.AvgPool2d((3, 3)), ) self.fc_score = nn.Sequential( nn.Linear(128, 128 // 2), nn.ReLU(), nn.Dropout(drop), nn.Linear(128 // 2, 1), nn.Sigmoid() ) self.is_gpu = is_gpu def extract_feature(self, save_output, block_index=[2, 5, 8, 11]): x1 = save_output.outputs[block_index[0]][:, 1:] x2 = save_output.outputs[block_index[1]][:, 1:] x3 = save_output.outputs[block_index[2]][:, 1:] x4 = save_output.outputs[block_index[3]][:, 1:] x = torch.cat((x1, x2, x3, x4), dim=2) return x def forward(self, x): # Multi-level Feature From Different Transformer Blocks _x = self.vit(x) x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4 self.save_output.outputs.clear() x = x.permute(0, 2, 1) # bs, 768 * 4, 28 * 28 x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28 x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28 # Different Opinion Features (DOF) if self.is_gpu: DOF = torch.tensor([]).cuda() else: DOF = torch.tensor([]) for index, _ in enumerate(self.MALs): DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28 # Image Quality Score Regression wam = self.fusion_wam(DOF).permute(0, 2, 1) # bs, 28 * 28 768 wam = self.block(wam).permute(0, 2, 1) wam = rearrange(wam, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) score = self.cnn(wam).squeeze(-1).squeeze(-1) score = self.fc_score(score).view(-1) return score if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--seed', dest='seed', type=int, default=3407) parser.add_argument('--gpu_id', dest='gpu_id', type=str, default='0') # model related parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', help='The backbone for MoNet.') parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.') # data related parser.add_argument('--dataset', dest='dataset', type=str, default='livec', help='Support datasets: livec|koniq10k|bid|spaq') parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=5, help='Number of sample patches from training image') parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25, help='Number of sample patches from testing image') parser.add_argument('--patch_size', dest='patch_size', type=int, default=224, help='Crop size for training & testing image patches') # training related parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate') parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-5, help='Weight decay') parser.add_argument('--batch_size', dest='batch_size', type=int, default=11, help='Batch size') parser.add_argument('--epochs', dest='epochs', type=int, default=50, help='Epochs for training') parser.add_argument('--T_max', dest='T_max', type=int, default=50, help='Hyper-parameter for CosineAnnealingLR') parser.add_argument('--eta_min', dest='eta_min', type=int, default=0, help='Hyper-parameter for CosineAnnealingLR') parser.add_argument('--save_path', dest='save_path', type=str, default='./training_for_IQA', help='The path where the model and logs will be saved.') config = parser.parse_args() # torch.autograd.set_detect_anomaly(True) # with torch.autograd.detect_anomaly(): in_tensor = torch.zeros((2, 3, 224, 224), dtype=torch.float).cuda() model = MoNet(config).cuda() res = model(in_tensor) print('{} : {} [M]'.format('#Params', sum(map(lambda x: x.numel(), model.parameters())) / 10 ** 6)) # label = torch.tensor([1, 2], dtype=torch.float).cuda() # loss = torch.nn.L1Loss().cuda() # # res = model(in_tensor) # # loss = loss_func() # l = loss(label, res) # print(l) # l.backward()