import torch from torch import nn from torchvision.models import densenet169 from config.finetune_config import set_args args = set_args() class Classifier(nn.Module): def __init__(self, num_classes): super(Classifier, self).__init__() self.GDConv1 = nn.Conv2d(1664 * 2, 1024, kernel_size=4, padding=0, dilation=2) self.GDConv2 = nn.Conv2d(1664 * 2, 1024, kernel_size=5, padding=1, dilation=2) self.GDConv3 = nn.Conv2d(1664 * 2, 1024, kernel_size=3, padding=0, dilation=3) self.LN1 = nn.LayerNorm([1024, 1, 1]) self.LN2 = nn.LayerNorm([1024, 1, 1]) self.LN3 = nn.LayerNorm([1024, 1, 1]) self.gelu = nn.GELU() self.fc_dropout = nn.Dropout(0.2) self.fc = nn.Linear(1024 * 3, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.GDConv1(x) x1 = self.LN1(x1) x1 = x1.view(x1.size(0), -1) x2 = self.GDConv2(x) x2 = self.LN2(x2) x2 = x2.view(x2.size(0), -1) x3 = self.GDConv3(x) x3 = self.LN3(x3) x3 = x3.view(x3.size(0), -1) X = torch.cat((x1, x2, x3), 1) X = self.gelu(X) output = self.fc(self.fc_dropout(X)) return output class M_DenseNet(nn.Module): def __init__(self, pretrain='IN', num_classes=8): super(M_DenseNet, self).__init__() # feature layer if pretrain == 'IN': model = densenet169(pretrained=True) # 此处的model参数是已经加载了预训练参数的模型 self.feature = nn.Sequential(*list(model.children())[:-1]) else: model = torch.load(args.finetune_path) self.feature = nn.Sequential(*list(model.children())[:-2]) self.classifier = Classifier(num_classes) def forward(self, left, right): left = self.feature(left) right = self.feature(right) x = torch.cat((left, right), 1) X = self.classifier(x) return X if __name__ == '__main__': model = M_DenseNet() input1 = torch.normal(0, 1, size=(4, 3, 224, 224)) input2 = torch.normal(0, 1, size=(4, 3, 224, 224)) output = model(input1, input2) print(output)