Add application files
Browse files- app.py +130 -0
- model/MIRNet/ChannelAttention.py +40 -0
- model/MIRNet/ChannelCompression.py +16 -0
- model/MIRNet/Downsampling.py +135 -0
- model/MIRNet/DualAttentionUnit.py +39 -0
- model/MIRNet/MultiScaleResidualBlock.py +124 -0
- model/MIRNet/ResidualRecurrentGroup.py +34 -0
- model/MIRNet/SelectiveKernelFeatureFusion.py +65 -0
- model/MIRNet/SpatialAttention.py +24 -0
- model/MIRNet/Upsampling.py +56 -0
- model/MIRNet/__init__.py +0 -0
- model/MIRNet/__pycache__/ChannelAttention.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/ChannelCompression.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/Downsampling.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/DualAttentionUnit.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/MultiScaleResidualBlock.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/ResidualRecurrentGroup.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/SelectiveKernelFeatureFusion.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/SpatialAttention.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/Upsampling.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/__init__.cpython-310.pyc +0 -0
- model/MIRNet/__pycache__/model.cpython-310.pyc +0 -0
- model/MIRNet/model.py +47 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import io
|
4 |
+
import sys, os
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as T
|
7 |
+
import torchvision.utils as vutils
|
8 |
+
import base64
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as T
|
11 |
+
from PIL import Image
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
from model.MIRNet.model import MIRNet
|
14 |
+
|
15 |
+
from model.MIRNet.model import MIRNet
|
16 |
+
|
17 |
+
|
18 |
+
def run_model(input_image):
|
19 |
+
device = (
|
20 |
+
torch.device("cuda")
|
21 |
+
if torch.cuda.is_available()
|
22 |
+
else torch.device("mps")
|
23 |
+
if torch.backends.mps.is_available()
|
24 |
+
else torch.device("cpu")
|
25 |
+
)
|
26 |
+
|
27 |
+
model = MIRNet(num_features=64).to(device)
|
28 |
+
model_path = hf_hub_download(
|
29 |
+
repo_id="dblasko/mirnet-low-light-img-enhancement",
|
30 |
+
filename="mirnet_finetuned.pth",
|
31 |
+
)
|
32 |
+
model.load_state_dict(
|
33 |
+
torch.load(model_path, map_location=device)["model_state_dict"]
|
34 |
+
)
|
35 |
+
|
36 |
+
model.eval()
|
37 |
+
with torch.no_grad():
|
38 |
+
img = input_image
|
39 |
+
img_tensor = T.Compose(
|
40 |
+
[
|
41 |
+
T.Resize(400),
|
42 |
+
T.ToTensor(),
|
43 |
+
T.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
|
44 |
+
]
|
45 |
+
)(img).unsqueeze(0)
|
46 |
+
img_tensor = img_tensor.to(device)
|
47 |
+
|
48 |
+
if img_tensor.shape[2] % 8 != 0:
|
49 |
+
img_tensor = img_tensor[:, :, : -(img_tensor.shape[2] % 8), :]
|
50 |
+
if img_tensor.shape[3] % 8 != 0:
|
51 |
+
img_tensor = img_tensor[:, :, :, : -(img_tensor.shape[3] % 8)]
|
52 |
+
|
53 |
+
output = model(img_tensor)
|
54 |
+
|
55 |
+
vutils.save_image(output, open(f"temp.png", "wb"))
|
56 |
+
output_image = Image.open("temp.png")
|
57 |
+
os.remove("temp.png")
|
58 |
+
return output_image
|
59 |
+
|
60 |
+
|
61 |
+
def get_base64_font(font_path):
|
62 |
+
with open(font_path, "rb") as font_file:
|
63 |
+
return base64.b64encode(font_file.read()).decode()
|
64 |
+
|
65 |
+
|
66 |
+
st.set_page_config(layout="wide")
|
67 |
+
|
68 |
+
font_name = "Gloock"
|
69 |
+
gloock_b64 = get_base64_font("utils/assets/Gloock-Regular.ttf")
|
70 |
+
font_name_text = "Merriweather sans"
|
71 |
+
merri_b64 = get_base64_font("utils/assets/MerriweatherSans-Regular.ttf")
|
72 |
+
hide_streamlit_style = f"""
|
73 |
+
<style>
|
74 |
+
#MainMenu {'{visibility: hidden;}'}
|
75 |
+
footer {'{visibility: hidden;}'}
|
76 |
+
|
77 |
+
@font-face {{
|
78 |
+
font-family: '{font_name}';
|
79 |
+
src: url(data:font/ttf;base64,{gloock_b64}) format('truetype');
|
80 |
+
}}
|
81 |
+
@font-face {{
|
82 |
+
font-family: '{font_name_text}';
|
83 |
+
src: url(data:font/ttf;base64,{merri_b64}) format('truetype');
|
84 |
+
}}
|
85 |
+
span {{
|
86 |
+
font-family: '{font_name_text}';
|
87 |
+
}}
|
88 |
+
.e1nzilvr1, .st-emotion-cache-10trblm {{
|
89 |
+
font-family: '{font_name}';
|
90 |
+
font-size: 65px;
|
91 |
+
}}
|
92 |
+
|
93 |
+
</style>
|
94 |
+
"""
|
95 |
+
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
|
96 |
+
|
97 |
+
st.title("Low-light event-image enhancement with MIRNet.")
|
98 |
+
|
99 |
+
# File uploader widget
|
100 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
|
101 |
+
if uploaded_file is not None:
|
102 |
+
# To read file as bytes:
|
103 |
+
bytes_data = uploaded_file.getvalue()
|
104 |
+
image = Image.open(io.BytesIO(bytes_data)).convert("RGB")
|
105 |
+
|
106 |
+
# Create two columns for images
|
107 |
+
col1, col2 = st.columns(2)
|
108 |
+
|
109 |
+
with col1:
|
110 |
+
st.image(image, caption="Original Image", use_column_width="always")
|
111 |
+
|
112 |
+
# Button to enhance image
|
113 |
+
if st.button("Enhance Image"):
|
114 |
+
with col2:
|
115 |
+
# Assume your model has a function 'enhance' to enhance the image
|
116 |
+
enhanced_image = run_model(image)
|
117 |
+
st.image(
|
118 |
+
enhanced_image, caption="Enhanced Image", use_column_width="always"
|
119 |
+
)
|
120 |
+
|
121 |
+
# Download button
|
122 |
+
buf = io.BytesIO()
|
123 |
+
enhanced_image.save(buf, format="JPEG")
|
124 |
+
byte_im = buf.getvalue()
|
125 |
+
st.download_button(
|
126 |
+
label="Download image",
|
127 |
+
data=byte_im,
|
128 |
+
file_name="enhanced_image.jpg",
|
129 |
+
mime="image/jpeg",
|
130 |
+
)
|
model/MIRNet/ChannelAttention.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class ChannelAttention(nn.Module):
|
6 |
+
"""
|
7 |
+
Squeezes down the input to 1x1xC, applies the excitation operation and restores the C channels through a 1x1 convolution.
|
8 |
+
|
9 |
+
In: HxWxC
|
10 |
+
Out: HxWxC (original channels are restored by multiplying the output with the original input)
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, in_channels, reduction_ratio=8, bias=True):
|
14 |
+
super().__init__()
|
15 |
+
self.squeezing = nn.AdaptiveAvgPool2d(1)
|
16 |
+
self.excitation = nn.Sequential(
|
17 |
+
nn.Conv2d(
|
18 |
+
in_channels,
|
19 |
+
in_channels // reduction_ratio,
|
20 |
+
kernel_size=1,
|
21 |
+
padding=0,
|
22 |
+
bias=bias,
|
23 |
+
),
|
24 |
+
nn.PReLU(),
|
25 |
+
nn.Conv2d(
|
26 |
+
in_channels // reduction_ratio,
|
27 |
+
in_channels,
|
28 |
+
kernel_size=1,
|
29 |
+
padding=0,
|
30 |
+
bias=bias,
|
31 |
+
),
|
32 |
+
nn.Sigmoid(),
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
squeezed_x = self.squeezing(x) # 1x1xC
|
37 |
+
excitation = self.excitation(squeezed_x) # 1x1x(C/r)
|
38 |
+
return (
|
39 |
+
excitation * x
|
40 |
+
) # HxWxC restored through the mult. with the original input
|
model/MIRNet/ChannelCompression.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class ChannelCompression(nn.Module):
|
6 |
+
"""
|
7 |
+
Reduces the input to 2 channels by concatenating the global average pooling and global max pooling outputs.
|
8 |
+
|
9 |
+
In: HxWxC
|
10 |
+
Out: HxWx2
|
11 |
+
"""
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return torch.cat(
|
15 |
+
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
|
16 |
+
)
|
model/MIRNet/Downsampling.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as fun
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class DownsamplingBlock(nn.Module):
|
8 |
+
"""
|
9 |
+
Downsamples the input to halve the dimensions while doubling the channels through two parallel conv + antialiased downsampling branches.
|
10 |
+
|
11 |
+
In: HxWxC
|
12 |
+
Out: H/2xW/2x2C
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, in_channels, bias=False):
|
16 |
+
super().__init__()
|
17 |
+
self.branch1 = (
|
18 |
+
nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> AD -> 1x1 conv
|
19 |
+
nn.Conv2d(
|
20 |
+
in_channels, in_channels, kernel_size=1, padding=0, bias=bias
|
21 |
+
),
|
22 |
+
nn.PReLU(),
|
23 |
+
nn.Conv2d(
|
24 |
+
in_channels, in_channels, kernel_size=3, padding=1, bias=bias
|
25 |
+
),
|
26 |
+
nn.PReLU(),
|
27 |
+
DownSample(channels=in_channels, filter_size=3, stride=2),
|
28 |
+
nn.Conv2d(
|
29 |
+
in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
|
30 |
+
),
|
31 |
+
)
|
32 |
+
)
|
33 |
+
self.branch2 = nn.Sequential(
|
34 |
+
DownSample(channels=in_channels, filter_size=3, stride=2),
|
35 |
+
nn.Conv2d(
|
36 |
+
in_channels, in_channels * 2, kernel_size=1, padding=0, bias=bias
|
37 |
+
),
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return self.branch1(x) + self.branch2(x) # H/2xW/2x2C
|
42 |
+
|
43 |
+
|
44 |
+
class DownsamplingModule(nn.Module):
|
45 |
+
"""
|
46 |
+
Downsampling module of the network composed of (scaling factor) DownsamplingBlocks.
|
47 |
+
|
48 |
+
In: HxWxC
|
49 |
+
Out: H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, in_channels, scaling_factor, stride=2):
|
53 |
+
super().__init__()
|
54 |
+
self.scaling_factor = int(np.log2(scaling_factor))
|
55 |
+
|
56 |
+
blocks = []
|
57 |
+
for i in range(self.scaling_factor):
|
58 |
+
blocks.append(DownsamplingBlock(in_channels))
|
59 |
+
in_channels = int(in_channels * stride)
|
60 |
+
self.blocks = nn.Sequential(*blocks)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
x = self.blocks(x)
|
64 |
+
return x # H/2^(scaling factor) x W/2^(scaling factor) x C^2(scaling factor)
|
65 |
+
|
66 |
+
|
67 |
+
class DownSample(nn.Module):
|
68 |
+
"""
|
69 |
+
Antialiased downsampling module using the blur-pooling method.
|
70 |
+
|
71 |
+
From Adobe's implementation available here: https://github.com/yilundu/improved_contrastive_divergence/blob/master/downsample.py
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self, pad_type="reflect", filter_size=3, stride=2, channels=None, pad_off=0
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.filter_size = filter_size
|
79 |
+
self.stride = stride
|
80 |
+
self.pad_off = pad_off
|
81 |
+
self.channels = channels
|
82 |
+
self.pad_sizes = [
|
83 |
+
int(1.0 * (filter_size - 1) / 2),
|
84 |
+
int(np.ceil(1.0 * (filter_size - 1) / 2)),
|
85 |
+
int(1.0 * (filter_size - 1) / 2),
|
86 |
+
int(np.ceil(1.0 * (filter_size - 1) / 2)),
|
87 |
+
]
|
88 |
+
|
89 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
90 |
+
self.off = int((self.stride - 1) / 2.0)
|
91 |
+
|
92 |
+
if self.filter_size == 1:
|
93 |
+
a = np.array([1.0])
|
94 |
+
elif self.filter_size == 2:
|
95 |
+
a = np.array([1.0, 1.0])
|
96 |
+
elif self.filter_size == 3:
|
97 |
+
a = np.array([1.0, 2.0, 1.0])
|
98 |
+
elif self.filter_size == 4:
|
99 |
+
a = np.array([1.0, 3.0, 3.0, 1.0])
|
100 |
+
elif self.filter_size == 5:
|
101 |
+
a = np.array([1.0, 4.0, 6.0, 4.0, 1.0])
|
102 |
+
elif self.filter_size == 6:
|
103 |
+
a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0])
|
104 |
+
elif self.filter_size == 7:
|
105 |
+
a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0])
|
106 |
+
|
107 |
+
filt = torch.Tensor(a[:, None] * a[None, :])
|
108 |
+
filt = filt / torch.sum(filt)
|
109 |
+
self.register_buffer(
|
110 |
+
"filt", filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
111 |
+
)
|
112 |
+
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
if self.filter_size == 1:
|
116 |
+
if self.pad_off == 0:
|
117 |
+
return x[:, :, :: self.stride, :: self.stride]
|
118 |
+
else:
|
119 |
+
return self.pad(x)[:, :, :: self.stride, :: self.stride]
|
120 |
+
|
121 |
+
else:
|
122 |
+
return fun.conv2d(
|
123 |
+
self.pad(x), self.filt, stride=self.stride, groups=x.shape[1]
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
def get_pad_layer(pad_type):
|
128 |
+
if pad_type == "reflect":
|
129 |
+
pad_layer = nn.ReflectionPad2d
|
130 |
+
elif pad_type == "replication":
|
131 |
+
pad_layer = nn.ReplicationPad2d
|
132 |
+
else:
|
133 |
+
print("Pad Type [%s] not recognized" % pad_type)
|
134 |
+
|
135 |
+
return pad_layer
|
model/MIRNet/DualAttentionUnit.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from model.MIRNet.ChannelAttention import ChannelAttention
|
4 |
+
|
5 |
+
from model.MIRNet.SpatialAttention import SpatialAttention
|
6 |
+
|
7 |
+
|
8 |
+
class DualAttentionUnit(nn.Module):
|
9 |
+
"""
|
10 |
+
Combines the ChannelAttention and SpatialAttention modules.
|
11 |
+
(conv, PReLU, conv -> concat. SA & CA output -> conv -> skip connection from input)
|
12 |
+
|
13 |
+
In: HxWxC
|
14 |
+
Out: HxWxC (original channels are restored by multiplying the output with the original input)
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, in_channels, kernel_size=3, reduction_ratio=8, bias=False):
|
18 |
+
super().__init__()
|
19 |
+
self.initial_convs = nn.Sequential(
|
20 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
|
21 |
+
nn.PReLU(),
|
22 |
+
nn.Conv2d(in_channels, in_channels, kernel_size, padding=1, bias=bias),
|
23 |
+
)
|
24 |
+
self.channel_attention = ChannelAttention(in_channels, reduction_ratio, bias)
|
25 |
+
self.spatial_attention = SpatialAttention()
|
26 |
+
self.final_conv = nn.Conv2d(
|
27 |
+
in_channels * 2, in_channels, kernel_size=1, bias=bias
|
28 |
+
)
|
29 |
+
self.in_channels = in_channels
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
initial_convs = self.initial_convs(x) # HxWxC
|
33 |
+
channel_attention = self.channel_attention(initial_convs) # HxWxC
|
34 |
+
spatial_attention = self.spatial_attention(initial_convs) # HxWxC
|
35 |
+
attention = torch.cat((spatial_attention, channel_attention), dim=1) # HxWx2C
|
36 |
+
block_output = self.final_conv(
|
37 |
+
attention
|
38 |
+
) # HxWxC - the 1x1 conv. restores the C channels for the skip connection
|
39 |
+
return x + block_output # the addition is the skip connection from input
|
model/MIRNet/MultiScaleResidualBlock.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
from model.MIRNet.Downsampling import DownsamplingModule
|
5 |
+
|
6 |
+
from model.MIRNet.DualAttentionUnit import DualAttentionUnit
|
7 |
+
from model.MIRNet.SelectiveKernelFeatureFusion import SelectiveKernelFeatureFusion
|
8 |
+
from model.MIRNet.Upsampling import UpsamplingModule
|
9 |
+
|
10 |
+
|
11 |
+
class MultiScaleResidualBlock(nn.Module):
|
12 |
+
"""
|
13 |
+
Three parallel convolutional streams at different resolutions. Information is exchanged through residual connexions.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, num_features, height, width, stride, bias):
|
17 |
+
super().__init__()
|
18 |
+
self.num_features = num_features
|
19 |
+
self.height = height
|
20 |
+
self.width = width
|
21 |
+
features = [int((stride**i) * num_features) for i in range(height)]
|
22 |
+
scale = [2**i for i in range(1, height)]
|
23 |
+
|
24 |
+
self.dual_attention_units = nn.ModuleList(
|
25 |
+
[
|
26 |
+
nn.ModuleList(
|
27 |
+
[DualAttentionUnit(int(num_features * stride**i))] * width
|
28 |
+
)
|
29 |
+
for i in range(height)
|
30 |
+
]
|
31 |
+
)
|
32 |
+
self.last_up = nn.ModuleDict()
|
33 |
+
for i in range(1, height):
|
34 |
+
self.last_up.update(
|
35 |
+
{
|
36 |
+
f"{i}": UpsamplingModule(
|
37 |
+
in_channels=int(num_features * stride**i),
|
38 |
+
scaling_factor=2**i,
|
39 |
+
stride=stride,
|
40 |
+
)
|
41 |
+
}
|
42 |
+
)
|
43 |
+
|
44 |
+
self.down = nn.ModuleDict()
|
45 |
+
i = 0
|
46 |
+
scale.reverse()
|
47 |
+
for f in features:
|
48 |
+
for s in scale[i:]:
|
49 |
+
self.down.update({f"{f}_{s}": DownsamplingModule(f, s, stride)})
|
50 |
+
i += 1
|
51 |
+
|
52 |
+
self.up = nn.ModuleDict()
|
53 |
+
i = 0
|
54 |
+
features.reverse()
|
55 |
+
for f in features:
|
56 |
+
for s in scale[i:]:
|
57 |
+
self.up.update({f"{f}_{s}": UpsamplingModule(f, s, stride)})
|
58 |
+
i += 1
|
59 |
+
|
60 |
+
self.out_conv = nn.Conv2d(
|
61 |
+
num_features, num_features, kernel_size=3, padding=1, bias=bias
|
62 |
+
)
|
63 |
+
self.skff_blocks = nn.ModuleList(
|
64 |
+
[
|
65 |
+
SelectiveKernelFeatureFusion(num_features * stride**i, height)
|
66 |
+
for i in range(height)
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
inp = x.clone()
|
72 |
+
out = []
|
73 |
+
|
74 |
+
for j in range(self.height):
|
75 |
+
if j == 0:
|
76 |
+
inp = self.dual_attention_units[j][0](inp)
|
77 |
+
else:
|
78 |
+
inp = self.dual_attention_units[j][0](
|
79 |
+
self.down[f"{inp.size(1)}_{2}"](inp)
|
80 |
+
)
|
81 |
+
out.append(inp)
|
82 |
+
|
83 |
+
for i in range(1, self.width):
|
84 |
+
if True:
|
85 |
+
temp = []
|
86 |
+
for j in range(self.height):
|
87 |
+
TENSOR = []
|
88 |
+
nfeats = (2**j) * self.num_features
|
89 |
+
for k in range(self.height):
|
90 |
+
TENSOR.append(self.select_up_down(out[k], j, k))
|
91 |
+
|
92 |
+
skff = self.skff_blocks[j](TENSOR)
|
93 |
+
temp.append(skff)
|
94 |
+
|
95 |
+
else:
|
96 |
+
temp = out
|
97 |
+
|
98 |
+
for j in range(self.height):
|
99 |
+
out[j] = self.dual_attention_units[j][i](temp[j])
|
100 |
+
|
101 |
+
output = []
|
102 |
+
for k in range(self.height):
|
103 |
+
output.append(self.select_last_up(out[k], k))
|
104 |
+
|
105 |
+
output = self.skff_blocks[0](output)
|
106 |
+
output = self.out_conv(output)
|
107 |
+
output = output + x
|
108 |
+
return output
|
109 |
+
|
110 |
+
def select_up_down(self, tensor, j, k):
|
111 |
+
if j == k:
|
112 |
+
return tensor
|
113 |
+
else:
|
114 |
+
diff = 2 ** np.abs(j - k)
|
115 |
+
if j < k:
|
116 |
+
return self.up[f"{tensor.size(1)}_{diff}"](tensor)
|
117 |
+
else:
|
118 |
+
return self.down[f"{tensor.size(1)}_{diff}"](tensor)
|
119 |
+
|
120 |
+
def select_last_up(self, tensor, k):
|
121 |
+
if k == 0:
|
122 |
+
return tensor
|
123 |
+
else:
|
124 |
+
return self.last_up[f"{k}"](tensor)
|
model/MIRNet/ResidualRecurrentGroup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from model.MIRNet.MultiScaleResidualBlock import MultiScaleResidualBlock
|
5 |
+
|
6 |
+
|
7 |
+
class ResidualRecurrentGroup(nn.Module):
|
8 |
+
"""
|
9 |
+
Group of multi-scale residual blocks followed by a convolutional layer. The output is what is added to the input image for restoration.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(
|
13 |
+
self, num_features, number_msrb_blocks, height, width, stride, bias=False
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
blocks = [
|
17 |
+
MultiScaleResidualBlock(num_features, height, width, stride, bias)
|
18 |
+
for _ in range(number_msrb_blocks)
|
19 |
+
]
|
20 |
+
blocks.append(
|
21 |
+
nn.Conv2d(
|
22 |
+
num_features,
|
23 |
+
num_features,
|
24 |
+
kernel_size=3,
|
25 |
+
padding=1,
|
26 |
+
stride=1,
|
27 |
+
bias=bias,
|
28 |
+
)
|
29 |
+
)
|
30 |
+
self.blocks = nn.Sequential(*blocks)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
output = self.blocks(x)
|
34 |
+
return x + output # restored image, HxWxC
|
model/MIRNet/SelectiveKernelFeatureFusion.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class SelectiveKernelFeatureFusion(nn.Module):
|
6 |
+
"""
|
7 |
+
Merges outputs of the three different resolutions through self-attention.
|
8 |
+
|
9 |
+
All three inputs are summed -> global average pooling -> downscaling -> the signal is passed through 3 different convs to have three descriptors,
|
10 |
+
softmax is applied to each descriptor to get 3 attention activations used to recalibrate the three input feature maps.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, in_channels, reduction_ratio, bias=False):
|
14 |
+
super().__init__()
|
15 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
16 |
+
conv_out_channels = max(int(in_channels / reduction_ratio), 4)
|
17 |
+
self.convolution = nn.Sequential(
|
18 |
+
nn.Conv2d(
|
19 |
+
in_channels, conv_out_channels, kernel_size=1, padding=0, bias=bias
|
20 |
+
),
|
21 |
+
nn.PReLU(),
|
22 |
+
)
|
23 |
+
|
24 |
+
self.attention_convs = nn.ModuleList([])
|
25 |
+
for i in range(3):
|
26 |
+
self.attention_convs.append(
|
27 |
+
nn.Conv2d(
|
28 |
+
conv_out_channels, in_channels, kernel_size=1, stride=1, bias=bias
|
29 |
+
)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.softmax = nn.Softmax(dim=1)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
batch_size = x[0].shape[0]
|
36 |
+
n_features = x[0].shape[1]
|
37 |
+
|
38 |
+
x = torch.cat(
|
39 |
+
x, dim=1
|
40 |
+
) # the three outputs of diff. res. are concatenated along the channel dimension
|
41 |
+
x = x.view(
|
42 |
+
batch_size, 3, n_features, x.shape[2], x.shape[3]
|
43 |
+
) # batch_size x 3 x n_features x H x W
|
44 |
+
|
45 |
+
z = torch.sum(x, dim=1) # batch_size x n_features x H x W
|
46 |
+
z = self.avg_pool(z) # batch_size x n_features x 1 x 1
|
47 |
+
z = self.convolution(z) # batch_size x n_features/8 x 1 x 1
|
48 |
+
|
49 |
+
attention_activations = [
|
50 |
+
atn(z) for atn in self.attention_convs
|
51 |
+
] # 3 x (batch_size x n_features x 1 x 1)
|
52 |
+
attention_activations = torch.cat(
|
53 |
+
attention_activations, dim=1
|
54 |
+
) # batch_size x 3*n_features x 1 x 1
|
55 |
+
attention_activations = attention_activations.view(
|
56 |
+
batch_size, 3, n_features, 1, 1
|
57 |
+
) # batch_size x 3 x n_features x 1 x 1
|
58 |
+
|
59 |
+
attention_activations = self.softmax(
|
60 |
+
attention_activations
|
61 |
+
) # batch_size x 3 x n_features x 1 x 1
|
62 |
+
|
63 |
+
return torch.sum(
|
64 |
+
x * attention_activations, dim=1
|
65 |
+
) # batch_size x n_features x H x W (the three feature maps are recalibrated and summed
|
model/MIRNet/SpatialAttention.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from model.MIRNet.ChannelCompression import ChannelCompression
|
5 |
+
|
6 |
+
|
7 |
+
class SpatialAttention(nn.Module):
|
8 |
+
"""
|
9 |
+
Reduces the input to 2 channel with the ChannelCompression module and applies a 2D convolution with 1 output channel.
|
10 |
+
|
11 |
+
In: HxWxC
|
12 |
+
Out: HxWxC (original channels are restored by multiplying the output with the original input)
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
self.channel_compression = ChannelCompression()
|
18 |
+
self.conv = nn.Conv2d(2, 1, kernel_size=5, stride=1, padding=2)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x_compressed = self.channel_compression(x) # HxWx2
|
22 |
+
x_conv = self.conv(x_compressed) # HxWx1
|
23 |
+
scaling_factor = torch.sigmoid(x_conv)
|
24 |
+
return x * scaling_factor # HxWxC
|
model/MIRNet/Upsampling.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class UpsamplingBlock(nn.Module):
|
7 |
+
"""
|
8 |
+
Upsamples the input to double the dimensions while halving the channels through two parallel conv + bilinear upsampling branches.
|
9 |
+
|
10 |
+
In: HxWxC
|
11 |
+
Out: 2Hx2WxC/2
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, in_channels, bias=False):
|
15 |
+
super().__init__()
|
16 |
+
self.branch1 = nn.Sequential( # 1x1 conv + PReLU -> 3x3 conv + PReLU -> BU -> 1x1 conv
|
17 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0, bias=bias),
|
18 |
+
nn.PReLU(),
|
19 |
+
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=bias),
|
20 |
+
nn.PReLU(),
|
21 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
|
22 |
+
nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, padding=0, bias=bias)
|
23 |
+
)
|
24 |
+
self.branch2 = nn.Sequential(
|
25 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=bias),
|
26 |
+
nn.Conv2d(in_channels, in_channels // 2, kernel_size=1, padding=0, bias=bias)
|
27 |
+
)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
return self.branch1(x) + self.branch2(x) # 2Hx2WxC/2
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
class UpsamplingModule(nn.Module):
|
35 |
+
"""
|
36 |
+
Upsampling module of the network composed of (scaling factor) UpsamplingBlocks.
|
37 |
+
|
38 |
+
In: HxWxC
|
39 |
+
Out: 2^(scaling factor)H x 2^(scaling factor)W x C/2^(scaling factor)
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, in_channels, scaling_factor, stride=2):
|
43 |
+
super().__init__()
|
44 |
+
self.scaling_factor = int(np.log2(scaling_factor))
|
45 |
+
|
46 |
+
blocks = []
|
47 |
+
for i in range(self.scaling_factor):
|
48 |
+
blocks.append(UpsamplingBlock(in_channels))
|
49 |
+
in_channels = int(in_channels // 2)
|
50 |
+
self.blocks = nn.Sequential(*blocks)
|
51 |
+
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return self.blocks(x) # 2^(scaling factor)H x 2^(scaling factor)W x C/2^(scaling factor)
|
55 |
+
|
56 |
+
|
model/MIRNet/__init__.py
ADDED
File without changes
|
model/MIRNet/__pycache__/ChannelAttention.cpython-310.pyc
ADDED
Binary file (1.28 kB). View file
|
|
model/MIRNet/__pycache__/ChannelCompression.cpython-310.pyc
ADDED
Binary file (764 Bytes). View file
|
|
model/MIRNet/__pycache__/Downsampling.cpython-310.pyc
ADDED
Binary file (4.24 kB). View file
|
|
model/MIRNet/__pycache__/DualAttentionUnit.cpython-310.pyc
ADDED
Binary file (1.61 kB). View file
|
|
model/MIRNet/__pycache__/MultiScaleResidualBlock.cpython-310.pyc
ADDED
Binary file (3.6 kB). View file
|
|
model/MIRNet/__pycache__/ResidualRecurrentGroup.cpython-310.pyc
ADDED
Binary file (1.43 kB). View file
|
|
model/MIRNet/__pycache__/SelectiveKernelFeatureFusion.cpython-310.pyc
ADDED
Binary file (2.07 kB). View file
|
|
model/MIRNet/__pycache__/SpatialAttention.cpython-310.pyc
ADDED
Binary file (1.23 kB). View file
|
|
model/MIRNet/__pycache__/Upsampling.cpython-310.pyc
ADDED
Binary file (2 kB). View file
|
|
model/MIRNet/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (141 Bytes). View file
|
|
model/MIRNet/__pycache__/model.cpython-310.pyc
ADDED
Binary file (1.72 kB). View file
|
|
model/MIRNet/model.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from model.MIRNet.ResidualRecurrentGroup import ResidualRecurrentGroup
|
5 |
+
|
6 |
+
|
7 |
+
class MIRNet(nn.Module):
|
8 |
+
"""
|
9 |
+
Low-level features are extracted through convolution and passed to n residual recurrent groups that operate at different resolutions.
|
10 |
+
Their output is added to the input image for restoration.
|
11 |
+
|
12 |
+
Please refer to the documentation of the different blocks of the model in this folder for detailed explanations.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_channels=3,
|
18 |
+
out_channels=3,
|
19 |
+
num_features=64,
|
20 |
+
kernel_size=3,
|
21 |
+
stride=2,
|
22 |
+
number_msrb=2,
|
23 |
+
number_rrg=3,
|
24 |
+
height=3,
|
25 |
+
width=2,
|
26 |
+
bias=False,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.conv_start = nn.Conv2d(
|
30 |
+
in_channels, num_features, kernel_size, padding=1, bias=bias
|
31 |
+
)
|
32 |
+
msrb_blocks = [
|
33 |
+
ResidualRecurrentGroup(
|
34 |
+
num_features, number_msrb, height, width, stride, bias
|
35 |
+
)
|
36 |
+
for _ in range(number_rrg)
|
37 |
+
]
|
38 |
+
self.msrb_blocks = nn.Sequential(*msrb_blocks)
|
39 |
+
self.conv_end = nn.Conv2d(
|
40 |
+
num_features, out_channels, kernel_size, padding=1, bias=bias
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
output = self.conv_start(x)
|
45 |
+
output = self.msrb_blocks(output)
|
46 |
+
output = self.conv_end(output)
|
47 |
+
return x + output # restored image, HxWxC
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface
|
2 |
+
datasets
|
3 |
+
torch
|
4 |
+
torchvision
|
5 |
+
streamlit
|