File size: 994 Bytes
9b9b1dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn

from model.MIRNet.MultiScaleResidualBlock import MultiScaleResidualBlock


class ResidualRecurrentGroup(nn.Module):
    """
    Group of multi-scale residual blocks followed by a convolutional layer. The output is what is added to the input image for restoration.
    """

    def __init__(
        self, num_features, number_msrb_blocks, height, width, stride, bias=False
    ):
        super().__init__()
        blocks = [
            MultiScaleResidualBlock(num_features, height, width, stride, bias)
            for _ in range(number_msrb_blocks)
        ]
        blocks.append(
            nn.Conv2d(
                num_features,
                num_features,
                kernel_size=3,
                padding=1,
                stride=1,
                bias=bias,
            )
        )
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        output = self.blocks(x)
        return x + output  # restored image, HxWxC