|
import torch |
|
import torch.nn as nn |
|
|
|
from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup |
|
|
|
|
|
class MIRNet(nn.Module): |
|
""" |
|
Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions. |
|
Their output is added to the input image for restoration. |
|
|
|
Please refer to the documentation of the different blocks of the model in this folder for detailed explanations. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels=3, |
|
out_channels=3, |
|
num_features=64, |
|
kernel_size=3, |
|
stride=2, |
|
number_msrb=2, |
|
number_rrg=3, |
|
height=3, |
|
width=2, |
|
bias=False, |
|
): |
|
super().__init__() |
|
self.conv_start = nn.Conv2d( |
|
in_channels, num_features, kernel_size, padding=1, bias=bias |
|
) |
|
msrb_blocks = [ |
|
ResidualRecurrentGroup( |
|
num_features, number_msrb, height, width, stride, bias |
|
) |
|
for _ in range(number_rrg) |
|
] |
|
self.msrb_blocks = nn.Sequential(*msrb_blocks) |
|
self.conv_end = nn.Conv2d( |
|
num_features, out_channels, kernel_size, padding=1, bias=bias |
|
) |
|
|
|
def forward(self, x): |
|
output = self.conv_start(x) |
|
output = self.msrb_blocks(output) |
|
output = self.conv_end(output) |
|
return x + output |
|
|