File size: 1,431 Bytes
3a273df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import torch.nn as nn

from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup


class MIRNet(nn.Module):
    """
    Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions.
    Their output is added to the input image for restoration.

    Please refer to the documentation of the different blocks of the model in this folder for detailed explanations.
    """

    def __init__(
        self,
        in_channels=3,
        out_channels=3,
        num_features=64,
        kernel_size=3,
        stride=2,
        number_msrb=2,
        number_rrg=3,
        height=3,
        width=2,
        bias=False,
    ):
        super().__init__()
        self.conv_start = nn.Conv2d(
            in_channels, num_features, kernel_size, padding=1, bias=bias
        )
        msrb_blocks = [
            ResidualRecurrentGroup(
                num_features, number_msrb, height, width, stride, bias
            )
            for _ in range(number_rrg)
        ]
        self.msrb_blocks = nn.Sequential(*msrb_blocks)
        self.conv_end = nn.Conv2d(
            num_features, out_channels, kernel_size, padding=1, bias=bias
        )

    def forward(self, x):
        output = self.conv_start(x)
        output = self.msrb_blocks(output)
        output = self.conv_end(output)
        return x + output  # restored image, HxWxC