mirnet-low-light-img-enhancement / model /MIRNet /MultiScaleResidualBlock.py
dblasko's picture
Add application files
3a273df
raw
history blame
3.9 kB
import torch
import numpy as np
import torch.nn as nn
from model.MIRNet.Downsampling import DownsamplingModule
from model.MIRNet.DualAttentionUnit import DualAttentionUnit
from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
from model.MIRNet.Upsampling import UpsamplingModule
class MultiScaleResidualBlock(nn.Module):
"""
Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
"""
def __init__(self, num_features, height, width, stride, bias):
super().__init__()
self.num_features = num_features
self.height = height
self.width = width
features = [int((stride**i) * num_features) for i in range(height)]
scale = [2**i for i in range(1, height)]
self.dual_attention_units = nn.ModuleList(
[
nn.ModuleList(
[DualAttentionUnit(int(num_features * stride**i))] * width
)
for i in range(height)
]
)
self.last_up = nn.ModuleDict()
for i in range(1, height):
self.last_up.update(
{
f"{i}": UpsamplingModule(
in_channels=int(num_features * stride**i),
scaling_factor=2**i,
stride=stride,
)
}
)
self.down = nn.ModuleDict()
i = 0
scale.reverse()
for f in features:
for s in scale[i:]:
self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
i += 1
self.up = nn.ModuleDict()
i = 0
features.reverse()
for f in features:
for s in scale[i:]:
self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
i += 1
self.out_conv = nn.Conv2d(
num_features, num_features, kernel_size=3, padding=1, bias=bias
)
self.skff_blocks = nn.ModuleList(
[
SelectiveKernelFeatureFusion(num_features * stride**i, height)
for i in range(height)
]
)
def forward(self, x):
inp = x.clone()
out = []
for j in range(self.height):
if j == 0:
inp = self.dual_attention_units[j][0](inp)
else:
inp = self.dual_attention_units[j][0](
self.down[f"{inp.size(1)}_{2}"](inp)
)
out.append(inp)
for i in range(1, self.width):
if True:
temp = []
for j in range(self.height):
TENSOR = []
nfeats = (2**j) * self.num_features
for k in range(self.height):
TENSOR.append(self.select_up_down(out[k], j, k))
skff = self.skff_blocks[j](TENSOR)
temp.append(skff)
else:
temp = out
for j in range(self.height):
out[j] = self.dual_attention_units[j][i](temp[j])
output = []
for k in range(self.height):
output.append(self.select_last_up(out[k], k))
output = self.skff_blocks[0](output)
output = self.out_conv(output)
output = output + x
return output
def select_up_down(self, tensor, j, k):
if j == k:
return tensor
else:
diff = 2 ** np.abs(j - k)
if j < k:
return self.up[f"{tensor.size(1)}_{diff}"](tensor)
else:
return self.down[f"{tensor.size(1)}_{diff}"](tensor)
def select_last_up(self, tensor, k):
if k == 0:
return tensor
else:
return self.last_up[f"{k}"](tensor)