File size: 4,229 Bytes
d360108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Optional

import torch.nn as nn
import torch

class BasicBlock(nn.Module):
    """ResNet Basic Block.

    Parameters
    ----------
    in_channels : int
        Number of input channels
    out_channels : int
        Number of output channels
    stride : int, optional
        Convolution stride size, by default 1
    identity_downsample : Optional[torch.nn.Module], optional
        Downsampling layer, by default None
    """

    def __init__(self,
                in_channels: int,
                out_channels: int,
                stride: int = 1,
                identity_downsample: Optional[torch.nn.Module] = None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,
                              out_channels,
                              kernel_size = 3,
                              stride = stride,
                              padding = 1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels,
                              out_channels,
                              kernel_size = 3,
                              stride = 1,
                              padding = 1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.identity_downsample = identity_downsample

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply forward computation."""
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        # Apply an operation to the identity output.
        # Useful to reduce the layer size and match from conv2 output
        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)
        x += identity
        x = self.relu(x)
        return x

class ResNet18(nn.Module):
    """Construct ResNet-18 Model.

    Parameters
    ----------
    input_channels : int
        Number of input channels
    num_classes : int
        Number of class outputs
    """

    def __init__(self, input_channels, num_classes):

        super(ResNet18, self).__init__()
        self.conv1 = nn.Conv2d(input_channels,
                               64, kernel_size = 7,
                              stride = 2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size = 3,
                                   stride = 2,
                                   padding = 1)

        self.layer1 = self._make_layer(64, 64, stride = 1)
        self.layer2 = self._make_layer(64, 128, stride = 2)
        self.layer3 = self._make_layer(128, 256, stride = 2)
        self.layer4 = self._make_layer(256, 512, stride = 2)

        # Last layers
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module:
        """Downsampling block to reduce the feature sizes."""
        return nn.Sequential(
             nn.Conv2d(in_channels,
                       out_channels,
                       kernel_size = 3,
                       stride = 2,
                       padding = 1),
            nn.BatchNorm2d(out_channels)
        )

    def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module:
        """Create sequential basic block."""
        identity_downsample = None

        # Add downsampling function
        if stride != 1:
            identity_downsample = self.identity_downsample(in_channels, out_channels)

        return nn.Sequential(
                    BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride),
                    BasicBlock(out_channels, out_channels)
                    )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x