dblasko's picture
Add application files
3a273df
raw
history blame
1.43 kB
import torch
import torch.nn as nn
from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup
class MIRNet(nn.Module):
"""
Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions.
Their output is added to the input image for restoration.
Please refer to the documentation of the different blocks of the model in this folder for detailed explanations.
"""
def __init__(
self,
in_channels=3,
out_channels=3,
num_features=64,
kernel_size=3,
stride=2,
number_msrb=2,
number_rrg=3,
height=3,
width=2,
bias=False,
):
super().__init__()
self.conv_start = nn.Conv2d(
in_channels, num_features, kernel_size, padding=1, bias=bias
)
msrb_blocks = [
ResidualRecurrentGroup(
num_features, number_msrb, height, width, stride, bias
)
for _ in range(number_rrg)
]
self.msrb_blocks = nn.Sequential(*msrb_blocks)
self.conv_end = nn.Conv2d(
num_features, out_channels, kernel_size, padding=1, bias=bias
)
def forward(self, x):
output = self.conv_start(x)
output = self.msrb_blocks(output)
output = self.conv_end(output)
return x + output # restored image, HxWxC