File size: 2,531 Bytes
be05fd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
from torch import nn
from torchvision.models import densenet169
from config.finetune_config import set_args

args = set_args()


class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.GDConv1 = nn.Conv2d(1664 * 2, 1024, kernel_size=4, padding=0, dilation=2)
        self.GDConv2 = nn.Conv2d(1664 * 2, 1024, kernel_size=5, padding=1, dilation=2)
        self.GDConv3 = nn.Conv2d(1664 * 2, 1024, kernel_size=3, padding=0, dilation=3)
        self.LN1 = nn.LayerNorm([1024, 1, 1])
        self.LN2 = nn.LayerNorm([1024, 1, 1])
        self.LN3 = nn.LayerNorm([1024, 1, 1])
        self.gelu = nn.GELU()
        self.fc_dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(1024 * 3, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x1 = self.GDConv1(x)
        x1 = self.LN1(x1)
        x1 = x1.view(x1.size(0), -1)

        x2 = self.GDConv2(x)
        x2 = self.LN2(x2)
        x2 = x2.view(x2.size(0), -1)

        x3 = self.GDConv3(x)
        x3 = self.LN3(x3)
        x3 = x3.view(x3.size(0), -1)

        X = torch.cat((x1, x2, x3), 1)
        X = self.gelu(X)
        output = self.fc(self.fc_dropout(X))

        return output


class M_DenseNet(nn.Module):
    def __init__(self, pretrain='IN', num_classes=8):
        super(M_DenseNet, self).__init__()
        # feature layer
        if pretrain == 'IN':
            model = densenet169(pretrained=True)  # 此处的model参数是已经加载了预训练参数的模型
            self.feature = nn.Sequential(*list(model.children())[:-1])
        else:
            model = torch.load(args.finetune_path)
            self.feature = nn.Sequential(*list(model.children())[:-2])
        self.classifier = Classifier(num_classes)

    def forward(self, left, right):
        left = self.feature(left)
        right = self.feature(right)
        x = torch.cat((left, right), 1)

        X = self.classifier(x)

        return X


if __name__ == '__main__':
    model = M_DenseNet()
    input1 = torch.normal(0, 1, size=(4, 3, 224, 224))
    input2 = torch.normal(0, 1, size=(4, 3, 224, 224))
    output = model(input1, input2)
    print(output)