File size: 3,742 Bytes
b6bb35e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import Optional, Dict, Any, List
import torch
import torch.nn as nn

# -----------------------------------------------------------------------------
# Blocks
# -----------------------------------------------------------------------------

class Conv2d(nn.Module):
    """ Perform a 2D convolution

    inputs are [b, c, h, w] where 
        b is the batch size
        c is the number of channels 
        h is the height
        w is the width
    """
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 kernel_size: int, 
                 padding: int,
                 do_activation: bool = True, 
                 ):
        super(Conv2d, self).__init__()

        conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        lst = [conv]

        if do_activation:
            lst.append(nn.PReLU())

        self.conv = nn.Sequential(*lst)

    def forward(self, x):
        # x is [B, C, H, W]
        return self.conv(x)
    
# -----------------------------------------------------------------------------
# Network
# -----------------------------------------------------------------------------

class _UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,
                 out_channels: int = 1,
                 features: List[int] = [64, 64, 64, 64, 64],
                 conv_kernel_size: int = 3,
                 conv: Optional[nn.Module] = None,
                 conv_kwargs: Dict[str,Any] = {}
                 ):
        """
        UNet (but can switch out the Conv)
        """
        super(_UNet, self).__init__()

        self.in_channels = in_channels

        padding = (conv_kernel_size - 1) // 2

        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of U-Net
        for feat in features:
            self.downs.append(
                conv(
                    in_channels, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
                )
            )
            in_channels = feat

        # Up part of U-Net
        for feat in reversed(features):
            self.ups.append(nn.UpsamplingBilinear2d(scale_factor=2))
            self.ups.append(
                conv(
                    # Factor of 2 is for the skip connections
                    feat * 2, feat, kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
                )
            )

        self.bottleneck = conv(
            features[-1], features[-1], kernel_size=conv_kernel_size, padding=padding, **conv_kwargs
            )
        self.final_conv = conv(
            features[0], out_channels, kernel_size=1, padding=0, do_activation=False, **conv_kwargs
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx // 2]

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx + 1](concat_skip)

        return self.final_conv(x)
    

class UNet(_UNet):
    """
    Unet with normal conv blocks

    input shape: B x C x H x W
    output shape: B x C x H x W 
    """
    def __init__(self, **kwargs) -> None:
        super().__init__(conv=Conv2d, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return super().forward(x)