IDKiro commited on
Commit
7eafae4
1 Parent(s): f78b820
.gitignore ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add by user
2
+ .vscode/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+
63
+ # Flask stuff:
64
+ instance/
65
+ .webassets-cache
66
+
67
+ # Scrapy stuff:
68
+ .scrapy
69
+
70
+ # Sphinx documentation
71
+ docs/_build/
72
+
73
+ # PyBuilder
74
+ target/
75
+
76
+ # Jupyter Notebook
77
+ .ipynb_checkpoints
78
+
79
+ # IPython
80
+ profile_default/
81
+ ipython_config.py
82
+
83
+ # pyenv
84
+ .python-version
85
+
86
+ # celery beat schedule file
87
+ celerybeat-schedule
88
+
89
+ # SageMath parsed files
90
+ *.sage.py
91
+
92
+ # Environments
93
+ .env
94
+ .venv
95
+ env/
96
+ venv/
97
+ ENV/
98
+ env.bak/
99
+ venv.bak/
100
+
101
+ # Spyder project settings
102
+ .spyderproject
103
+ .spyproject
104
+
105
+ # Rope project settings
106
+ .ropeproject
107
+
108
+ # mkdocs documentation
109
+ /site
110
+
111
+ # mypy
112
+ .mypy_cache/
113
+ .dmypy.json
114
+ dmypy.json
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 IDKiro
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+
5
+ from PIL import Image
6
+ from models import dehazeformer
7
+
8
+
9
+ def infer(raw_image):
10
+ network = dehazeformer()
11
+ network.load_state_dict(torch.load('./saved_models/dehazeformer.pth', map_location=torch.device('cpu'))['state_dict'])
12
+ # torch.save({'state_dict': network.state_dict()}, './saved_models/dehazeformer.pth')
13
+
14
+ network.eval()
15
+
16
+ image = np.array(raw_image, np.float32) / 255. * 2 - 1
17
+ image = torch.from_numpy(image)
18
+ image = image.permute((2, 0, 1)).unsqueeze(0)
19
+
20
+ with torch.no_grad():
21
+ output = network(image).clamp_(-1, 1)[0] * 0.5 + 0.5
22
+ output = output.permute((1, 2, 0))
23
+ output = np.array(output, np.float32)
24
+ output = np.round(output * 255.0)
25
+
26
+ output = Image.fromarray(output.astype(np.uint8))
27
+
28
+ return output
29
+
30
+
31
+ title = "DehazeFormer"
32
+ description = f"We use a mixed dataset to train the model, allowing the trained model to work better on real hazy images. To allow the model to process high-resolution images more efficiently and effectively, we extend it to the [MCT](https://github.com/IDKiro/MCT) variant."
33
+ examples = [
34
+ ["examples/1.jpg"],
35
+ ["examples/2.jpg"],
36
+ ["examples/3.jpg"],
37
+ ["examples/4.jpg"],
38
+ ["examples/5.jpg"],
39
+ ["examples/6.jpg"]
40
+ ]
41
+
42
+ iface = gr.Interface(
43
+ infer,
44
+ inputs="image", outputs="image",
45
+ title=title,
46
+ description=description,
47
+ allow_flagging='never',
48
+ examples=examples,
49
+ )
50
+ iface.launch()
examples/1.jpg ADDED
examples/2.jpg ADDED
examples/3.jpg ADDED
examples/4.jpg ADDED
examples/5.jpg ADDED
examples/6.jpg ADDED
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dehazeformer import MCT as dehazeformer
models/dehazeformer.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class RLN(nn.Module):
7
+ r"""Revised LayerNorm"""
8
+ def __init__(self, dim, eps=1e-5, detach_grad=False):
9
+ super(RLN, self).__init__()
10
+ self.eps = eps
11
+ self.detach_grad = detach_grad
12
+
13
+ self.weight = nn.Parameter(torch.ones((1, dim, 1, 1)))
14
+ self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1)))
15
+
16
+ self.meta1 = nn.Conv2d(1, dim, 1)
17
+ self.meta2 = nn.Conv2d(1, dim, 1)
18
+
19
+ def forward(self, input):
20
+ mean = torch.mean(input, dim=(1, 2, 3), keepdim=True)
21
+ std = torch.sqrt((input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps)
22
+
23
+ normalized_input = (input - mean) / std
24
+
25
+ if self.detach_grad:
26
+ rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach())
27
+ else:
28
+ rescale, rebias = self.meta1(std), self.meta2(mean)
29
+
30
+ out = normalized_input * self.weight + self.bias
31
+ return out, rescale, rebias
32
+
33
+
34
+ class Mlp(nn.Module):
35
+ def __init__(self, network_depth, in_features, hidden_features=None, out_features=None):
36
+ super().__init__()
37
+ out_features = out_features or in_features
38
+ hidden_features = hidden_features or in_features
39
+
40
+ self.network_depth = network_depth
41
+
42
+ self.mlp = nn.Sequential(
43
+ nn.Conv2d(in_features, hidden_features, 1),
44
+ nn.ReLU(True),
45
+ nn.Conv2d(hidden_features, out_features, 1)
46
+ )
47
+
48
+ def forward(self, x):
49
+ return self.mlp(x)
50
+
51
+
52
+ def window_partition(x, window_size):
53
+ B, H, W, C = x.shape
54
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
55
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C)
56
+ return windows
57
+
58
+
59
+ def window_reverse(windows, window_size, H, W):
60
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
61
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
62
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
63
+ return x
64
+
65
+
66
+ def get_relative_positions(window_size):
67
+ coords_h = torch.arange(window_size)
68
+ coords_w = torch.arange(window_size)
69
+
70
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
71
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
72
+ relative_positions = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
73
+
74
+ relative_positions = relative_positions.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
75
+ relative_positions_log = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs())
76
+
77
+ return relative_positions_log
78
+
79
+
80
+ class WindowAttention(nn.Module):
81
+ def __init__(self, dim, window_size, num_heads):
82
+
83
+ super().__init__()
84
+ self.dim = dim
85
+ self.window_size = window_size # Wh, Ww
86
+ self.num_heads = num_heads
87
+ head_dim = dim // num_heads
88
+ self.scale = head_dim ** -0.5
89
+
90
+ relative_positions = get_relative_positions(self.window_size)
91
+ self.register_buffer("relative_positions", relative_positions)
92
+ self.meta = nn.Sequential(
93
+ nn.Linear(2, 256, bias=True),
94
+ nn.ReLU(True),
95
+ nn.Linear(256, num_heads, bias=True)
96
+ )
97
+
98
+ self.softmax = nn.Softmax(dim=-1)
99
+
100
+ def forward(self, qkv):
101
+ B_, N, _ = qkv.shape
102
+
103
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4)
104
+
105
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
106
+
107
+ q = q * self.scale
108
+ attn = (q @ k.transpose(-2, -1))
109
+
110
+ relative_position_bias = self.meta(self.relative_positions)
111
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
112
+ attn = attn + relative_position_bias.unsqueeze(0)
113
+
114
+ attn = self.softmax(attn)
115
+
116
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
117
+ return x
118
+
119
+
120
+ class Attention(nn.Module):
121
+ def __init__(self, network_depth, dim, num_heads, window_size, shift_size, use_attn=False, conv_type=None):
122
+ super().__init__()
123
+ self.dim = dim
124
+ self.head_dim = int(dim // num_heads)
125
+ self.num_heads = num_heads
126
+
127
+ self.window_size = window_size
128
+ self.shift_size = shift_size
129
+
130
+ self.network_depth = network_depth
131
+ self.use_attn = use_attn
132
+ self.conv_type = conv_type
133
+
134
+ if self.conv_type == 'Conv':
135
+ self.conv = nn.Sequential(
136
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect'),
137
+ nn.ReLU(True),
138
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect')
139
+ )
140
+
141
+ if self.conv_type == 'DWConv':
142
+ self.conv = nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode='reflect')
143
+
144
+ if self.conv_type == 'DWConv' or self.use_attn:
145
+ self.V = nn.Conv2d(dim, dim, 1)
146
+ self.proj = nn.Conv2d(dim, dim, 1)
147
+
148
+ if self.use_attn:
149
+ self.QK = nn.Conv2d(dim, dim * 2, 1)
150
+ self.attn = WindowAttention(dim, window_size, num_heads)
151
+
152
+ def check_size(self, x, shift=False):
153
+ _, _, h, w = x.size()
154
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
155
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
156
+
157
+ if shift:
158
+ x = F.pad(x, (self.shift_size, (self.window_size-self.shift_size+mod_pad_w) % self.window_size,
159
+ self.shift_size, (self.window_size-self.shift_size+mod_pad_h) % self.window_size), mode='reflect')
160
+ else:
161
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
162
+ return x
163
+
164
+ def forward(self, X):
165
+ B, C, H, W = X.shape
166
+
167
+ if self.conv_type == 'DWConv' or self.use_attn:
168
+ V = self.V(X)
169
+
170
+ if self.use_attn:
171
+ QK = self.QK(X)
172
+ QKV = torch.cat([QK, V], dim=1)
173
+
174
+ # shift
175
+ shifted_QKV = self.check_size(QKV, self.shift_size > 0)
176
+ Ht, Wt = shifted_QKV.shape[2:]
177
+
178
+ # partition windows
179
+ shifted_QKV = shifted_QKV.permute(0, 2, 3, 1)
180
+ qkv = window_partition(shifted_QKV, self.window_size) # nW*B, window_size**2, C
181
+
182
+ attn_windows = self.attn(qkv)
183
+
184
+ # merge windows
185
+ shifted_out = window_reverse(attn_windows, self.window_size, Ht, Wt) # B H' W' C
186
+
187
+ # reverse cyclic shift
188
+ out = shifted_out[:, self.shift_size:(self.shift_size+H), self.shift_size:(self.shift_size+W), :]
189
+ attn_out = out.permute(0, 3, 1, 2)
190
+
191
+ if self.conv_type in ['Conv', 'DWConv']:
192
+ conv_out = self.conv(V)
193
+ out = self.proj(conv_out + attn_out)
194
+ else:
195
+ out = self.proj(attn_out)
196
+
197
+ else:
198
+ if self.conv_type == 'Conv':
199
+ out = self.conv(X) # no attention and use conv, no projection
200
+ elif self.conv_type == 'DWConv':
201
+ out = self.proj(self.conv(V))
202
+
203
+ return out
204
+
205
+
206
+ class TransformerBlock(nn.Module):
207
+ def __init__(self, network_depth, dim, num_heads, mlp_ratio=4.,
208
+ norm_layer=nn.LayerNorm, mlp_norm=False,
209
+ window_size=8, shift_size=0, use_attn=True, conv_type=None):
210
+ super().__init__()
211
+ self.use_attn = use_attn
212
+ self.mlp_norm = mlp_norm
213
+
214
+ self.norm1 = norm_layer(dim) if use_attn else nn.Identity()
215
+ self.attn = Attention(network_depth, dim, num_heads=num_heads, window_size=window_size,
216
+ shift_size=shift_size, use_attn=use_attn, conv_type=conv_type)
217
+
218
+ self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity()
219
+ self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio))
220
+
221
+ def forward(self, x):
222
+ identity = x
223
+ if self.use_attn: x, rescale, rebias = self.norm1(x)
224
+ x = self.attn(x)
225
+ if self.use_attn: x = x * rescale + rebias
226
+ x = identity + x
227
+
228
+ identity = x
229
+ if self.use_attn and self.mlp_norm: x, rescale, rebias = self.norm2(x)
230
+ x = self.mlp(x)
231
+ if self.use_attn and self.mlp_norm: x = x * rescale + rebias
232
+ x = identity + x
233
+ return x
234
+
235
+
236
+ class BasicLayer(nn.Module):
237
+ def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4.,
238
+ norm_layer=nn.LayerNorm, window_size=8,
239
+ attn_ratio=0., attn_loc='last', conv_type=None):
240
+
241
+ super().__init__()
242
+ self.dim = dim
243
+ self.depth = depth
244
+
245
+ attn_depth = attn_ratio * depth
246
+
247
+ if attn_loc == 'last':
248
+ use_attns = [i >= depth-attn_depth for i in range(depth)]
249
+ elif attn_loc == 'first':
250
+ use_attns = [i < attn_depth for i in range(depth)]
251
+ elif attn_loc == 'middle':
252
+ use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)]
253
+
254
+ # build blocks
255
+ self.blocks = nn.ModuleList([
256
+ TransformerBlock(network_depth=network_depth,
257
+ dim=dim,
258
+ num_heads=num_heads,
259
+ mlp_ratio=mlp_ratio,
260
+ norm_layer=norm_layer,
261
+ window_size=window_size,
262
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
263
+ use_attn=use_attns[i], conv_type=conv_type)
264
+ for i in range(depth)])
265
+
266
+ def forward(self, x):
267
+ for blk in self.blocks:
268
+ x = blk(x)
269
+ return x
270
+
271
+
272
+ class PatchEmbed(nn.Module):
273
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None):
274
+ super().__init__()
275
+ self.in_chans = in_chans
276
+ self.embed_dim = embed_dim
277
+
278
+ if kernel_size is None:
279
+ kernel_size = patch_size
280
+
281
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size,
282
+ padding=(kernel_size-patch_size+1)//2, padding_mode='reflect')
283
+
284
+ def forward(self, x):
285
+ x = self.proj(x)
286
+ return x
287
+
288
+
289
+ class PatchUnEmbed(nn.Module):
290
+ def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None):
291
+ super().__init__()
292
+ self.out_chans = out_chans
293
+ self.embed_dim = embed_dim
294
+
295
+ if kernel_size is None:
296
+ kernel_size = 1
297
+
298
+ self.proj = nn.Sequential(
299
+ nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size,
300
+ padding=kernel_size//2, padding_mode='reflect'),
301
+ nn.PixelShuffle(patch_size)
302
+ )
303
+
304
+ def forward(self, x):
305
+ x = self.proj(x)
306
+ return x
307
+
308
+
309
+ class SKFusion(nn.Module):
310
+ def __init__(self, dim, height=2, reduction=8):
311
+ super(SKFusion, self).__init__()
312
+
313
+ self.height = height
314
+ d = max(int(dim/reduction), 4)
315
+
316
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
317
+ self.mlp = nn.Sequential(
318
+ nn.Conv2d(dim, d, 1, bias=False),
319
+ nn.ReLU(),
320
+ nn.Conv2d(d, dim*height, 1, bias=False)
321
+ )
322
+
323
+ self.softmax = nn.Softmax(dim=1)
324
+
325
+ def forward(self, in_feats):
326
+ B, C, H, W = in_feats[0].shape
327
+
328
+ in_feats = torch.cat(in_feats, dim=1)
329
+ in_feats = in_feats.view(B, self.height, C, H, W)
330
+
331
+ feats_sum = torch.sum(in_feats, dim=1)
332
+ attn = self.mlp(self.avg_pool(feats_sum))
333
+ attn = self.softmax(attn.view(B, self.height, C, 1, 1))
334
+
335
+ out = torch.sum(in_feats*attn, dim=1)
336
+ return out
337
+
338
+
339
+ class DehazeFormer(nn.Module):
340
+ def __init__(self, in_chans=3, out_chans=3, window_size=8,
341
+ embed_dims=[24, 48, 96, 48, 24],
342
+ mlp_ratios=[2., 2., 4., 2., 2.],
343
+ depths=[4, 4, 8, 4, 4],
344
+ num_heads=[2, 4, 6, 4, 2],
345
+ attn_ratio=[1., 1., 1., 1., 1.],
346
+ conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
347
+ norm_layer=[RLN, RLN, RLN, RLN, RLN]):
348
+ super(DehazeFormer, self).__init__()
349
+
350
+ # setting
351
+ self.patch_size = 4
352
+ self.window_size = window_size
353
+ self.mlp_ratios = mlp_ratios
354
+
355
+ # split image into non-overlapping patches
356
+ self.patch_embed = PatchEmbed(
357
+ patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)
358
+
359
+ # backbone
360
+ self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
361
+ num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
362
+ norm_layer=norm_layer[0], window_size=window_size,
363
+ attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])
364
+
365
+ self.patch_merge1 = PatchEmbed(
366
+ patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
367
+
368
+ self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
369
+
370
+ self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
371
+ num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
372
+ norm_layer=norm_layer[1], window_size=window_size,
373
+ attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])
374
+
375
+ self.patch_merge2 = PatchEmbed(
376
+ patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
377
+
378
+ self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
379
+
380
+ self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
381
+ num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
382
+ norm_layer=norm_layer[2], window_size=window_size,
383
+ attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])
384
+
385
+ self.patch_split1 = PatchUnEmbed(
386
+ patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])
387
+
388
+ assert embed_dims[1] == embed_dims[3]
389
+ self.fusion1 = SKFusion(embed_dims[3])
390
+
391
+ self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
392
+ num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
393
+ norm_layer=norm_layer[3], window_size=window_size,
394
+ attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])
395
+
396
+ self.patch_split2 = PatchUnEmbed(
397
+ patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])
398
+
399
+ assert embed_dims[0] == embed_dims[4]
400
+ self.fusion2 = SKFusion(embed_dims[4])
401
+
402
+ self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
403
+ num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
404
+ norm_layer=norm_layer[4], window_size=window_size,
405
+ attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])
406
+
407
+ # merge non-overlapping patches into image
408
+ self.patch_unembed = PatchUnEmbed(
409
+ patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)
410
+
411
+ def forward(self, x):
412
+ x = self.patch_embed(x)
413
+ x = self.layer1(x)
414
+ skip1 = x
415
+
416
+ x = self.patch_merge1(x)
417
+ x = self.layer2(x)
418
+ skip2 = x
419
+
420
+ x = self.patch_merge2(x)
421
+ x = self.layer3(x)
422
+ x = self.patch_split1(x)
423
+
424
+ x = self.fusion1([x, self.skip2(skip2)]) + x
425
+ x = self.layer4(x)
426
+ x = self.patch_split2(x)
427
+
428
+ x = self.fusion2([x, self.skip1(skip1)]) + x
429
+ x = self.layer5(x)
430
+ x = self.patch_unembed(x)
431
+ return x
432
+
433
+
434
+ class MCT(nn.Module):
435
+ def __init__(self):
436
+ super(MCT, self).__init__()
437
+ self.ts = 256
438
+ self.l = 8
439
+
440
+ self.dims = 3 * 3 * self.l
441
+
442
+ self.basenet = DehazeFormer(3, self.dims)
443
+
444
+ def get_coord(self, x):
445
+ B, _, H, W = x.size()
446
+
447
+ coordh, coordw = torch.meshgrid([torch.linspace(-1,1,H), torch.linspace(-1,1,W)], indexing="ij")
448
+ coordh = coordh.unsqueeze(0).unsqueeze(1).repeat(B,1,1,1)
449
+ coordw = coordw.unsqueeze(0).unsqueeze(1).repeat(B,1,1,1)
450
+
451
+ return coordw.detach(), coordh.detach()
452
+
453
+ def mapping(self, x, param):
454
+ # curves
455
+ curve = torch.stack(torch.chunk(param, 3, dim=1), dim=1)
456
+ curve_list = list(torch.chunk(curve, 3, dim=2))
457
+
458
+ # grid: x, y, z -> w, h, d ~[-1 ,1]
459
+ x_list = list(torch.chunk(x.detach(), 3, dim=1))
460
+ coordw, coordh = self.get_coord(x)
461
+ grid_list = [torch.stack([coordw, coordh, x_i], dim=4) for x_i in x_list]
462
+
463
+ # mapping
464
+ out = sum([F.grid_sample(curve_i, grid_i, 'bilinear', 'border', True) \
465
+ for curve_i, grid_i in zip(curve_list, grid_list)]).squeeze(2)
466
+
467
+ return out # no Tanh is much better than using Tanh
468
+
469
+ def forward(self, x):
470
+ # param input
471
+ x_d = F.interpolate(x, (self.ts, self.ts), mode='area')
472
+ param = self.basenet(x_d)
473
+ out = self.mapping(x, param)
474
+ return out
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ numpy
3
+ Pillow
saved_models/dehazeformer.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:479e2017166ed8f97edcde059db759b38cc89388da5e456881ed8892ba35f0d7
3
+ size 5927945