dblasko's picture
Upload 11 files
9b9b1dc
raw
history blame contribute delete
No virus
1.43 kB
import torch
import torch.nn as nn
from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup
class MIRNet(nn.Module):
"""
Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions.
Their output is added to the input image for restoration.
Please refer to the documentation of the different blocks of the model in this folder for detailed explanations.
"""
def __init__(
self,
in_channels=3,
out_channels=3,
num_features=64,
kernel_size=3,
stride=2,
number_msrb=2,
number_rrg=3,
height=3,
width=2,
bias=False,
):
super().__init__()
self.conv_start = nn.Conv2d(
in_channels, num_features, kernel_size, padding=1, bias=bias
)
msrb_blocks = [
ResidualRecurrentGroup(
num_features, number_msrb, height, width, stride, bias
)
for _ in range(number_rrg)
]
self.msrb_blocks = nn.Sequential(*msrb_blocks)
self.conv_end = nn.Conv2d(
num_features, out_channels, kernel_size, padding=1, bias=bias
)
def forward(self, x):
output = self.conv_start(x)
output = self.msrb_blocks(output)
output = self.conv_end(output)
return x + output # restored image, HxWxC