File size: 3,897 Bytes
3a273df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
import numpy as np
import torch.nn as nn
from model.MIRNet.Downsampling import DownsamplingModule
from model.MIRNet.DualAttentionUnit import DualAttentionUnit
from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
from model.MIRNet.Upsampling import UpsamplingModule
class MultiScaleResidualBlock(nn.Module):
"""
Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
"""
def __init__(self, num_features, height, width, stride, bias):
super().__init__()
self.num_features = num_features
self.height = height
self.width = width
features = [int((stride**i) * num_features) for i in range(height)]
scale = [2**i for i in range(1, height)]
self.dual_attention_units = nn.ModuleList(
[
nn.ModuleList(
[DualAttentionUnit(int(num_features * stride**i))] * width
)
for i in range(height)
]
)
self.last_up = nn.ModuleDict()
for i in range(1, height):
self.last_up.update(
{
f"{i}": UpsamplingModule(
in_channels=int(num_features * stride**i),
scaling_factor=2**i,
stride=stride,
)
}
)
self.down = nn.ModuleDict()
i = 0
scale.reverse()
for f in features:
for s in scale[i:]:
self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
i += 1
self.up = nn.ModuleDict()
i = 0
features.reverse()
for f in features:
for s in scale[i:]:
self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
i += 1
self.out_conv = nn.Conv2d(
num_features, num_features, kernel_size=3, padding=1, bias=bias
)
self.skff_blocks = nn.ModuleList(
[
SelectiveKernelFeatureFusion(num_features * stride**i, height)
for i in range(height)
]
)
def forward(self, x):
inp = x.clone()
out = []
for j in range(self.height):
if j == 0:
inp = self.dual_attention_units[j][0](inp)
else:
inp = self.dual_attention_units[j][0](
self.down[f"{inp.size(1)}_{2}"](inp)
)
out.append(inp)
for i in range(1, self.width):
if True:
temp = []
for j in range(self.height):
TENSOR = []
nfeats = (2**j) * self.num_features
for k in range(self.height):
TENSOR.append(self.select_up_down(out[k], j, k))
skff = self.skff_blocks[j](TENSOR)
temp.append(skff)
else:
temp = out
for j in range(self.height):
out[j] = self.dual_attention_units[j][i](temp[j])
output = []
for k in range(self.height):
output.append(self.select_last_up(out[k], k))
output = self.skff_blocks[0](output)
output = self.out_conv(output)
output = output + x
return output
def select_up_down(self, tensor, j, k):
if j == k:
return tensor
else:
diff = 2 ** np.abs(j - k)
if j < k:
return self.up[f"{tensor.size(1)}_{diff}"](tensor)
else:
return self.down[f"{tensor.size(1)}_{diff}"](tensor)
def select_last_up(self, tensor, k):
if k == 0:
return tensor
else:
return self.last_up[f"{k}"](tensor)
|