AQC-Net-PH-space / model.py
jervinjosh68's picture
added app.py and others
2794d4e
raw
history blame contribute delete
No virus
1.48 kB
import torch
import torch.nn as nn
from torch import Tensor as tensor
from torch.nn import functional as F
import torchvision.models as models
class SCA_Block(nn.Module):
def __init__(self, in_channel, downsample_channel):
super().__init__()
self.conv_A = nn.Conv2d(in_channel, downsample_channel, (1,1))
self.conv_B = nn.Conv2d(in_channel, downsample_channel, (1,1))
self.conv_E = nn.Conv2d(in_channel, downsample_channel, (1,1))
self.linear = nn.Linear(downsample_channel,in_channel)
def forward(self, feature_in):
b_size,c,w,h = feature_in.shape
A = self.conv_A(feature_in)
B = self.conv_B(feature_in)
E = self.conv_E(feature_in)
c1 = A.shape[1]
Z = F.softmax(torch.dot(torch.reshape(A,(b_size,c1,-1)),
torch.reshape(B,(b_size,-1,c1))), axis = 1 )
D = torch.reshape( Z * torch.reshape(E,(b_size,c1,-1)) , (b_size,c1,w,h))
out = feature_in * F.sigmoid(F.adaptive_avg_pool2d(D))
return out
class AQC_NET(nn.Module):
def __init__(self, pretrain = True, num_label = 5):
super().__init__()
self.resnet18 = models.resnet18(pretrained = pretrain)
self.resnet18.layer3[0].add_module('sca_1', SCA_Block(256,16))
self.resnet18.layer3[1].add_module('sca_2', SCA_Block(256,16))
self.resnet18.fc = nn.Linear(512,num_label)
def forward(self,x):
return F.softmax(self.resnet18(x))