fundus_img / models /modified_dual_densenet.py
dongsheng's picture
Upload 6 files
be05fd1
raw
history blame
2.53 kB
import torch
from torch import nn
from torchvision.models import densenet169
from config.finetune_config import set_args
args = set_args()
class Classifier(nn.Module):
def __init__(self, num_classes):
super(Classifier, self).__init__()
self.GDConv1 = nn.Conv2d(1664 * 2, 1024, kernel_size=4, padding=0, dilation=2)
self.GDConv2 = nn.Conv2d(1664 * 2, 1024, kernel_size=5, padding=1, dilation=2)
self.GDConv3 = nn.Conv2d(1664 * 2, 1024, kernel_size=3, padding=0, dilation=3)
self.LN1 = nn.LayerNorm([1024, 1, 1])
self.LN2 = nn.LayerNorm([1024, 1, 1])
self.LN3 = nn.LayerNorm([1024, 1, 1])
self.gelu = nn.GELU()
self.fc_dropout = nn.Dropout(0.2)
self.fc = nn.Linear(1024 * 3, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def forward(self, x):
x1 = self.GDConv1(x)
x1 = self.LN1(x1)
x1 = x1.view(x1.size(0), -1)
x2 = self.GDConv2(x)
x2 = self.LN2(x2)
x2 = x2.view(x2.size(0), -1)
x3 = self.GDConv3(x)
x3 = self.LN3(x3)
x3 = x3.view(x3.size(0), -1)
X = torch.cat((x1, x2, x3), 1)
X = self.gelu(X)
output = self.fc(self.fc_dropout(X))
return output
class M_DenseNet(nn.Module):
def __init__(self, pretrain='IN', num_classes=8):
super(M_DenseNet, self).__init__()
# feature layer
if pretrain == 'IN':
model = densenet169(pretrained=True) # 此处的model参数是已经加载了预训练参数的模型
self.feature = nn.Sequential(*list(model.children())[:-1])
else:
model = torch.load(args.finetune_path)
self.feature = nn.Sequential(*list(model.children())[:-2])
self.classifier = Classifier(num_classes)
def forward(self, left, right):
left = self.feature(left)
right = self.feature(right)
x = torch.cat((left, right), 1)
X = self.classifier(x)
return X
if __name__ == '__main__':
model = M_DenseNet()
input1 = torch.normal(0, 1, size=(4, 3, 224, 224))
input2 = torch.normal(0, 1, size=(4, 3, 224, 224))
output = model(input1, input2)
print(output)