ohayonguy
commited on
Commit
•
b7f3942
1
Parent(s):
1b8b226
first commit fixed
Browse files- app.py +1 -6
- arch/hourglass/__init__.py +0 -0
- arch/hourglass/axial_rope.py +113 -0
- arch/hourglass/flags.py +60 -0
- arch/hourglass/flops.py +58 -0
- arch/hourglass/image_transformer_v2.py +772 -0
- arch/swinir/__init__.py +0 -0
- arch/swinir/swinir.py +904 -0
- packages.txt +3 -0
- requirements.txt +21 -0
- utils/__init__.py +0 -0
- utils/basicsr_custom.py +954 -0
- utils/create_arch.py +143 -0
- utils/create_degradation.py +144 -0
- utils/img_utils.py +5 -0
app.py
CHANGED
@@ -24,17 +24,12 @@ if not os.path.exists(realesr_model_path):
|
|
24 |
os.system(
|
25 |
"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
|
26 |
|
27 |
-
pmrf_model_path = 'blind_face_restoration_pmrf.ckpt'
|
28 |
-
|
29 |
# background enhancer with RealESRGAN
|
30 |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
31 |
half = True if torch.cuda.is_available() else False
|
32 |
upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
33 |
|
34 |
-
pmrf = MMSERectifiedFlow.
|
35 |
-
mmse_model_arch='swinir_L',
|
36 |
-
mmse_model_ckpt_path=None,
|
37 |
-
map_location='cpu').to(device)
|
38 |
|
39 |
os.makedirs('output', exist_ok=True)
|
40 |
|
|
|
24 |
os.system(
|
25 |
"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -O experiments/pretrained_models/RealESRGAN_x4plus.pth")
|
26 |
|
|
|
|
|
27 |
# background enhancer with RealESRGAN
|
28 |
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
29 |
half = True if torch.cuda.is_available() else False
|
30 |
upsampler = RealESRGANer(scale=4, model_path=realesr_model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
31 |
|
32 |
+
pmrf = MMSERectifiedFlow.from_pretrained('ohayonguy/PMRF_blind_face_image_restoration').to(device)
|
|
|
|
|
|
|
33 |
|
34 |
os.makedirs('output', exist_ok=True)
|
35 |
|
arch/hourglass/__init__.py
ADDED
File without changes
|
arch/hourglass/axial_rope.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""k-diffusion transformer diffusion models, version 2.
|
2 |
+
Codes adopted from https://github.com/crowsonkb/k-diffusion
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch._dynamo
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from . import flags
|
12 |
+
|
13 |
+
if flags.get_use_compile():
|
14 |
+
torch._dynamo.config.suppress_errors = True
|
15 |
+
|
16 |
+
|
17 |
+
def rotate_half(x):
|
18 |
+
x1, x2 = x[..., 0::2], x[..., 1::2]
|
19 |
+
x = torch.stack((-x2, x1), dim=-1)
|
20 |
+
*shape, d, r = x.shape
|
21 |
+
return x.view(*shape, d * r)
|
22 |
+
|
23 |
+
|
24 |
+
@flags.compile_wrap
|
25 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
|
26 |
+
freqs = freqs.to(t)
|
27 |
+
rot_dim = freqs.shape[-1]
|
28 |
+
end_index = start_index + rot_dim
|
29 |
+
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
30 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
31 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
32 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
33 |
+
|
34 |
+
|
35 |
+
def centers(start, stop, num, dtype=None, device=None):
|
36 |
+
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
37 |
+
return (edges[:-1] + edges[1:]) / 2
|
38 |
+
|
39 |
+
|
40 |
+
def make_grid(h_pos, w_pos):
|
41 |
+
grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
|
42 |
+
h, w, d = grid.shape
|
43 |
+
return grid.view(h * w, d)
|
44 |
+
|
45 |
+
|
46 |
+
def bounding_box(h, w, pixel_aspect_ratio=1.0):
|
47 |
+
# Adjusted dimensions
|
48 |
+
w_adj = w
|
49 |
+
h_adj = h * pixel_aspect_ratio
|
50 |
+
|
51 |
+
# Adjusted aspect ratio
|
52 |
+
ar_adj = w_adj / h_adj
|
53 |
+
|
54 |
+
# Determine bounding box based on the adjusted aspect ratio
|
55 |
+
y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
|
56 |
+
if ar_adj > 1:
|
57 |
+
y_min, y_max = -1 / ar_adj, 1 / ar_adj
|
58 |
+
elif ar_adj < 1:
|
59 |
+
x_min, x_max = -ar_adj, ar_adj
|
60 |
+
|
61 |
+
return y_min, y_max, x_min, x_max
|
62 |
+
|
63 |
+
|
64 |
+
def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
|
65 |
+
y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
|
66 |
+
if align_corners:
|
67 |
+
h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
|
68 |
+
w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
|
69 |
+
else:
|
70 |
+
h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
|
71 |
+
w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
|
72 |
+
return make_grid(h_pos, w_pos)
|
73 |
+
|
74 |
+
|
75 |
+
def freqs_pixel(max_freq=10.0):
|
76 |
+
def init(shape):
|
77 |
+
freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi
|
78 |
+
return freqs.log().expand(shape)
|
79 |
+
return init
|
80 |
+
|
81 |
+
|
82 |
+
def freqs_pixel_log(max_freq=10.0):
|
83 |
+
def init(shape):
|
84 |
+
log_min = math.log(math.pi)
|
85 |
+
log_max = math.log(max_freq * math.pi / 2)
|
86 |
+
return torch.linspace(log_min, log_max, shape[-1]).expand(shape)
|
87 |
+
return init
|
88 |
+
|
89 |
+
|
90 |
+
class AxialRoPE(nn.Module):
|
91 |
+
def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
|
92 |
+
super().__init__()
|
93 |
+
self.n_heads = n_heads
|
94 |
+
self.start_index = start_index
|
95 |
+
log_freqs = freqs_init((n_heads, dim // 4))
|
96 |
+
self.freqs_h = nn.Parameter(log_freqs.clone())
|
97 |
+
self.freqs_w = nn.Parameter(log_freqs.clone())
|
98 |
+
|
99 |
+
def extra_repr(self):
|
100 |
+
dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2
|
101 |
+
return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}"
|
102 |
+
|
103 |
+
def get_freqs(self, pos):
|
104 |
+
if pos.shape[-1] != 2:
|
105 |
+
raise ValueError("input shape must be (..., 2)")
|
106 |
+
freqs_h = pos[..., None, None, 0] * self.freqs_h.exp()
|
107 |
+
freqs_w = pos[..., None, None, 1] * self.freqs_w.exp()
|
108 |
+
freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1)
|
109 |
+
return freqs.transpose(-2, -3)
|
110 |
+
|
111 |
+
def forward(self, x, pos):
|
112 |
+
freqs = self.get_freqs(pos)
|
113 |
+
return apply_rotary_emb(freqs, x, self.start_index)
|
arch/hourglass/flags.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""k-diffusion transformer diffusion models, version 2.
|
2 |
+
Codes adopted from https://github.com/crowsonkb/k-diffusion
|
3 |
+
"""
|
4 |
+
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from functools import update_wrapper
|
7 |
+
import os
|
8 |
+
import threading
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
def get_use_compile():
|
14 |
+
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"
|
15 |
+
|
16 |
+
|
17 |
+
def get_use_flash_attention_2():
|
18 |
+
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"
|
19 |
+
|
20 |
+
|
21 |
+
state = threading.local()
|
22 |
+
state.checkpointing = False
|
23 |
+
|
24 |
+
|
25 |
+
@contextmanager
|
26 |
+
def checkpointing(enable=True):
|
27 |
+
try:
|
28 |
+
old_checkpointing, state.checkpointing = state.checkpointing, enable
|
29 |
+
yield
|
30 |
+
finally:
|
31 |
+
state.checkpointing = old_checkpointing
|
32 |
+
|
33 |
+
|
34 |
+
def get_checkpointing():
|
35 |
+
return getattr(state, "checkpointing", False)
|
36 |
+
|
37 |
+
|
38 |
+
class compile_wrap:
|
39 |
+
def __init__(self, function, *args, **kwargs):
|
40 |
+
self.function = function
|
41 |
+
self.args = args
|
42 |
+
self.kwargs = kwargs
|
43 |
+
self._compiled_function = None
|
44 |
+
update_wrapper(self, function)
|
45 |
+
|
46 |
+
@property
|
47 |
+
def compiled_function(self):
|
48 |
+
if self._compiled_function is not None:
|
49 |
+
return self._compiled_function
|
50 |
+
if get_use_compile():
|
51 |
+
try:
|
52 |
+
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
|
53 |
+
except RuntimeError:
|
54 |
+
self._compiled_function = self.function
|
55 |
+
else:
|
56 |
+
self._compiled_function = self.function
|
57 |
+
return self._compiled_function
|
58 |
+
|
59 |
+
def __call__(self, *args, **kwargs):
|
60 |
+
return self.compiled_function(*args, **kwargs)
|
arch/hourglass/flops.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""k-diffusion transformer diffusion models, version 2.
|
2 |
+
Codes adopted from https://github.com/crowsonkb/k-diffusion
|
3 |
+
"""
|
4 |
+
|
5 |
+
from contextlib import contextmanager
|
6 |
+
import math
|
7 |
+
import threading
|
8 |
+
|
9 |
+
|
10 |
+
state = threading.local()
|
11 |
+
state.flop_counter = None
|
12 |
+
|
13 |
+
|
14 |
+
@contextmanager
|
15 |
+
def flop_counter(enable=True):
|
16 |
+
try:
|
17 |
+
old_flop_counter = state.flop_counter
|
18 |
+
state.flop_counter = FlopCounter() if enable else None
|
19 |
+
yield state.flop_counter
|
20 |
+
finally:
|
21 |
+
state.flop_counter = old_flop_counter
|
22 |
+
|
23 |
+
|
24 |
+
class FlopCounter:
|
25 |
+
def __init__(self):
|
26 |
+
self.ops = []
|
27 |
+
|
28 |
+
def op(self, op, *args, **kwargs):
|
29 |
+
self.ops.append((op, args, kwargs))
|
30 |
+
|
31 |
+
@property
|
32 |
+
def flops(self):
|
33 |
+
flops = 0
|
34 |
+
for op, args, kwargs in self.ops:
|
35 |
+
flops += op(*args, **kwargs)
|
36 |
+
return flops
|
37 |
+
|
38 |
+
|
39 |
+
def op(op, *args, **kwargs):
|
40 |
+
if getattr(state, "flop_counter", None):
|
41 |
+
state.flop_counter.op(op, *args, **kwargs)
|
42 |
+
|
43 |
+
|
44 |
+
def op_linear(x, weight):
|
45 |
+
return math.prod(x) * weight[0]
|
46 |
+
|
47 |
+
|
48 |
+
def op_attention(q, k, v):
|
49 |
+
*b, s_q, d_q = q
|
50 |
+
*b, s_k, d_k = k
|
51 |
+
*b, s_v, d_v = v
|
52 |
+
return math.prod(b) * s_q * s_k * (d_q + d_v)
|
53 |
+
|
54 |
+
|
55 |
+
def op_natten(q, k, v, kernel_size):
|
56 |
+
*q_rest, d_q = q
|
57 |
+
*_, d_v = v
|
58 |
+
return math.prod(q_rest) * (d_q + d_v) * kernel_size**2
|
arch/hourglass/image_transformer_v2.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""k-diffusion transformer diffusion models, version 2.
|
2 |
+
Codes adopted from https://github.com/crowsonkb/k-diffusion
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from functools import lru_cache, reduce
|
7 |
+
import math
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
import torch._dynamo
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
from . import flags, flops
|
17 |
+
from .axial_rope import make_axial_pos
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
import natten
|
22 |
+
except ImportError:
|
23 |
+
natten = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
import flash_attn
|
27 |
+
except ImportError:
|
28 |
+
flash_attn = None
|
29 |
+
|
30 |
+
|
31 |
+
if flags.get_use_compile():
|
32 |
+
torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit)
|
33 |
+
torch._dynamo.config.suppress_errors = True
|
34 |
+
|
35 |
+
|
36 |
+
# Helpers
|
37 |
+
|
38 |
+
def zero_init(layer):
|
39 |
+
nn.init.zeros_(layer.weight)
|
40 |
+
if layer.bias is not None:
|
41 |
+
nn.init.zeros_(layer.bias)
|
42 |
+
return layer
|
43 |
+
|
44 |
+
|
45 |
+
def checkpoint(function, *args, **kwargs):
|
46 |
+
if flags.get_checkpointing():
|
47 |
+
kwargs.setdefault("use_reentrant", True)
|
48 |
+
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
|
49 |
+
else:
|
50 |
+
return function(*args, **kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
def downscale_pos(pos):
|
54 |
+
pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=2, nw=2)
|
55 |
+
return torch.mean(pos, dim=-2)
|
56 |
+
|
57 |
+
|
58 |
+
# Param tags
|
59 |
+
|
60 |
+
def tag_param(param, tag):
|
61 |
+
if not hasattr(param, "_tags"):
|
62 |
+
param._tags = set([tag])
|
63 |
+
else:
|
64 |
+
param._tags.add(tag)
|
65 |
+
return param
|
66 |
+
|
67 |
+
|
68 |
+
def tag_module(module, tag):
|
69 |
+
for param in module.parameters():
|
70 |
+
tag_param(param, tag)
|
71 |
+
return module
|
72 |
+
|
73 |
+
|
74 |
+
def apply_wd(module):
|
75 |
+
for name, param in module.named_parameters():
|
76 |
+
if name.endswith("weight"):
|
77 |
+
tag_param(param, "wd")
|
78 |
+
return module
|
79 |
+
|
80 |
+
|
81 |
+
def filter_params(function, module):
|
82 |
+
for param in module.parameters():
|
83 |
+
tags = getattr(param, "_tags", set())
|
84 |
+
if function(tags):
|
85 |
+
yield param
|
86 |
+
|
87 |
+
|
88 |
+
# Kernels
|
89 |
+
|
90 |
+
@flags.compile_wrap
|
91 |
+
def linear_geglu(x, weight, bias=None):
|
92 |
+
x = x @ weight.mT
|
93 |
+
if bias is not None:
|
94 |
+
x = x + bias
|
95 |
+
x, gate = x.chunk(2, dim=-1)
|
96 |
+
return x * F.gelu(gate)
|
97 |
+
|
98 |
+
|
99 |
+
@flags.compile_wrap
|
100 |
+
def rms_norm(x, scale, eps):
|
101 |
+
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
|
102 |
+
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
|
103 |
+
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
|
104 |
+
return x * scale.to(x.dtype)
|
105 |
+
|
106 |
+
|
107 |
+
@flags.compile_wrap
|
108 |
+
def scale_for_cosine_sim(q, k, scale, eps):
|
109 |
+
dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
|
110 |
+
sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
|
111 |
+
sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True)
|
112 |
+
sqrt_scale = torch.sqrt(scale.to(dtype))
|
113 |
+
scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
|
114 |
+
scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps)
|
115 |
+
return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)
|
116 |
+
|
117 |
+
|
118 |
+
@flags.compile_wrap
|
119 |
+
def scale_for_cosine_sim_qkv(qkv, scale, eps):
|
120 |
+
q, k, v = qkv.unbind(2)
|
121 |
+
q, k = scale_for_cosine_sim(q, k, scale[:, None], eps)
|
122 |
+
return torch.stack((q, k, v), dim=2)
|
123 |
+
|
124 |
+
|
125 |
+
# Layers
|
126 |
+
|
127 |
+
class Linear(nn.Linear):
|
128 |
+
def forward(self, x):
|
129 |
+
flops.op(flops.op_linear, x.shape, self.weight.shape)
|
130 |
+
return super().forward(x)
|
131 |
+
|
132 |
+
|
133 |
+
class LinearGEGLU(nn.Linear):
|
134 |
+
def __init__(self, in_features, out_features, bias=True):
|
135 |
+
super().__init__(in_features, out_features * 2, bias=bias)
|
136 |
+
self.out_features = out_features
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
flops.op(flops.op_linear, x.shape, self.weight.shape)
|
140 |
+
return linear_geglu(x, self.weight, self.bias)
|
141 |
+
|
142 |
+
|
143 |
+
class FourierFeatures(nn.Module):
|
144 |
+
def __init__(self, in_features, out_features, std=1.):
|
145 |
+
super().__init__()
|
146 |
+
assert out_features % 2 == 0
|
147 |
+
self.register_buffer('weight', torch.randn([out_features // 2, in_features]) * std)
|
148 |
+
|
149 |
+
def forward(self, input):
|
150 |
+
f = 2 * math.pi * input @ self.weight.T
|
151 |
+
return torch.cat([f.cos(), f.sin()], dim=-1)
|
152 |
+
|
153 |
+
class RMSNorm(nn.Module):
|
154 |
+
def __init__(self, shape, eps=1e-6):
|
155 |
+
super().__init__()
|
156 |
+
self.eps = eps
|
157 |
+
self.scale = nn.Parameter(torch.ones(shape))
|
158 |
+
|
159 |
+
def extra_repr(self):
|
160 |
+
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
return rms_norm(x, self.scale, self.eps)
|
164 |
+
|
165 |
+
|
166 |
+
class AdaRMSNorm(nn.Module):
|
167 |
+
def __init__(self, features, cond_features, eps=1e-6):
|
168 |
+
super().__init__()
|
169 |
+
self.eps = eps
|
170 |
+
self.linear = apply_wd(zero_init(Linear(cond_features, features, bias=False)))
|
171 |
+
tag_module(self.linear, "mapping")
|
172 |
+
|
173 |
+
def extra_repr(self):
|
174 |
+
return f"eps={self.eps},"
|
175 |
+
|
176 |
+
def forward(self, x, cond):
|
177 |
+
return rms_norm(x, self.linear(cond)[:, None, None, :] + 1, self.eps)
|
178 |
+
|
179 |
+
|
180 |
+
# Rotary position embeddings
|
181 |
+
|
182 |
+
@flags.compile_wrap
|
183 |
+
def apply_rotary_emb(x, theta, conj=False):
|
184 |
+
out_dtype = x.dtype
|
185 |
+
dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
|
186 |
+
d = theta.shape[-1]
|
187 |
+
assert d * 2 <= x.shape[-1]
|
188 |
+
x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :]
|
189 |
+
x1, x2, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype)
|
190 |
+
cos, sin = torch.cos(theta), torch.sin(theta)
|
191 |
+
sin = -sin if conj else sin
|
192 |
+
y1 = x1 * cos - x2 * sin
|
193 |
+
y2 = x2 * cos + x1 * sin
|
194 |
+
y1, y2 = y1.to(out_dtype), y2.to(out_dtype)
|
195 |
+
return torch.cat((y1, y2, x3), dim=-1)
|
196 |
+
|
197 |
+
|
198 |
+
@flags.compile_wrap
|
199 |
+
def _apply_rotary_emb_inplace(x, theta, conj):
|
200 |
+
dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
|
201 |
+
d = theta.shape[-1]
|
202 |
+
assert d * 2 <= x.shape[-1]
|
203 |
+
x1, x2 = x[..., :d], x[..., d : d * 2]
|
204 |
+
x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype)
|
205 |
+
cos, sin = torch.cos(theta), torch.sin(theta)
|
206 |
+
sin = -sin if conj else sin
|
207 |
+
y1 = x1_ * cos - x2_ * sin
|
208 |
+
y2 = x2_ * cos + x1_ * sin
|
209 |
+
x1.copy_(y1)
|
210 |
+
x2.copy_(y2)
|
211 |
+
|
212 |
+
|
213 |
+
class ApplyRotaryEmbeddingInplace(torch.autograd.Function):
|
214 |
+
@staticmethod
|
215 |
+
def forward(x, theta, conj):
|
216 |
+
_apply_rotary_emb_inplace(x, theta, conj=conj)
|
217 |
+
return x
|
218 |
+
|
219 |
+
@staticmethod
|
220 |
+
def setup_context(ctx, inputs, output):
|
221 |
+
_, theta, conj = inputs
|
222 |
+
ctx.save_for_backward(theta)
|
223 |
+
ctx.conj = conj
|
224 |
+
|
225 |
+
@staticmethod
|
226 |
+
def backward(ctx, grad_output):
|
227 |
+
theta, = ctx.saved_tensors
|
228 |
+
_apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj)
|
229 |
+
return grad_output, None, None
|
230 |
+
|
231 |
+
|
232 |
+
def apply_rotary_emb_(x, theta):
|
233 |
+
return ApplyRotaryEmbeddingInplace.apply(x, theta, False)
|
234 |
+
|
235 |
+
|
236 |
+
class AxialRoPE(nn.Module):
|
237 |
+
def __init__(self, dim, n_heads):
|
238 |
+
super().__init__()
|
239 |
+
log_min = math.log(math.pi)
|
240 |
+
log_max = math.log(10.0 * math.pi)
|
241 |
+
freqs = torch.linspace(log_min, log_max, n_heads * dim // 4 + 1)[:-1].exp()
|
242 |
+
self.register_buffer("freqs", freqs.view(dim // 4, n_heads).T.contiguous())
|
243 |
+
|
244 |
+
def extra_repr(self):
|
245 |
+
return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}"
|
246 |
+
|
247 |
+
def forward(self, pos):
|
248 |
+
theta_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype)
|
249 |
+
theta_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype)
|
250 |
+
return torch.cat((theta_h, theta_w), dim=-1)
|
251 |
+
|
252 |
+
|
253 |
+
# Shifted window attention
|
254 |
+
|
255 |
+
def window(window_size, x):
|
256 |
+
*b, h, w, c = x.shape
|
257 |
+
x = torch.reshape(
|
258 |
+
x,
|
259 |
+
(*b, h // window_size, window_size, w // window_size, window_size, c),
|
260 |
+
)
|
261 |
+
x = torch.permute(
|
262 |
+
x,
|
263 |
+
(*range(len(b)), -5, -3, -4, -2, -1),
|
264 |
+
)
|
265 |
+
return x
|
266 |
+
|
267 |
+
|
268 |
+
def unwindow(x):
|
269 |
+
*b, h, w, wh, ww, c = x.shape
|
270 |
+
x = torch.permute(x, (*range(len(b)), -5, -3, -4, -2, -1))
|
271 |
+
x = torch.reshape(x, (*b, h * wh, w * ww, c))
|
272 |
+
return x
|
273 |
+
|
274 |
+
|
275 |
+
def shifted_window(window_size, window_shift, x):
|
276 |
+
x = torch.roll(x, shifts=(window_shift, window_shift), dims=(-2, -3))
|
277 |
+
windows = window(window_size, x)
|
278 |
+
return windows
|
279 |
+
|
280 |
+
|
281 |
+
def shifted_unwindow(window_shift, x):
|
282 |
+
x = unwindow(x)
|
283 |
+
x = torch.roll(x, shifts=(-window_shift, -window_shift), dims=(-2, -3))
|
284 |
+
return x
|
285 |
+
|
286 |
+
|
287 |
+
@lru_cache
|
288 |
+
def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None):
|
289 |
+
ph_coords = torch.arange(n_h_w, device=device)
|
290 |
+
pw_coords = torch.arange(n_w_w, device=device)
|
291 |
+
h_coords = torch.arange(w_h, device=device)
|
292 |
+
w_coords = torch.arange(w_w, device=device)
|
293 |
+
patch_h, patch_w, q_h, q_w, k_h, k_w = torch.meshgrid(
|
294 |
+
ph_coords,
|
295 |
+
pw_coords,
|
296 |
+
h_coords,
|
297 |
+
w_coords,
|
298 |
+
h_coords,
|
299 |
+
w_coords,
|
300 |
+
indexing="ij",
|
301 |
+
)
|
302 |
+
is_top_patch = patch_h == 0
|
303 |
+
is_left_patch = patch_w == 0
|
304 |
+
q_above_shift = q_h < shift
|
305 |
+
k_above_shift = k_h < shift
|
306 |
+
q_left_of_shift = q_w < shift
|
307 |
+
k_left_of_shift = k_w < shift
|
308 |
+
m_corner = (
|
309 |
+
is_left_patch
|
310 |
+
& is_top_patch
|
311 |
+
& (q_left_of_shift == k_left_of_shift)
|
312 |
+
& (q_above_shift == k_above_shift)
|
313 |
+
)
|
314 |
+
m_left = is_left_patch & ~is_top_patch & (q_left_of_shift == k_left_of_shift)
|
315 |
+
m_top = ~is_left_patch & is_top_patch & (q_above_shift == k_above_shift)
|
316 |
+
m_rest = ~is_left_patch & ~is_top_patch
|
317 |
+
m = m_corner | m_left | m_top | m_rest
|
318 |
+
return m
|
319 |
+
|
320 |
+
|
321 |
+
def apply_window_attention(window_size, window_shift, q, k, v, scale=None):
|
322 |
+
# prep windows and masks
|
323 |
+
q_windows = shifted_window(window_size, window_shift, q)
|
324 |
+
k_windows = shifted_window(window_size, window_shift, k)
|
325 |
+
v_windows = shifted_window(window_size, window_shift, v)
|
326 |
+
b, heads, h, w, wh, ww, d_head = q_windows.shape
|
327 |
+
mask = make_shifted_window_masks(h, w, wh, ww, window_shift, device=q.device)
|
328 |
+
q_seqs = torch.reshape(q_windows, (b, heads, h, w, wh * ww, d_head))
|
329 |
+
k_seqs = torch.reshape(k_windows, (b, heads, h, w, wh * ww, d_head))
|
330 |
+
v_seqs = torch.reshape(v_windows, (b, heads, h, w, wh * ww, d_head))
|
331 |
+
mask = torch.reshape(mask, (h, w, wh * ww, wh * ww))
|
332 |
+
|
333 |
+
# do the attention here
|
334 |
+
flops.op(flops.op_attention, q_seqs.shape, k_seqs.shape, v_seqs.shape)
|
335 |
+
qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=scale)
|
336 |
+
|
337 |
+
# unwindow
|
338 |
+
qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head))
|
339 |
+
return shifted_unwindow(window_shift, qkv)
|
340 |
+
|
341 |
+
|
342 |
+
# Transformer layers
|
343 |
+
|
344 |
+
|
345 |
+
def use_flash_2(x):
|
346 |
+
if not flags.get_use_flash_attention_2():
|
347 |
+
return False
|
348 |
+
if flash_attn is None:
|
349 |
+
return False
|
350 |
+
if x.device.type != "cuda":
|
351 |
+
return False
|
352 |
+
if x.dtype not in (torch.float16, torch.bfloat16):
|
353 |
+
return False
|
354 |
+
return True
|
355 |
+
|
356 |
+
|
357 |
+
class SelfAttentionBlock(nn.Module):
|
358 |
+
def __init__(self, d_model, d_head, cond_features, dropout=0.0):
|
359 |
+
super().__init__()
|
360 |
+
self.d_head = d_head
|
361 |
+
self.n_heads = d_model // d_head
|
362 |
+
self.norm = AdaRMSNorm(d_model, cond_features)
|
363 |
+
self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False))
|
364 |
+
self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
|
365 |
+
self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
|
366 |
+
self.dropout = nn.Dropout(dropout)
|
367 |
+
self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False)))
|
368 |
+
|
369 |
+
def extra_repr(self):
|
370 |
+
return f"d_head={self.d_head},"
|
371 |
+
|
372 |
+
def forward(self, x, pos, cond):
|
373 |
+
skip = x
|
374 |
+
x = self.norm(x, cond)
|
375 |
+
qkv = self.qkv_proj(x)
|
376 |
+
pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype)
|
377 |
+
theta = self.pos_emb(pos)
|
378 |
+
if use_flash_2(qkv):
|
379 |
+
qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head)
|
380 |
+
qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6)
|
381 |
+
theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3)
|
382 |
+
qkv = apply_rotary_emb_(qkv, theta)
|
383 |
+
flops_shape = qkv.shape[-5], qkv.shape[-2], qkv.shape[-4], qkv.shape[-1]
|
384 |
+
flops.op(flops.op_attention, flops_shape, flops_shape, flops_shape)
|
385 |
+
x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0)
|
386 |
+
x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2])
|
387 |
+
else:
|
388 |
+
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head)
|
389 |
+
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6)
|
390 |
+
theta = theta.movedim(-2, -3)
|
391 |
+
q = apply_rotary_emb_(q, theta)
|
392 |
+
k = apply_rotary_emb_(k, theta)
|
393 |
+
flops.op(flops.op_attention, q.shape, k.shape, v.shape)
|
394 |
+
x = F.scaled_dot_product_attention(q, k, v, scale=1.0)
|
395 |
+
x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2])
|
396 |
+
x = self.dropout(x)
|
397 |
+
x = self.out_proj(x)
|
398 |
+
return x + skip
|
399 |
+
|
400 |
+
|
401 |
+
class NeighborhoodSelfAttentionBlock(nn.Module):
|
402 |
+
def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0):
|
403 |
+
super().__init__()
|
404 |
+
self.d_head = d_head
|
405 |
+
self.n_heads = d_model // d_head
|
406 |
+
self.kernel_size = kernel_size
|
407 |
+
self.norm = AdaRMSNorm(d_model, cond_features)
|
408 |
+
self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False))
|
409 |
+
self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
|
410 |
+
self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
|
411 |
+
self.dropout = nn.Dropout(dropout)
|
412 |
+
self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False)))
|
413 |
+
|
414 |
+
def extra_repr(self):
|
415 |
+
return f"d_head={self.d_head}, kernel_size={self.kernel_size}"
|
416 |
+
|
417 |
+
def forward(self, x, pos, cond):
|
418 |
+
skip = x
|
419 |
+
x = self.norm(x, cond)
|
420 |
+
qkv = self.qkv_proj(x)
|
421 |
+
if natten is None:
|
422 |
+
raise ModuleNotFoundError("natten is required for neighborhood attention")
|
423 |
+
if natten.has_fused_na():
|
424 |
+
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n h w nh e", t=3, e=self.d_head)
|
425 |
+
q, k = scale_for_cosine_sim(q, k, self.scale[:, None], 1e-6)
|
426 |
+
theta = self.pos_emb(pos)
|
427 |
+
q = apply_rotary_emb_(q, theta)
|
428 |
+
k = apply_rotary_emb_(k, theta)
|
429 |
+
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
|
430 |
+
x = natten.functional.na2d(q, k, v, self.kernel_size, scale=1.0)
|
431 |
+
x = rearrange(x, "n h w nh e -> n h w (nh e)")
|
432 |
+
else:
|
433 |
+
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
|
434 |
+
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
|
435 |
+
theta = self.pos_emb(pos).movedim(-2, -4)
|
436 |
+
q = apply_rotary_emb_(q, theta)
|
437 |
+
k = apply_rotary_emb_(k, theta)
|
438 |
+
flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size)
|
439 |
+
qk = natten.functional.na2d_qk(q, k, self.kernel_size)
|
440 |
+
a = torch.softmax(qk, dim=-1).to(v.dtype)
|
441 |
+
x = natten.functional.na2d_av(a, v, self.kernel_size)
|
442 |
+
x = rearrange(x, "n nh h w e -> n h w (nh e)")
|
443 |
+
x = self.dropout(x)
|
444 |
+
x = self.out_proj(x)
|
445 |
+
return x + skip
|
446 |
+
|
447 |
+
|
448 |
+
class ShiftedWindowSelfAttentionBlock(nn.Module):
|
449 |
+
def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dropout=0.0):
|
450 |
+
super().__init__()
|
451 |
+
self.d_head = d_head
|
452 |
+
self.n_heads = d_model // d_head
|
453 |
+
self.window_size = window_size
|
454 |
+
self.window_shift = window_shift
|
455 |
+
self.norm = AdaRMSNorm(d_model, cond_features)
|
456 |
+
self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False))
|
457 |
+
self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
|
458 |
+
self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
|
459 |
+
self.dropout = nn.Dropout(dropout)
|
460 |
+
self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False)))
|
461 |
+
|
462 |
+
def extra_repr(self):
|
463 |
+
return f"d_head={self.d_head}, window_size={self.window_size}, window_shift={self.window_shift}"
|
464 |
+
|
465 |
+
def forward(self, x, pos, cond):
|
466 |
+
skip = x
|
467 |
+
x = self.norm(x, cond)
|
468 |
+
qkv = self.qkv_proj(x)
|
469 |
+
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
|
470 |
+
q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
|
471 |
+
theta = self.pos_emb(pos).movedim(-2, -4)
|
472 |
+
q = apply_rotary_emb_(q, theta)
|
473 |
+
k = apply_rotary_emb_(k, theta)
|
474 |
+
x = apply_window_attention(self.window_size, self.window_shift, q, k, v, scale=1.0)
|
475 |
+
x = rearrange(x, "n nh h w e -> n h w (nh e)")
|
476 |
+
x = self.dropout(x)
|
477 |
+
x = self.out_proj(x)
|
478 |
+
return x + skip
|
479 |
+
|
480 |
+
|
481 |
+
class FeedForwardBlock(nn.Module):
|
482 |
+
def __init__(self, d_model, d_ff, cond_features, dropout=0.0):
|
483 |
+
super().__init__()
|
484 |
+
self.norm = AdaRMSNorm(d_model, cond_features)
|
485 |
+
self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False))
|
486 |
+
self.dropout = nn.Dropout(dropout)
|
487 |
+
self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False)))
|
488 |
+
|
489 |
+
def forward(self, x, cond):
|
490 |
+
skip = x
|
491 |
+
x = self.norm(x, cond)
|
492 |
+
x = self.up_proj(x)
|
493 |
+
x = self.dropout(x)
|
494 |
+
x = self.down_proj(x)
|
495 |
+
return x + skip
|
496 |
+
|
497 |
+
|
498 |
+
class GlobalTransformerLayer(nn.Module):
|
499 |
+
def __init__(self, d_model, d_ff, d_head, cond_features, dropout=0.0):
|
500 |
+
super().__init__()
|
501 |
+
self.self_attn = SelfAttentionBlock(d_model, d_head, cond_features, dropout=dropout)
|
502 |
+
self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
|
503 |
+
|
504 |
+
def forward(self, x, pos, cond):
|
505 |
+
x = checkpoint(self.self_attn, x, pos, cond)
|
506 |
+
x = checkpoint(self.ff, x, cond)
|
507 |
+
return x
|
508 |
+
|
509 |
+
|
510 |
+
class NeighborhoodTransformerLayer(nn.Module):
|
511 |
+
def __init__(self, d_model, d_ff, d_head, cond_features, kernel_size, dropout=0.0):
|
512 |
+
super().__init__()
|
513 |
+
self.self_attn = NeighborhoodSelfAttentionBlock(d_model, d_head, cond_features, kernel_size, dropout=dropout)
|
514 |
+
self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
|
515 |
+
|
516 |
+
def forward(self, x, pos, cond):
|
517 |
+
x = checkpoint(self.self_attn, x, pos, cond)
|
518 |
+
x = checkpoint(self.ff, x, cond)
|
519 |
+
return x
|
520 |
+
|
521 |
+
|
522 |
+
class ShiftedWindowTransformerLayer(nn.Module):
|
523 |
+
def __init__(self, d_model, d_ff, d_head, cond_features, window_size, index, dropout=0.0):
|
524 |
+
super().__init__()
|
525 |
+
window_shift = window_size // 2 if index % 2 == 1 else 0
|
526 |
+
self.self_attn = ShiftedWindowSelfAttentionBlock(d_model, d_head, cond_features, window_size, window_shift, dropout=dropout)
|
527 |
+
self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
|
528 |
+
|
529 |
+
def forward(self, x, pos, cond):
|
530 |
+
x = checkpoint(self.self_attn, x, pos, cond)
|
531 |
+
x = checkpoint(self.ff, x, cond)
|
532 |
+
return x
|
533 |
+
|
534 |
+
|
535 |
+
class NoAttentionTransformerLayer(nn.Module):
|
536 |
+
def __init__(self, d_model, d_ff, cond_features, dropout=0.0):
|
537 |
+
super().__init__()
|
538 |
+
self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout)
|
539 |
+
|
540 |
+
def forward(self, x, pos, cond):
|
541 |
+
x = checkpoint(self.ff, x, cond)
|
542 |
+
return x
|
543 |
+
|
544 |
+
|
545 |
+
class Level(nn.ModuleList):
|
546 |
+
def forward(self, x, *args, **kwargs):
|
547 |
+
for layer in self:
|
548 |
+
x = layer(x, *args, **kwargs)
|
549 |
+
return x
|
550 |
+
|
551 |
+
|
552 |
+
# Mapping network
|
553 |
+
|
554 |
+
class MappingFeedForwardBlock(nn.Module):
|
555 |
+
def __init__(self, d_model, d_ff, dropout=0.0):
|
556 |
+
super().__init__()
|
557 |
+
self.norm = RMSNorm(d_model)
|
558 |
+
self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False))
|
559 |
+
self.dropout = nn.Dropout(dropout)
|
560 |
+
self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False)))
|
561 |
+
|
562 |
+
def forward(self, x):
|
563 |
+
skip = x
|
564 |
+
x = self.norm(x)
|
565 |
+
x = self.up_proj(x)
|
566 |
+
x = self.dropout(x)
|
567 |
+
x = self.down_proj(x)
|
568 |
+
return x + skip
|
569 |
+
|
570 |
+
|
571 |
+
class MappingNetwork(nn.Module):
|
572 |
+
def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
|
573 |
+
super().__init__()
|
574 |
+
self.in_norm = RMSNorm(d_model)
|
575 |
+
self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
|
576 |
+
self.out_norm = RMSNorm(d_model)
|
577 |
+
|
578 |
+
def forward(self, x):
|
579 |
+
x = self.in_norm(x)
|
580 |
+
for block in self.blocks:
|
581 |
+
x = block(x)
|
582 |
+
x = self.out_norm(x)
|
583 |
+
return x
|
584 |
+
|
585 |
+
|
586 |
+
# Token merging and splitting
|
587 |
+
|
588 |
+
class TokenMerge(nn.Module):
|
589 |
+
def __init__(self, in_features, out_features, patch_size=(2, 2)):
|
590 |
+
super().__init__()
|
591 |
+
self.h = patch_size[0]
|
592 |
+
self.w = patch_size[1]
|
593 |
+
self.proj = apply_wd(Linear(in_features * self.h * self.w, out_features, bias=False))
|
594 |
+
|
595 |
+
def forward(self, x):
|
596 |
+
x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w)
|
597 |
+
return self.proj(x)
|
598 |
+
|
599 |
+
|
600 |
+
class TokenSplitWithoutSkip(nn.Module):
|
601 |
+
def __init__(self, in_features, out_features, patch_size=(2, 2)):
|
602 |
+
super().__init__()
|
603 |
+
self.h = patch_size[0]
|
604 |
+
self.w = patch_size[1]
|
605 |
+
self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False))
|
606 |
+
|
607 |
+
def forward(self, x):
|
608 |
+
x = self.proj(x)
|
609 |
+
return rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w)
|
610 |
+
|
611 |
+
|
612 |
+
class TokenSplit(nn.Module):
|
613 |
+
def __init__(self, in_features, out_features, patch_size=(2, 2)):
|
614 |
+
super().__init__()
|
615 |
+
self.h = patch_size[0]
|
616 |
+
self.w = patch_size[1]
|
617 |
+
self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False))
|
618 |
+
self.fac = nn.Parameter(torch.ones(1) * 0.5)
|
619 |
+
|
620 |
+
def forward(self, x, skip):
|
621 |
+
x = self.proj(x)
|
622 |
+
x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w)
|
623 |
+
return torch.lerp(skip, x, self.fac.to(x.dtype))
|
624 |
+
|
625 |
+
|
626 |
+
# Configuration
|
627 |
+
|
628 |
+
@dataclass
|
629 |
+
class GlobalAttentionSpec:
|
630 |
+
d_head: int
|
631 |
+
|
632 |
+
|
633 |
+
@dataclass
|
634 |
+
class NeighborhoodAttentionSpec:
|
635 |
+
d_head: int
|
636 |
+
kernel_size: int
|
637 |
+
|
638 |
+
|
639 |
+
@dataclass
|
640 |
+
class ShiftedWindowAttentionSpec:
|
641 |
+
d_head: int
|
642 |
+
window_size: int
|
643 |
+
|
644 |
+
|
645 |
+
@dataclass
|
646 |
+
class NoAttentionSpec:
|
647 |
+
pass
|
648 |
+
|
649 |
+
|
650 |
+
@dataclass
|
651 |
+
class LevelSpec:
|
652 |
+
depth: int
|
653 |
+
width: int
|
654 |
+
d_ff: int
|
655 |
+
self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, ShiftedWindowAttentionSpec, NoAttentionSpec]
|
656 |
+
dropout: float
|
657 |
+
|
658 |
+
|
659 |
+
@dataclass
|
660 |
+
class MappingSpec:
|
661 |
+
depth: int
|
662 |
+
width: int
|
663 |
+
d_ff: int
|
664 |
+
dropout: float
|
665 |
+
|
666 |
+
|
667 |
+
# Model class
|
668 |
+
|
669 |
+
class ImageTransformerDenoiserModelV2(nn.Module):
|
670 |
+
def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0, degradation_params_dim=None):
|
671 |
+
super().__init__()
|
672 |
+
self.num_classes = num_classes
|
673 |
+
self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size)
|
674 |
+
self.mapping_width = mapping.width
|
675 |
+
self.time_emb = FourierFeatures(1, mapping.width)
|
676 |
+
self.time_in_proj = Linear(mapping.width, mapping.width, bias=False)
|
677 |
+
self.aug_emb = FourierFeatures(9, mapping.width)
|
678 |
+
self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False)
|
679 |
+
self.degradation_proj = Linear(degradation_params_dim, mapping.width, bias=False) if degradation_params_dim else None
|
680 |
+
self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None
|
681 |
+
self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None
|
682 |
+
self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout), "mapping")
|
683 |
+
|
684 |
+
self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList()
|
685 |
+
for i, spec in enumerate(levels):
|
686 |
+
if isinstance(spec.self_attn, GlobalAttentionSpec):
|
687 |
+
layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout)
|
688 |
+
elif isinstance(spec.self_attn, NeighborhoodAttentionSpec):
|
689 |
+
layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout)
|
690 |
+
elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec):
|
691 |
+
layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout)
|
692 |
+
elif isinstance(spec.self_attn, NoAttentionSpec):
|
693 |
+
layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout)
|
694 |
+
else:
|
695 |
+
raise ValueError(f"unsupported self attention spec {spec.self_attn}")
|
696 |
+
|
697 |
+
if i < len(levels) - 1:
|
698 |
+
self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)]))
|
699 |
+
self.up_levels.append(Level([layer_factory(i + spec.depth) for i in range(spec.depth)]))
|
700 |
+
else:
|
701 |
+
self.mid_level = Level([layer_factory(i) for i in range(spec.depth)])
|
702 |
+
|
703 |
+
self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])])
|
704 |
+
self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])])
|
705 |
+
|
706 |
+
self.out_norm = RMSNorm(levels[0].width)
|
707 |
+
self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channels, patch_size)
|
708 |
+
nn.init.zeros_(self.patch_out.proj.weight)
|
709 |
+
|
710 |
+
def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3):
|
711 |
+
wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self)
|
712 |
+
no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self)
|
713 |
+
mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self)
|
714 |
+
mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self)
|
715 |
+
groups = [
|
716 |
+
{"params": list(wd), "lr": base_lr},
|
717 |
+
{"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0},
|
718 |
+
{"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale},
|
719 |
+
{"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0}
|
720 |
+
]
|
721 |
+
return groups
|
722 |
+
|
723 |
+
def forward(self, x, sigma=None, aug_cond=None, class_cond=None, mapping_cond=None, degradation_params=None):
|
724 |
+
# Patching
|
725 |
+
x = x.movedim(-3, -1)
|
726 |
+
x = self.patch_in(x)
|
727 |
+
# TODO: pixel aspect ratio for nonsquare patches
|
728 |
+
pos = make_axial_pos(x.shape[-3], x.shape[-2], device=x.device).view(x.shape[-3], x.shape[-2], 2)
|
729 |
+
|
730 |
+
# Mapping network
|
731 |
+
if class_cond is None and self.class_emb is not None:
|
732 |
+
raise ValueError("class_cond must be specified if num_classes > 0")
|
733 |
+
if mapping_cond is None and self.mapping_cond_in_proj is not None:
|
734 |
+
raise ValueError("mapping_cond must be specified if mapping_cond_dim > 0")
|
735 |
+
|
736 |
+
# c_noise = torch.log(sigma) / 4
|
737 |
+
# c_noise = (sigma * 2.0 - 1.0)
|
738 |
+
# c_noise = sigma * 2 - 1
|
739 |
+
if sigma is not None:
|
740 |
+
time_emb = self.time_in_proj(self.time_emb(sigma[..., None]))
|
741 |
+
else:
|
742 |
+
time_emb = self.time_in_proj(torch.ones(1, 1, device=x.device, dtype=x.dtype).expand(x.shape[0], self.mapping_width))
|
743 |
+
# time_emb = self.time_in_proj(sigma[..., None])
|
744 |
+
|
745 |
+
aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond
|
746 |
+
aug_emb = self.aug_in_proj(self.aug_emb(aug_cond))
|
747 |
+
class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0
|
748 |
+
mapping_emb = self.mapping_cond_in_proj(mapping_cond) if self.mapping_cond_in_proj is not None else 0
|
749 |
+
degradation_emb = self.degradation_proj(degradation_params) if degradation_params is not None else 0
|
750 |
+
cond = self.mapping(time_emb + aug_emb + class_emb + mapping_emb + degradation_emb)
|
751 |
+
|
752 |
+
# Hourglass transformer
|
753 |
+
skips, poses = [], []
|
754 |
+
for down_level, merge in zip(self.down_levels, self.merges):
|
755 |
+
x = down_level(x, pos, cond)
|
756 |
+
skips.append(x)
|
757 |
+
poses.append(pos)
|
758 |
+
x = merge(x)
|
759 |
+
pos = downscale_pos(pos)
|
760 |
+
|
761 |
+
x = self.mid_level(x, pos, cond)
|
762 |
+
|
763 |
+
for up_level, split, skip, pos in reversed(list(zip(self.up_levels, self.splits, skips, poses))):
|
764 |
+
x = split(x, skip)
|
765 |
+
x = up_level(x, pos, cond)
|
766 |
+
|
767 |
+
# Unpatching
|
768 |
+
x = self.out_norm(x)
|
769 |
+
x = self.patch_out(x)
|
770 |
+
x = x.movedim(-1, -3)
|
771 |
+
|
772 |
+
return x
|
arch/swinir/__init__.py
ADDED
File without changes
|
arch/swinir/swinir.py
ADDED
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -----------------------------------------------------------------------------------
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
# -----------------------------------------------------------------------------------
|
5 |
+
# Borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py)
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Set
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.utils.checkpoint as checkpoint
|
14 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features or in_features
|
21 |
+
hidden_features = hidden_features or in_features
|
22 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
23 |
+
self.act = act_layer()
|
24 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
25 |
+
self.drop = nn.Dropout(drop)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
x = self.fc1(x)
|
29 |
+
x = self.act(x)
|
30 |
+
x = self.drop(x)
|
31 |
+
x = self.fc2(x)
|
32 |
+
x = self.drop(x)
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
def window_partition(x, window_size):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
x: (B, H, W, C)
|
40 |
+
window_size (int): window size
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
windows: (num_windows*B, window_size, window_size, C)
|
44 |
+
"""
|
45 |
+
B, H, W, C = x.shape
|
46 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
47 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
48 |
+
return windows
|
49 |
+
|
50 |
+
|
51 |
+
def window_reverse(windows, window_size, H, W):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
windows: (num_windows*B, window_size, window_size, C)
|
55 |
+
window_size (int): Window size
|
56 |
+
H (int): Height of image
|
57 |
+
W (int): Width of image
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
x: (B, H, W, C)
|
61 |
+
"""
|
62 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
63 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
64 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
65 |
+
return x
|
66 |
+
|
67 |
+
|
68 |
+
class WindowAttention(nn.Module):
|
69 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
70 |
+
It supports both of shifted and non-shifted window.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
dim (int): Number of input channels.
|
74 |
+
window_size (tuple[int]): The height and width of the window.
|
75 |
+
num_heads (int): Number of attention heads.
|
76 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
77 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
78 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
79 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
83 |
+
|
84 |
+
super().__init__()
|
85 |
+
self.dim = dim
|
86 |
+
self.window_size = window_size # Wh, Ww
|
87 |
+
self.num_heads = num_heads
|
88 |
+
head_dim = dim // num_heads
|
89 |
+
self.scale = qk_scale or head_dim ** -0.5
|
90 |
+
|
91 |
+
# define a parameter table of relative position bias
|
92 |
+
self.relative_position_bias_table = nn.Parameter(
|
93 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
94 |
+
|
95 |
+
# get pair-wise relative position index for each token inside the window
|
96 |
+
coords_h = torch.arange(self.window_size[0])
|
97 |
+
coords_w = torch.arange(self.window_size[1])
|
98 |
+
# coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
99 |
+
# Fix: Pass indexing="ij" to avoid warning
|
100 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
|
101 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
102 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
103 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
104 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
105 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
106 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
107 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
108 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
109 |
+
|
110 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
111 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
112 |
+
self.proj = nn.Linear(dim, dim)
|
113 |
+
|
114 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
115 |
+
|
116 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
117 |
+
self.softmax = nn.Softmax(dim=-1)
|
118 |
+
|
119 |
+
def forward(self, x, mask=None):
|
120 |
+
"""
|
121 |
+
Args:
|
122 |
+
x: input features with shape of (num_windows*B, N, C)
|
123 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
124 |
+
"""
|
125 |
+
B_, N, C = x.shape
|
126 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
127 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
128 |
+
|
129 |
+
q = q * self.scale
|
130 |
+
attn = (q @ k.transpose(-2, -1))
|
131 |
+
|
132 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
133 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
134 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
135 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
136 |
+
|
137 |
+
if mask is not None:
|
138 |
+
nW = mask.shape[0]
|
139 |
+
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
140 |
+
attn = attn.view(-1, self.num_heads, N, N)
|
141 |
+
attn = self.softmax(attn)
|
142 |
+
else:
|
143 |
+
attn = self.softmax(attn)
|
144 |
+
|
145 |
+
attn = self.attn_drop(attn)
|
146 |
+
|
147 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
148 |
+
x = self.proj(x)
|
149 |
+
x = self.proj_drop(x)
|
150 |
+
return x
|
151 |
+
|
152 |
+
def extra_repr(self) -> str:
|
153 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
154 |
+
|
155 |
+
def flops(self, N):
|
156 |
+
# calculate flops for 1 window with token length of N
|
157 |
+
flops = 0
|
158 |
+
# qkv = self.qkv(x)
|
159 |
+
flops += N * self.dim * 3 * self.dim
|
160 |
+
# attn = (q @ k.transpose(-2, -1))
|
161 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
162 |
+
# x = (attn @ v)
|
163 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
164 |
+
# x = self.proj(x)
|
165 |
+
flops += N * self.dim * self.dim
|
166 |
+
return flops
|
167 |
+
|
168 |
+
|
169 |
+
class SwinTransformerBlock(nn.Module):
|
170 |
+
r""" Swin Transformer Block.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
dim (int): Number of input channels.
|
174 |
+
input_resolution (tuple[int]): Input resulotion.
|
175 |
+
num_heads (int): Number of attention heads.
|
176 |
+
window_size (int): Window size.
|
177 |
+
shift_size (int): Shift size for SW-MSA.
|
178 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
179 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
180 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
181 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
182 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
183 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
184 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
185 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
|
189 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
190 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
191 |
+
super().__init__()
|
192 |
+
self.dim = dim
|
193 |
+
self.input_resolution = input_resolution
|
194 |
+
self.num_heads = num_heads
|
195 |
+
self.window_size = window_size
|
196 |
+
self.shift_size = shift_size
|
197 |
+
self.mlp_ratio = mlp_ratio
|
198 |
+
if min(self.input_resolution) <= self.window_size:
|
199 |
+
# if window size is larger than input resolution, we don't partition windows
|
200 |
+
self.shift_size = 0
|
201 |
+
self.window_size = min(self.input_resolution)
|
202 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
203 |
+
|
204 |
+
self.norm1 = norm_layer(dim)
|
205 |
+
self.attn = WindowAttention(
|
206 |
+
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
|
207 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
208 |
+
|
209 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
210 |
+
self.norm2 = norm_layer(dim)
|
211 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
212 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
213 |
+
|
214 |
+
if self.shift_size > 0:
|
215 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
216 |
+
else:
|
217 |
+
attn_mask = None
|
218 |
+
|
219 |
+
self.register_buffer("attn_mask", attn_mask)
|
220 |
+
|
221 |
+
def calculate_mask(self, x_size):
|
222 |
+
# calculate attention mask for SW-MSA
|
223 |
+
H, W = x_size
|
224 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
225 |
+
h_slices = (slice(0, -self.window_size),
|
226 |
+
slice(-self.window_size, -self.shift_size),
|
227 |
+
slice(-self.shift_size, None))
|
228 |
+
w_slices = (slice(0, -self.window_size),
|
229 |
+
slice(-self.window_size, -self.shift_size),
|
230 |
+
slice(-self.shift_size, None))
|
231 |
+
cnt = 0
|
232 |
+
for h in h_slices:
|
233 |
+
for w in w_slices:
|
234 |
+
img_mask[:, h, w, :] = cnt
|
235 |
+
cnt += 1
|
236 |
+
|
237 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
238 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
239 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
240 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
241 |
+
|
242 |
+
return attn_mask
|
243 |
+
|
244 |
+
def forward(self, x, x_size):
|
245 |
+
H, W = x_size
|
246 |
+
B, L, C = x.shape
|
247 |
+
# assert L == H * W, "input feature has wrong size"
|
248 |
+
|
249 |
+
shortcut = x
|
250 |
+
x = self.norm1(x)
|
251 |
+
x = x.view(B, H, W, C)
|
252 |
+
|
253 |
+
# cyclic shift
|
254 |
+
if self.shift_size > 0:
|
255 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
256 |
+
else:
|
257 |
+
shifted_x = x
|
258 |
+
|
259 |
+
# partition windows
|
260 |
+
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
261 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
|
262 |
+
|
263 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
264 |
+
if self.input_resolution == x_size:
|
265 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
266 |
+
else:
|
267 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
268 |
+
|
269 |
+
# merge windows
|
270 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
271 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
272 |
+
|
273 |
+
# reverse cyclic shift
|
274 |
+
if self.shift_size > 0:
|
275 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
276 |
+
else:
|
277 |
+
x = shifted_x
|
278 |
+
x = x.view(B, H * W, C)
|
279 |
+
|
280 |
+
# FFN
|
281 |
+
x = shortcut + self.drop_path(x)
|
282 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
283 |
+
|
284 |
+
return x
|
285 |
+
|
286 |
+
def extra_repr(self) -> str:
|
287 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
288 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
289 |
+
|
290 |
+
def flops(self):
|
291 |
+
flops = 0
|
292 |
+
H, W = self.input_resolution
|
293 |
+
# norm1
|
294 |
+
flops += self.dim * H * W
|
295 |
+
# W-MSA/SW-MSA
|
296 |
+
nW = H * W / self.window_size / self.window_size
|
297 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
298 |
+
# mlp
|
299 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
300 |
+
# norm2
|
301 |
+
flops += self.dim * H * W
|
302 |
+
return flops
|
303 |
+
|
304 |
+
|
305 |
+
class PatchMerging(nn.Module):
|
306 |
+
r""" Patch Merging Layer.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
310 |
+
dim (int): Number of input channels.
|
311 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
315 |
+
super().__init__()
|
316 |
+
self.input_resolution = input_resolution
|
317 |
+
self.dim = dim
|
318 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
319 |
+
self.norm = norm_layer(4 * dim)
|
320 |
+
|
321 |
+
def forward(self, x):
|
322 |
+
"""
|
323 |
+
x: B, H*W, C
|
324 |
+
"""
|
325 |
+
H, W = self.input_resolution
|
326 |
+
B, L, C = x.shape
|
327 |
+
assert L == H * W, "input feature has wrong size"
|
328 |
+
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
|
329 |
+
|
330 |
+
x = x.view(B, H, W, C)
|
331 |
+
|
332 |
+
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
333 |
+
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
334 |
+
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
335 |
+
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
336 |
+
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
337 |
+
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
338 |
+
|
339 |
+
x = self.norm(x)
|
340 |
+
x = self.reduction(x)
|
341 |
+
|
342 |
+
return x
|
343 |
+
|
344 |
+
def extra_repr(self) -> str:
|
345 |
+
return f"input_resolution={self.input_resolution}, dim={self.dim}"
|
346 |
+
|
347 |
+
def flops(self):
|
348 |
+
H, W = self.input_resolution
|
349 |
+
flops = H * W * self.dim
|
350 |
+
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
351 |
+
return flops
|
352 |
+
|
353 |
+
|
354 |
+
class BasicLayer(nn.Module):
|
355 |
+
""" A basic Swin Transformer layer for one stage.
|
356 |
+
|
357 |
+
Args:
|
358 |
+
dim (int): Number of input channels.
|
359 |
+
input_resolution (tuple[int]): Input resolution.
|
360 |
+
depth (int): Number of blocks.
|
361 |
+
num_heads (int): Number of attention heads.
|
362 |
+
window_size (int): Local window size.
|
363 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
364 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
365 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
366 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
367 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
368 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
369 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
370 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
371 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
372 |
+
"""
|
373 |
+
|
374 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
375 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
376 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
|
377 |
+
|
378 |
+
super().__init__()
|
379 |
+
self.dim = dim
|
380 |
+
self.input_resolution = input_resolution
|
381 |
+
self.depth = depth
|
382 |
+
self.use_checkpoint = use_checkpoint
|
383 |
+
|
384 |
+
# build blocks
|
385 |
+
self.blocks = nn.ModuleList([
|
386 |
+
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
|
387 |
+
num_heads=num_heads, window_size=window_size,
|
388 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
389 |
+
mlp_ratio=mlp_ratio,
|
390 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
391 |
+
drop=drop, attn_drop=attn_drop,
|
392 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
393 |
+
norm_layer=norm_layer)
|
394 |
+
for i in range(depth)])
|
395 |
+
|
396 |
+
# patch merging layer
|
397 |
+
if downsample is not None:
|
398 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
399 |
+
else:
|
400 |
+
self.downsample = None
|
401 |
+
|
402 |
+
def forward(self, x, x_size):
|
403 |
+
for blk in self.blocks:
|
404 |
+
if self.use_checkpoint:
|
405 |
+
x = checkpoint.checkpoint(blk, x, x_size)
|
406 |
+
else:
|
407 |
+
x = blk(x, x_size)
|
408 |
+
if self.downsample is not None:
|
409 |
+
x = self.downsample(x)
|
410 |
+
return x
|
411 |
+
|
412 |
+
def extra_repr(self) -> str:
|
413 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
414 |
+
|
415 |
+
def flops(self):
|
416 |
+
flops = 0
|
417 |
+
for blk in self.blocks:
|
418 |
+
flops += blk.flops()
|
419 |
+
if self.downsample is not None:
|
420 |
+
flops += self.downsample.flops()
|
421 |
+
return flops
|
422 |
+
|
423 |
+
|
424 |
+
class RSTB(nn.Module):
|
425 |
+
"""Residual Swin Transformer Block (RSTB).
|
426 |
+
|
427 |
+
Args:
|
428 |
+
dim (int): Number of input channels.
|
429 |
+
input_resolution (tuple[int]): Input resolution.
|
430 |
+
depth (int): Number of blocks.
|
431 |
+
num_heads (int): Number of attention heads.
|
432 |
+
window_size (int): Local window size.
|
433 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
434 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
435 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
436 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
437 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
438 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
439 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
440 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
441 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
442 |
+
img_size: Input image size.
|
443 |
+
patch_size: Patch size.
|
444 |
+
resi_connection: The convolutional block before residual connection.
|
445 |
+
"""
|
446 |
+
|
447 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
|
448 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
449 |
+
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
|
450 |
+
img_size=224, patch_size=4, resi_connection='1conv'):
|
451 |
+
super(RSTB, self).__init__()
|
452 |
+
|
453 |
+
self.dim = dim
|
454 |
+
self.input_resolution = input_resolution
|
455 |
+
|
456 |
+
self.residual_group = BasicLayer(dim=dim,
|
457 |
+
input_resolution=input_resolution,
|
458 |
+
depth=depth,
|
459 |
+
num_heads=num_heads,
|
460 |
+
window_size=window_size,
|
461 |
+
mlp_ratio=mlp_ratio,
|
462 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
463 |
+
drop=drop, attn_drop=attn_drop,
|
464 |
+
drop_path=drop_path,
|
465 |
+
norm_layer=norm_layer,
|
466 |
+
downsample=downsample,
|
467 |
+
use_checkpoint=use_checkpoint)
|
468 |
+
|
469 |
+
if resi_connection == '1conv':
|
470 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
471 |
+
elif resi_connection == '3conv':
|
472 |
+
# to save parameters and memory
|
473 |
+
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
474 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
475 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
476 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
477 |
+
|
478 |
+
self.patch_embed = PatchEmbed(
|
479 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
480 |
+
norm_layer=None)
|
481 |
+
|
482 |
+
self.patch_unembed = PatchUnEmbed(
|
483 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
|
484 |
+
norm_layer=None)
|
485 |
+
|
486 |
+
def forward(self, x, x_size):
|
487 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
488 |
+
|
489 |
+
def flops(self):
|
490 |
+
flops = 0
|
491 |
+
flops += self.residual_group.flops()
|
492 |
+
H, W = self.input_resolution
|
493 |
+
flops += H * W * self.dim * self.dim * 9
|
494 |
+
flops += self.patch_embed.flops()
|
495 |
+
flops += self.patch_unembed.flops()
|
496 |
+
|
497 |
+
return flops
|
498 |
+
|
499 |
+
|
500 |
+
class PatchEmbed(nn.Module):
|
501 |
+
r""" Image to Patch Embedding
|
502 |
+
|
503 |
+
Args:
|
504 |
+
img_size (int): Image size. Default: 224.
|
505 |
+
patch_size (int): Patch token size. Default: 4.
|
506 |
+
in_chans (int): Number of input image channels. Default: 3.
|
507 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
508 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
509 |
+
"""
|
510 |
+
|
511 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
512 |
+
super().__init__()
|
513 |
+
img_size = to_2tuple(img_size)
|
514 |
+
patch_size = to_2tuple(patch_size)
|
515 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
516 |
+
self.img_size = img_size
|
517 |
+
self.patch_size = patch_size
|
518 |
+
self.patches_resolution = patches_resolution
|
519 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
520 |
+
|
521 |
+
self.in_chans = in_chans
|
522 |
+
self.embed_dim = embed_dim
|
523 |
+
|
524 |
+
if norm_layer is not None:
|
525 |
+
self.norm = norm_layer(embed_dim)
|
526 |
+
else:
|
527 |
+
self.norm = None
|
528 |
+
|
529 |
+
def forward(self, x):
|
530 |
+
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
|
531 |
+
if self.norm is not None:
|
532 |
+
x = self.norm(x)
|
533 |
+
return x
|
534 |
+
|
535 |
+
def flops(self):
|
536 |
+
flops = 0
|
537 |
+
H, W = self.img_size
|
538 |
+
if self.norm is not None:
|
539 |
+
flops += H * W * self.embed_dim
|
540 |
+
return flops
|
541 |
+
|
542 |
+
|
543 |
+
class PatchUnEmbed(nn.Module):
|
544 |
+
r""" Image to Patch Unembedding
|
545 |
+
|
546 |
+
Args:
|
547 |
+
img_size (int): Image size. Default: 224.
|
548 |
+
patch_size (int): Patch token size. Default: 4.
|
549 |
+
in_chans (int): Number of input image channels. Default: 3.
|
550 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
551 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
555 |
+
super().__init__()
|
556 |
+
img_size = to_2tuple(img_size)
|
557 |
+
patch_size = to_2tuple(patch_size)
|
558 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
559 |
+
self.img_size = img_size
|
560 |
+
self.patch_size = patch_size
|
561 |
+
self.patches_resolution = patches_resolution
|
562 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
563 |
+
|
564 |
+
self.in_chans = in_chans
|
565 |
+
self.embed_dim = embed_dim
|
566 |
+
|
567 |
+
def forward(self, x, x_size):
|
568 |
+
B, HW, C = x.shape
|
569 |
+
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
|
570 |
+
return x
|
571 |
+
|
572 |
+
def flops(self):
|
573 |
+
flops = 0
|
574 |
+
return flops
|
575 |
+
|
576 |
+
|
577 |
+
class Upsample(nn.Sequential):
|
578 |
+
"""Upsample module.
|
579 |
+
|
580 |
+
Args:
|
581 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
582 |
+
num_feat (int): Channel number of intermediate features.
|
583 |
+
"""
|
584 |
+
|
585 |
+
def __init__(self, scale, num_feat):
|
586 |
+
m = []
|
587 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
588 |
+
for _ in range(int(math.log(scale, 2))):
|
589 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
590 |
+
m.append(nn.PixelShuffle(2))
|
591 |
+
elif scale == 3:
|
592 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
593 |
+
m.append(nn.PixelShuffle(3))
|
594 |
+
else:
|
595 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
596 |
+
super(Upsample, self).__init__(*m)
|
597 |
+
|
598 |
+
|
599 |
+
class UpsampleOneStep(nn.Sequential):
|
600 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
601 |
+
Used in lightweight SR to save parameters.
|
602 |
+
|
603 |
+
Args:
|
604 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
605 |
+
num_feat (int): Channel number of intermediate features.
|
606 |
+
|
607 |
+
"""
|
608 |
+
|
609 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
610 |
+
self.num_feat = num_feat
|
611 |
+
self.input_resolution = input_resolution
|
612 |
+
m = []
|
613 |
+
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
|
614 |
+
m.append(nn.PixelShuffle(scale))
|
615 |
+
super(UpsampleOneStep, self).__init__(*m)
|
616 |
+
|
617 |
+
def flops(self):
|
618 |
+
H, W = self.input_resolution
|
619 |
+
flops = H * W * self.num_feat * 3 * 9
|
620 |
+
return flops
|
621 |
+
|
622 |
+
|
623 |
+
class SwinIR(nn.Module):
|
624 |
+
r""" SwinIR
|
625 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
626 |
+
|
627 |
+
Args:
|
628 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
629 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
630 |
+
in_chans (int): Number of input image channels. Default: 3
|
631 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
632 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
633 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
634 |
+
window_size (int): Window size. Default: 7
|
635 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
636 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
637 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
638 |
+
drop_rate (float): Dropout rate. Default: 0
|
639 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
640 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
641 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
642 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
643 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
644 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
645 |
+
sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
646 |
+
img_range: Image range. 1. or 255.
|
647 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
648 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
649 |
+
"""
|
650 |
+
|
651 |
+
def __init__(
|
652 |
+
self,
|
653 |
+
img_size=64,
|
654 |
+
patch_size=1,
|
655 |
+
in_chans=3,
|
656 |
+
num_out_ch=3,
|
657 |
+
embed_dim=96,
|
658 |
+
depths=[6, 6, 6, 6],
|
659 |
+
num_heads=[6, 6, 6, 6],
|
660 |
+
window_size=7,
|
661 |
+
mlp_ratio=4.,
|
662 |
+
qkv_bias=True,
|
663 |
+
qk_scale=None,
|
664 |
+
drop_rate=0.,
|
665 |
+
attn_drop_rate=0.,
|
666 |
+
drop_path_rate=0.1,
|
667 |
+
norm_layer=nn.LayerNorm,
|
668 |
+
ape=False,
|
669 |
+
patch_norm=True,
|
670 |
+
use_checkpoint=False,
|
671 |
+
sf=4,
|
672 |
+
img_range=1.,
|
673 |
+
upsampler='',
|
674 |
+
resi_connection='1conv',
|
675 |
+
unshuffle=False,
|
676 |
+
unshuffle_scale=None,
|
677 |
+
hq_key: str = "jpg",
|
678 |
+
lq_key: str = "hint",
|
679 |
+
learning_rate: float = None,
|
680 |
+
weight_decay: float = None
|
681 |
+
) -> "SwinIR":
|
682 |
+
super(SwinIR, self).__init__()
|
683 |
+
num_in_ch = in_chans * (unshuffle_scale ** 2) if unshuffle else in_chans
|
684 |
+
num_feat = 64
|
685 |
+
self.img_range = img_range
|
686 |
+
if in_chans == 3:
|
687 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
688 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
689 |
+
else:
|
690 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
691 |
+
self.upscale = sf
|
692 |
+
self.upsampler = upsampler
|
693 |
+
self.window_size = window_size
|
694 |
+
self.unshuffle_scale = unshuffle_scale
|
695 |
+
self.unshuffle = unshuffle
|
696 |
+
|
697 |
+
#####################################################################################################
|
698 |
+
################################### 1, shallow feature extraction ###################################
|
699 |
+
if unshuffle:
|
700 |
+
assert unshuffle_scale is not None
|
701 |
+
self.conv_first = nn.Sequential(
|
702 |
+
nn.PixelUnshuffle(sf),
|
703 |
+
nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1),
|
704 |
+
)
|
705 |
+
else:
|
706 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
707 |
+
|
708 |
+
#####################################################################################################
|
709 |
+
################################### 2, deep feature extraction ######################################
|
710 |
+
self.num_layers = len(depths)
|
711 |
+
self.embed_dim = embed_dim
|
712 |
+
self.ape = ape
|
713 |
+
self.patch_norm = patch_norm
|
714 |
+
self.num_features = embed_dim
|
715 |
+
self.mlp_ratio = mlp_ratio
|
716 |
+
|
717 |
+
# split image into non-overlapping patches
|
718 |
+
self.patch_embed = PatchEmbed(
|
719 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
720 |
+
norm_layer=norm_layer if self.patch_norm else None
|
721 |
+
)
|
722 |
+
num_patches = self.patch_embed.num_patches
|
723 |
+
patches_resolution = self.patch_embed.patches_resolution
|
724 |
+
self.patches_resolution = patches_resolution
|
725 |
+
|
726 |
+
# merge non-overlapping patches into image
|
727 |
+
self.patch_unembed = PatchUnEmbed(
|
728 |
+
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
|
729 |
+
norm_layer=norm_layer if self.patch_norm else None
|
730 |
+
)
|
731 |
+
|
732 |
+
# absolute position embedding
|
733 |
+
if self.ape:
|
734 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
735 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
736 |
+
|
737 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
738 |
+
|
739 |
+
# stochastic depth
|
740 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
741 |
+
|
742 |
+
# build Residual Swin Transformer blocks (RSTB)
|
743 |
+
self.layers = nn.ModuleList()
|
744 |
+
for i_layer in range(self.num_layers):
|
745 |
+
layer = RSTB(
|
746 |
+
dim=embed_dim,
|
747 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
748 |
+
depth=depths[i_layer],
|
749 |
+
num_heads=num_heads[i_layer],
|
750 |
+
window_size=window_size,
|
751 |
+
mlp_ratio=self.mlp_ratio,
|
752 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
753 |
+
drop=drop_rate, attn_drop=attn_drop_rate,
|
754 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
755 |
+
norm_layer=norm_layer,
|
756 |
+
downsample=None,
|
757 |
+
use_checkpoint=use_checkpoint,
|
758 |
+
img_size=img_size,
|
759 |
+
patch_size=patch_size,
|
760 |
+
resi_connection=resi_connection
|
761 |
+
)
|
762 |
+
self.layers.append(layer)
|
763 |
+
self.norm = norm_layer(self.num_features)
|
764 |
+
|
765 |
+
# build the last conv layer in deep feature extraction
|
766 |
+
if resi_connection == '1conv':
|
767 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
768 |
+
elif resi_connection == '3conv':
|
769 |
+
# to save parameters and memory
|
770 |
+
self.conv_after_body = nn.Sequential(
|
771 |
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
|
772 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
773 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
|
774 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
775 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)
|
776 |
+
)
|
777 |
+
|
778 |
+
#####################################################################################################
|
779 |
+
################################ 3, high quality image reconstruction ################################
|
780 |
+
if self.upsampler == 'pixelshuffle':
|
781 |
+
# for classical SR
|
782 |
+
self.conv_before_upsample = nn.Sequential(
|
783 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
784 |
+
nn.LeakyReLU(inplace=True)
|
785 |
+
)
|
786 |
+
self.upsample = Upsample(sf, num_feat)
|
787 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
788 |
+
elif self.upsampler == 'pixelshuffledirect':
|
789 |
+
# for lightweight SR (to save parameters)
|
790 |
+
self.upsample = UpsampleOneStep(
|
791 |
+
sf, embed_dim, num_out_ch,
|
792 |
+
(patches_resolution[0], patches_resolution[1])
|
793 |
+
)
|
794 |
+
elif self.upsampler == 'nearest+conv':
|
795 |
+
# for real-world SR (less artifacts)
|
796 |
+
self.conv_before_upsample = nn.Sequential(
|
797 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
798 |
+
nn.LeakyReLU(inplace=True)
|
799 |
+
)
|
800 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
801 |
+
if self.upscale == 4:
|
802 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
803 |
+
elif self.upscale == 8:
|
804 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
805 |
+
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
806 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
807 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
808 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
809 |
+
else:
|
810 |
+
# for image denoising and JPEG compression artifact reduction
|
811 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
812 |
+
|
813 |
+
self.apply(self._init_weights)
|
814 |
+
|
815 |
+
def _init_weights(self, m: nn.Module) -> None:
|
816 |
+
if isinstance(m, nn.Linear):
|
817 |
+
trunc_normal_(m.weight, std=.02)
|
818 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
819 |
+
nn.init.constant_(m.bias, 0)
|
820 |
+
elif isinstance(m, nn.LayerNorm):
|
821 |
+
nn.init.constant_(m.bias, 0)
|
822 |
+
nn.init.constant_(m.weight, 1.0)
|
823 |
+
|
824 |
+
# TODO: What's this ?
|
825 |
+
@torch.jit.ignore
|
826 |
+
def no_weight_decay(self) -> Set[str]:
|
827 |
+
return {'absolute_pos_embed'}
|
828 |
+
|
829 |
+
@torch.jit.ignore
|
830 |
+
def no_weight_decay_keywords(self) -> Set[str]:
|
831 |
+
return {'relative_position_bias_table'}
|
832 |
+
|
833 |
+
def check_image_size(self, x: torch.Tensor) -> torch.Tensor:
|
834 |
+
_, _, h, w = x.size()
|
835 |
+
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
|
836 |
+
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
|
837 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
|
838 |
+
return x
|
839 |
+
|
840 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
841 |
+
x_size = (x.shape[2], x.shape[3])
|
842 |
+
x = self.patch_embed(x)
|
843 |
+
if self.ape:
|
844 |
+
x = x + self.absolute_pos_embed
|
845 |
+
x = self.pos_drop(x)
|
846 |
+
|
847 |
+
for layer in self.layers:
|
848 |
+
x = layer(x, x_size)
|
849 |
+
|
850 |
+
x = self.norm(x) # B L C
|
851 |
+
x = self.patch_unembed(x, x_size)
|
852 |
+
|
853 |
+
return x
|
854 |
+
|
855 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
856 |
+
H, W = x.shape[2:]
|
857 |
+
x = self.check_image_size(x)
|
858 |
+
|
859 |
+
self.mean = self.mean.type_as(x)
|
860 |
+
x = (x - self.mean) * self.img_range
|
861 |
+
|
862 |
+
if self.upsampler == 'pixelshuffle':
|
863 |
+
# for classical SR
|
864 |
+
x = self.conv_first(x)
|
865 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
866 |
+
x = self.conv_before_upsample(x)
|
867 |
+
x = self.conv_last(self.upsample(x))
|
868 |
+
elif self.upsampler == 'pixelshuffledirect':
|
869 |
+
# for lightweight SR
|
870 |
+
x = self.conv_first(x)
|
871 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
872 |
+
x = self.upsample(x)
|
873 |
+
elif self.upsampler == 'nearest+conv':
|
874 |
+
# for real-world SR
|
875 |
+
x = self.conv_first(x)
|
876 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
877 |
+
x = self.conv_before_upsample(x)
|
878 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
879 |
+
if self.upscale == 4:
|
880 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
881 |
+
elif self.upscale == 8:
|
882 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
883 |
+
x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
884 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
885 |
+
else:
|
886 |
+
# for image denoising and JPEG compression artifact reduction
|
887 |
+
x_first = self.conv_first(x)
|
888 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
889 |
+
x = x + self.conv_last(res)
|
890 |
+
|
891 |
+
x = x / self.img_range + self.mean
|
892 |
+
|
893 |
+
return x[:, :, :H * self.upscale, :W * self.upscale]
|
894 |
+
|
895 |
+
def flops(self) -> int:
|
896 |
+
flops = 0
|
897 |
+
H, W = self.patches_resolution
|
898 |
+
flops += H * W * 3 * self.embed_dim * 9
|
899 |
+
flops += self.patch_embed.flops()
|
900 |
+
for i, layer in enumerate(self.layers):
|
901 |
+
flops += layer.flops()
|
902 |
+
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
903 |
+
flops += self.upsample.flops()
|
904 |
+
return flops
|
packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
ffmpeg
|
2 |
+
libsm6
|
3 |
+
libxext6
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.2.2
|
2 |
+
facexlib==0.2.5
|
3 |
+
realesrgan==0.2.5
|
4 |
+
numpy
|
5 |
+
opencv-python
|
6 |
+
torchvision
|
7 |
+
pytorch-lightning==2.4.0
|
8 |
+
scipy
|
9 |
+
tqdm
|
10 |
+
lmdb
|
11 |
+
pyyaml
|
12 |
+
basicsr==1.4.2
|
13 |
+
yapf
|
14 |
+
dctorch
|
15 |
+
einops
|
16 |
+
torch-ema==0.3
|
17 |
+
huggingface_hub==0.24.5
|
18 |
+
natten==0.17.1
|
19 |
+
wandb
|
20 |
+
timm
|
21 |
+
huggingface_hub==0.24.5
|
utils/__init__.py
ADDED
File without changes
|
utils/basicsr_custom.py
ADDED
@@ -0,0 +1,954 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/data/degradations.py
|
2 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
3 |
+
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
|
4 |
+
|
5 |
+
import math
|
6 |
+
import random
|
7 |
+
import re
|
8 |
+
from abc import ABCMeta, abstractmethod
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import List, Dict
|
11 |
+
from typing import Mapping, Any
|
12 |
+
from typing import Optional, Union
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from PIL import Image
|
18 |
+
from scipy import special
|
19 |
+
from scipy.stats import multivariate_normal
|
20 |
+
from torch import Tensor
|
21 |
+
# from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
22 |
+
from torchvision.transforms._functional_tensor import rgb_to_grayscale
|
23 |
+
|
24 |
+
|
25 |
+
# -------------------------------------------------------------------- #
|
26 |
+
# --------------------------- blur kernels --------------------------- #
|
27 |
+
# -------------------------------------------------------------------- #
|
28 |
+
|
29 |
+
|
30 |
+
# --------------------------- util functions --------------------------- #
|
31 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
32 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
33 |
+
|
34 |
+
Args:
|
35 |
+
sig_x (float):
|
36 |
+
sig_y (float):
|
37 |
+
theta (float): Radian measurement.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
ndarray: Rotated sigma matrix.
|
41 |
+
"""
|
42 |
+
d_matrix = np.array([[sig_x ** 2, 0], [0, sig_y ** 2]])
|
43 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
44 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
45 |
+
|
46 |
+
|
47 |
+
def mesh_grid(kernel_size):
|
48 |
+
"""Generate the mesh grid, centering at zero.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
kernel_size (int):
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
55 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
56 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
57 |
+
"""
|
58 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
59 |
+
xx, yy = np.meshgrid(ax, ax)
|
60 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
61 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
62 |
+
return xy, xx, yy
|
63 |
+
|
64 |
+
|
65 |
+
def pdf2(sigma_matrix, grid):
|
66 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
70 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
71 |
+
with the shape (K, K, 2), K is the kernel size.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
kernel (ndarrray): un-normalized kernel.
|
75 |
+
"""
|
76 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
77 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
78 |
+
return kernel
|
79 |
+
|
80 |
+
|
81 |
+
def cdf2(d_matrix, grid):
|
82 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
83 |
+
Used in skewed Gaussian distribution.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
d_matrix (ndarrasy): skew matrix.
|
87 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
88 |
+
with the shape (K, K, 2), K is the kernel size.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
cdf (ndarray): skewed cdf.
|
92 |
+
"""
|
93 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
94 |
+
grid = np.dot(grid, d_matrix)
|
95 |
+
cdf = rv.cdf(grid)
|
96 |
+
return cdf
|
97 |
+
|
98 |
+
|
99 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
100 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
101 |
+
|
102 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
kernel_size (int):
|
106 |
+
sig_x (float):
|
107 |
+
sig_y (float):
|
108 |
+
theta (float): Radian measurement.
|
109 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
110 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
111 |
+
isotropic (bool):
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
kernel (ndarray): normalized kernel.
|
115 |
+
"""
|
116 |
+
if grid is None:
|
117 |
+
grid, _, _ = mesh_grid(kernel_size)
|
118 |
+
if isotropic:
|
119 |
+
sigma_matrix = np.array([[sig_x ** 2, 0], [0, sig_x ** 2]])
|
120 |
+
else:
|
121 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
122 |
+
kernel = pdf2(sigma_matrix, grid)
|
123 |
+
kernel = kernel / np.sum(kernel)
|
124 |
+
return kernel
|
125 |
+
|
126 |
+
|
127 |
+
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
128 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
129 |
+
|
130 |
+
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
|
131 |
+
|
132 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
kernel_size (int):
|
136 |
+
sig_x (float):
|
137 |
+
sig_y (float):
|
138 |
+
theta (float): Radian measurement.
|
139 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
140 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
141 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
kernel (ndarray): normalized kernel.
|
145 |
+
"""
|
146 |
+
if grid is None:
|
147 |
+
grid, _, _ = mesh_grid(kernel_size)
|
148 |
+
if isotropic:
|
149 |
+
sigma_matrix = np.array([[sig_x ** 2, 0], [0, sig_x ** 2]])
|
150 |
+
else:
|
151 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
152 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
153 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
154 |
+
kernel = kernel / np.sum(kernel)
|
155 |
+
return kernel
|
156 |
+
|
157 |
+
|
158 |
+
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
159 |
+
"""Generate a plateau-like anisotropic kernel.
|
160 |
+
|
161 |
+
1 / (1+x^(beta))
|
162 |
+
|
163 |
+
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
164 |
+
|
165 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
kernel_size (int):
|
169 |
+
sig_x (float):
|
170 |
+
sig_y (float):
|
171 |
+
theta (float): Radian measurement.
|
172 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
173 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
174 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
kernel (ndarray): normalized kernel.
|
178 |
+
"""
|
179 |
+
if grid is None:
|
180 |
+
grid, _, _ = mesh_grid(kernel_size)
|
181 |
+
if isotropic:
|
182 |
+
sigma_matrix = np.array([[sig_x ** 2, 0], [0, sig_x ** 2]])
|
183 |
+
else:
|
184 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
185 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
186 |
+
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
187 |
+
kernel = kernel / np.sum(kernel)
|
188 |
+
return kernel
|
189 |
+
|
190 |
+
|
191 |
+
def random_bivariate_Gaussian(kernel_size,
|
192 |
+
sigma_x_range,
|
193 |
+
sigma_y_range,
|
194 |
+
rotation_range,
|
195 |
+
noise_range=None,
|
196 |
+
isotropic=True):
|
197 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
198 |
+
|
199 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
kernel_size (int):
|
203 |
+
sigma_x_range (tuple): [0.6, 5]
|
204 |
+
sigma_y_range (tuple): [0.6, 5]
|
205 |
+
rotation range (tuple): [-math.pi, math.pi]
|
206 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
207 |
+
[0.75, 1.25]. Default: None
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
kernel (ndarray):
|
211 |
+
"""
|
212 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
213 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
214 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
215 |
+
if isotropic is False:
|
216 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
217 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
218 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
219 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
220 |
+
else:
|
221 |
+
sigma_y = sigma_x
|
222 |
+
rotation = 0
|
223 |
+
|
224 |
+
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
225 |
+
|
226 |
+
# add multiplicative noise
|
227 |
+
if noise_range is not None:
|
228 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
229 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
230 |
+
kernel = kernel * noise
|
231 |
+
kernel = kernel / np.sum(kernel)
|
232 |
+
return kernel
|
233 |
+
|
234 |
+
|
235 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
236 |
+
sigma_x_range,
|
237 |
+
sigma_y_range,
|
238 |
+
rotation_range,
|
239 |
+
beta_range,
|
240 |
+
noise_range=None,
|
241 |
+
isotropic=True):
|
242 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
243 |
+
|
244 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
kernel_size (int):
|
248 |
+
sigma_x_range (tuple): [0.6, 5]
|
249 |
+
sigma_y_range (tuple): [0.6, 5]
|
250 |
+
rotation range (tuple): [-math.pi, math.pi]
|
251 |
+
beta_range (tuple): [0.5, 8]
|
252 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
253 |
+
[0.75, 1.25]. Default: None
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
kernel (ndarray):
|
257 |
+
"""
|
258 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
259 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
260 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
261 |
+
if isotropic is False:
|
262 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
263 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
264 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
265 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
266 |
+
else:
|
267 |
+
sigma_y = sigma_x
|
268 |
+
rotation = 0
|
269 |
+
|
270 |
+
# assume beta_range[0] < 1 < beta_range[1]
|
271 |
+
if np.random.uniform() < 0.5:
|
272 |
+
beta = np.random.uniform(beta_range[0], 1)
|
273 |
+
else:
|
274 |
+
beta = np.random.uniform(1, beta_range[1])
|
275 |
+
|
276 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
277 |
+
|
278 |
+
# add multiplicative noise
|
279 |
+
if noise_range is not None:
|
280 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
281 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
282 |
+
kernel = kernel * noise
|
283 |
+
kernel = kernel / np.sum(kernel)
|
284 |
+
return kernel
|
285 |
+
|
286 |
+
|
287 |
+
def random_bivariate_plateau(kernel_size,
|
288 |
+
sigma_x_range,
|
289 |
+
sigma_y_range,
|
290 |
+
rotation_range,
|
291 |
+
beta_range,
|
292 |
+
noise_range=None,
|
293 |
+
isotropic=True):
|
294 |
+
"""Randomly generate bivariate plateau kernels.
|
295 |
+
|
296 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
kernel_size (int):
|
300 |
+
sigma_x_range (tuple): [0.6, 5]
|
301 |
+
sigma_y_range (tuple): [0.6, 5]
|
302 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
303 |
+
beta_range (tuple): [1, 4]
|
304 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
305 |
+
[0.75, 1.25]. Default: None
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
kernel (ndarray):
|
309 |
+
"""
|
310 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
311 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
312 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
313 |
+
if isotropic is False:
|
314 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
315 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
316 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
317 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
318 |
+
else:
|
319 |
+
sigma_y = sigma_x
|
320 |
+
rotation = 0
|
321 |
+
|
322 |
+
# TODO: this may be not proper
|
323 |
+
if np.random.uniform() < 0.5:
|
324 |
+
beta = np.random.uniform(beta_range[0], 1)
|
325 |
+
else:
|
326 |
+
beta = np.random.uniform(1, beta_range[1])
|
327 |
+
|
328 |
+
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
329 |
+
# add multiplicative noise
|
330 |
+
if noise_range is not None:
|
331 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
332 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
333 |
+
kernel = kernel * noise
|
334 |
+
kernel = kernel / np.sum(kernel)
|
335 |
+
|
336 |
+
return kernel
|
337 |
+
|
338 |
+
|
339 |
+
def random_mixed_kernels(kernel_list,
|
340 |
+
kernel_prob,
|
341 |
+
kernel_size=21,
|
342 |
+
sigma_x_range=(0.6, 5),
|
343 |
+
sigma_y_range=(0.6, 5),
|
344 |
+
rotation_range=(-math.pi, math.pi),
|
345 |
+
betag_range=(0.5, 8),
|
346 |
+
betap_range=(0.5, 8),
|
347 |
+
noise_range=None):
|
348 |
+
"""Randomly generate mixed kernels.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
kernel_list (tuple): a list name of kernel types,
|
352 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
353 |
+
'plateau_aniso']
|
354 |
+
kernel_prob (tuple): corresponding kernel probability for each
|
355 |
+
kernel type
|
356 |
+
kernel_size (int):
|
357 |
+
sigma_x_range (tuple): [0.6, 5]
|
358 |
+
sigma_y_range (tuple): [0.6, 5]
|
359 |
+
rotation range (tuple): [-math.pi, math.pi]
|
360 |
+
beta_range (tuple): [0.5, 8]
|
361 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
362 |
+
[0.75, 1.25]. Default: None
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
kernel (ndarray):
|
366 |
+
"""
|
367 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
368 |
+
if kernel_type == 'iso':
|
369 |
+
kernel = random_bivariate_Gaussian(
|
370 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
371 |
+
elif kernel_type == 'aniso':
|
372 |
+
kernel = random_bivariate_Gaussian(
|
373 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
374 |
+
elif kernel_type == 'generalized_iso':
|
375 |
+
kernel = random_bivariate_generalized_Gaussian(
|
376 |
+
kernel_size,
|
377 |
+
sigma_x_range,
|
378 |
+
sigma_y_range,
|
379 |
+
rotation_range,
|
380 |
+
betag_range,
|
381 |
+
noise_range=noise_range,
|
382 |
+
isotropic=True)
|
383 |
+
elif kernel_type == 'generalized_aniso':
|
384 |
+
kernel = random_bivariate_generalized_Gaussian(
|
385 |
+
kernel_size,
|
386 |
+
sigma_x_range,
|
387 |
+
sigma_y_range,
|
388 |
+
rotation_range,
|
389 |
+
betag_range,
|
390 |
+
noise_range=noise_range,
|
391 |
+
isotropic=False)
|
392 |
+
elif kernel_type == 'plateau_iso':
|
393 |
+
kernel = random_bivariate_plateau(
|
394 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
395 |
+
elif kernel_type == 'plateau_aniso':
|
396 |
+
kernel = random_bivariate_plateau(
|
397 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
398 |
+
return kernel
|
399 |
+
|
400 |
+
|
401 |
+
np.seterr(divide='ignore', invalid='ignore')
|
402 |
+
|
403 |
+
|
404 |
+
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
405 |
+
"""2D sinc filter
|
406 |
+
|
407 |
+
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
408 |
+
|
409 |
+
Args:
|
410 |
+
cutoff (float): cutoff frequency in radians (pi is max)
|
411 |
+
kernel_size (int): horizontal and vertical size, must be odd.
|
412 |
+
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
413 |
+
"""
|
414 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
415 |
+
kernel = np.fromfunction(
|
416 |
+
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
417 |
+
(x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2)) / (2 * np.pi * np.sqrt(
|
418 |
+
(x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2)), [kernel_size, kernel_size])
|
419 |
+
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff ** 2 / (4 * np.pi)
|
420 |
+
kernel = kernel / np.sum(kernel)
|
421 |
+
if pad_to > kernel_size:
|
422 |
+
pad_size = (pad_to - kernel_size) // 2
|
423 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
424 |
+
return kernel
|
425 |
+
|
426 |
+
|
427 |
+
# ------------------------------------------------------------- #
|
428 |
+
# --------------------------- noise --------------------------- #
|
429 |
+
# ------------------------------------------------------------- #
|
430 |
+
|
431 |
+
# ----------------------- Gaussian Noise ----------------------- #
|
432 |
+
|
433 |
+
def instantiate_from_config(config: Mapping[str, Any]) -> Any:
|
434 |
+
if not "target" in config:
|
435 |
+
raise KeyError("Expected key `target` to instantiate.")
|
436 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
437 |
+
|
438 |
+
|
439 |
+
class BaseStorageBackend(metaclass=ABCMeta):
|
440 |
+
"""Abstract class of storage backends.
|
441 |
+
|
442 |
+
All backends need to implement two apis: ``get()`` and ``get_text()``.
|
443 |
+
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
|
444 |
+
as texts.
|
445 |
+
"""
|
446 |
+
|
447 |
+
@property
|
448 |
+
def name(self) -> str:
|
449 |
+
return self.__class__.__name__
|
450 |
+
|
451 |
+
@abstractmethod
|
452 |
+
def get(self, filepath: str) -> bytes:
|
453 |
+
pass
|
454 |
+
|
455 |
+
|
456 |
+
class PetrelBackend(BaseStorageBackend):
|
457 |
+
"""Petrel storage backend (for internal use).
|
458 |
+
|
459 |
+
PetrelBackend supports reading and writing data to multiple clusters.
|
460 |
+
If the file path contains the cluster name, PetrelBackend will read data
|
461 |
+
from specified cluster or write data to it. Otherwise, PetrelBackend will
|
462 |
+
access the default cluster.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
path_mapping (dict, optional): Path mapping dict from local path to
|
466 |
+
Petrel path. When ``path_mapping={'src': 'dst'}``, ``src`` in
|
467 |
+
``filepath`` will be replaced by ``dst``. Default: None.
|
468 |
+
enable_mc (bool, optional): Whether to enable memcached support.
|
469 |
+
Default: True.
|
470 |
+
conf_path (str, optional): Config path of Petrel client. Default: None.
|
471 |
+
`New in version 1.7.1`.
|
472 |
+
|
473 |
+
Examples:
|
474 |
+
>>> filepath1 = 's3://path/of/file'
|
475 |
+
>>> filepath2 = 'cluster-name:s3://path/of/file'
|
476 |
+
>>> client = PetrelBackend()
|
477 |
+
>>> client.get(filepath1) # get data from default cluster
|
478 |
+
>>> client.get(filepath2) # get data from 'cluster-name' cluster
|
479 |
+
"""
|
480 |
+
|
481 |
+
def __init__(self,
|
482 |
+
path_mapping: Optional[dict] = None,
|
483 |
+
enable_mc: bool = False,
|
484 |
+
conf_path: str = None):
|
485 |
+
try:
|
486 |
+
from petrel_client import client
|
487 |
+
except ImportError:
|
488 |
+
raise ImportError('Please install petrel_client to enable '
|
489 |
+
'PetrelBackend.')
|
490 |
+
|
491 |
+
self._client = client.Client(conf_path=conf_path, enable_mc=enable_mc)
|
492 |
+
assert isinstance(path_mapping, dict) or path_mapping is None
|
493 |
+
self.path_mapping = path_mapping
|
494 |
+
|
495 |
+
def _map_path(self, filepath: Union[str, Path]) -> str:
|
496 |
+
"""Map ``filepath`` to a string path whose prefix will be replaced by
|
497 |
+
:attr:`self.path_mapping`.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
filepath (str): Path to be mapped.
|
501 |
+
"""
|
502 |
+
filepath = str(filepath)
|
503 |
+
if self.path_mapping is not None:
|
504 |
+
for k, v in self.path_mapping.items():
|
505 |
+
filepath = filepath.replace(k, v, 1)
|
506 |
+
return filepath
|
507 |
+
|
508 |
+
def _format_path(self, filepath: str) -> str:
|
509 |
+
"""Convert a ``filepath`` to standard format of petrel oss.
|
510 |
+
|
511 |
+
If the ``filepath`` is concatenated by ``os.path.join``, in a Windows
|
512 |
+
environment, the ``filepath`` will be the format of
|
513 |
+
's3://bucket_name\\image.jpg'. By invoking :meth:`_format_path`, the
|
514 |
+
above ``filepath`` will be converted to 's3://bucket_name/image.jpg'.
|
515 |
+
|
516 |
+
Args:
|
517 |
+
filepath (str): Path to be formatted.
|
518 |
+
"""
|
519 |
+
return re.sub(r'\\+', '/', filepath)
|
520 |
+
|
521 |
+
def get(self, filepath: Union[str, Path]) -> bytes:
|
522 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
523 |
+
|
524 |
+
Args:
|
525 |
+
filepath (str or Path): Path to read data.
|
526 |
+
|
527 |
+
Returns:
|
528 |
+
bytes: The loaded bytes.
|
529 |
+
"""
|
530 |
+
filepath = self._map_path(filepath)
|
531 |
+
filepath = self._format_path(filepath)
|
532 |
+
value = self._client.Get(filepath)
|
533 |
+
return value
|
534 |
+
|
535 |
+
|
536 |
+
class HardDiskBackend(BaseStorageBackend):
|
537 |
+
"""Raw hard disks storage backend."""
|
538 |
+
|
539 |
+
def get(self, filepath: Union[str, Path]) -> bytes:
|
540 |
+
"""Read data from a given ``filepath`` with 'rb' mode.
|
541 |
+
|
542 |
+
Args:
|
543 |
+
filepath (str or Path): Path to read data.
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
bytes: Expected bytes object.
|
547 |
+
"""
|
548 |
+
with open(filepath, 'rb') as f:
|
549 |
+
value_buf = f.read()
|
550 |
+
return value_buf
|
551 |
+
|
552 |
+
|
553 |
+
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
554 |
+
"""Generate Gaussian noise.
|
555 |
+
|
556 |
+
Args:
|
557 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
558 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
559 |
+
|
560 |
+
Returns:
|
561 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
562 |
+
float32.
|
563 |
+
"""
|
564 |
+
if gray_noise:
|
565 |
+
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
566 |
+
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
567 |
+
else:
|
568 |
+
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
569 |
+
return noise
|
570 |
+
|
571 |
+
|
572 |
+
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
573 |
+
"""Add Gaussian noise.
|
574 |
+
|
575 |
+
Args:
|
576 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
577 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
578 |
+
|
579 |
+
Returns:
|
580 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
581 |
+
float32.
|
582 |
+
"""
|
583 |
+
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
584 |
+
out = img + noise
|
585 |
+
if clip and rounds:
|
586 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
587 |
+
elif clip:
|
588 |
+
out = np.clip(out, 0, 1)
|
589 |
+
elif rounds:
|
590 |
+
out = (out * 255.0).round() / 255.
|
591 |
+
return out
|
592 |
+
|
593 |
+
|
594 |
+
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
595 |
+
"""Add Gaussian noise (PyTorch version).
|
596 |
+
|
597 |
+
Args:
|
598 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
599 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
600 |
+
|
601 |
+
Returns:
|
602 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
603 |
+
float32.
|
604 |
+
"""
|
605 |
+
b, _, h, w = img.size()
|
606 |
+
if not isinstance(sigma, (float, int)):
|
607 |
+
sigma = sigma.view(img.size(0), 1, 1, 1)
|
608 |
+
if isinstance(gray_noise, (float, int)):
|
609 |
+
cal_gray_noise = gray_noise > 0
|
610 |
+
else:
|
611 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
612 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
613 |
+
|
614 |
+
if cal_gray_noise:
|
615 |
+
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
616 |
+
noise_gray = noise_gray.view(b, 1, h, w)
|
617 |
+
|
618 |
+
# always calculate color noise
|
619 |
+
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
620 |
+
|
621 |
+
if cal_gray_noise:
|
622 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
623 |
+
return noise
|
624 |
+
|
625 |
+
|
626 |
+
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
627 |
+
"""Add Gaussian noise (PyTorch version).
|
628 |
+
|
629 |
+
Args:
|
630 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
631 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
632 |
+
|
633 |
+
Returns:
|
634 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
635 |
+
float32.
|
636 |
+
"""
|
637 |
+
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
638 |
+
out = img + noise
|
639 |
+
if clip and rounds:
|
640 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
641 |
+
elif clip:
|
642 |
+
out = torch.clamp(out, 0, 1)
|
643 |
+
elif rounds:
|
644 |
+
out = (out * 255.0).round() / 255.
|
645 |
+
return out
|
646 |
+
|
647 |
+
|
648 |
+
# ----------------------- Random Gaussian Noise ----------------------- #
|
649 |
+
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
650 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
651 |
+
if np.random.uniform() < gray_prob:
|
652 |
+
gray_noise = True
|
653 |
+
else:
|
654 |
+
gray_noise = False
|
655 |
+
return generate_gaussian_noise(img, sigma, gray_noise)
|
656 |
+
|
657 |
+
|
658 |
+
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
659 |
+
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
660 |
+
out = img + noise
|
661 |
+
if clip and rounds:
|
662 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
663 |
+
elif clip:
|
664 |
+
out = np.clip(out, 0, 1)
|
665 |
+
elif rounds:
|
666 |
+
out = (out * 255.0).round() / 255.
|
667 |
+
return out
|
668 |
+
|
669 |
+
|
670 |
+
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
671 |
+
sigma = torch.rand(
|
672 |
+
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
673 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
674 |
+
gray_noise = (gray_noise < gray_prob).float()
|
675 |
+
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
676 |
+
|
677 |
+
|
678 |
+
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
679 |
+
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
680 |
+
out = img + noise
|
681 |
+
if clip and rounds:
|
682 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
683 |
+
elif clip:
|
684 |
+
out = torch.clamp(out, 0, 1)
|
685 |
+
elif rounds:
|
686 |
+
out = (out * 255.0).round() / 255.
|
687 |
+
return out
|
688 |
+
|
689 |
+
|
690 |
+
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
691 |
+
|
692 |
+
|
693 |
+
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
694 |
+
"""Generate poisson noise.
|
695 |
+
|
696 |
+
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
697 |
+
|
698 |
+
Args:
|
699 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
700 |
+
scale (float): Noise scale. Default: 1.0.
|
701 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
702 |
+
|
703 |
+
Returns:
|
704 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
705 |
+
float32.
|
706 |
+
"""
|
707 |
+
if gray_noise:
|
708 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
709 |
+
# round and clip image for counting vals correctly
|
710 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
711 |
+
vals = len(np.unique(img))
|
712 |
+
vals = 2 ** np.ceil(np.log2(vals))
|
713 |
+
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
714 |
+
noise = out - img
|
715 |
+
if gray_noise:
|
716 |
+
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
717 |
+
return noise * scale
|
718 |
+
|
719 |
+
|
720 |
+
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
721 |
+
"""Add poisson noise.
|
722 |
+
|
723 |
+
Args:
|
724 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
725 |
+
scale (float): Noise scale. Default: 1.0.
|
726 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
727 |
+
|
728 |
+
Returns:
|
729 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
730 |
+
float32.
|
731 |
+
"""
|
732 |
+
noise = generate_poisson_noise(img, scale, gray_noise)
|
733 |
+
out = img + noise
|
734 |
+
if clip and rounds:
|
735 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
736 |
+
elif clip:
|
737 |
+
out = np.clip(out, 0, 1)
|
738 |
+
elif rounds:
|
739 |
+
out = (out * 255.0).round() / 255.
|
740 |
+
return out
|
741 |
+
|
742 |
+
|
743 |
+
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
744 |
+
"""Generate a batch of poisson noise (PyTorch version)
|
745 |
+
|
746 |
+
Args:
|
747 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
748 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
749 |
+
Default: 1.0.
|
750 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
751 |
+
0 for False, 1 for True. Default: 0.
|
752 |
+
|
753 |
+
Returns:
|
754 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
755 |
+
float32.
|
756 |
+
"""
|
757 |
+
b, _, h, w = img.size()
|
758 |
+
if isinstance(gray_noise, (float, int)):
|
759 |
+
cal_gray_noise = gray_noise > 0
|
760 |
+
else:
|
761 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
762 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
763 |
+
if cal_gray_noise:
|
764 |
+
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
765 |
+
# round and clip image for counting vals correctly
|
766 |
+
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
767 |
+
# use for-loop to get the unique values for each sample
|
768 |
+
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
769 |
+
vals_list = [2 ** np.ceil(np.log2(vals)) for vals in vals_list]
|
770 |
+
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
771 |
+
out = torch.poisson(img_gray * vals) / vals
|
772 |
+
noise_gray = out - img_gray
|
773 |
+
noise_gray = noise_gray.expand(b, 3, h, w)
|
774 |
+
|
775 |
+
# always calculate color noise
|
776 |
+
# round and clip image for counting vals correctly
|
777 |
+
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
778 |
+
# use for-loop to get the unique values for each sample
|
779 |
+
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
780 |
+
vals_list = [2 ** np.ceil(np.log2(vals)) for vals in vals_list]
|
781 |
+
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
782 |
+
out = torch.poisson(img * vals) / vals
|
783 |
+
noise = out - img
|
784 |
+
if cal_gray_noise:
|
785 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
786 |
+
if not isinstance(scale, (float, int)):
|
787 |
+
scale = scale.view(b, 1, 1, 1)
|
788 |
+
return noise * scale
|
789 |
+
|
790 |
+
|
791 |
+
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
792 |
+
"""Add poisson noise to a batch of images (PyTorch version).
|
793 |
+
|
794 |
+
Args:
|
795 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
796 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
797 |
+
Default: 1.0.
|
798 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
799 |
+
0 for False, 1 for True. Default: 0.
|
800 |
+
|
801 |
+
Returns:
|
802 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
803 |
+
float32.
|
804 |
+
"""
|
805 |
+
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
806 |
+
out = img + noise
|
807 |
+
if clip and rounds:
|
808 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
809 |
+
elif clip:
|
810 |
+
out = torch.clamp(out, 0, 1)
|
811 |
+
elif rounds:
|
812 |
+
out = (out * 255.0).round() / 255.
|
813 |
+
return out
|
814 |
+
|
815 |
+
|
816 |
+
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
817 |
+
|
818 |
+
|
819 |
+
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
820 |
+
scale = np.random.uniform(scale_range[0], scale_range[1])
|
821 |
+
if np.random.uniform() < gray_prob:
|
822 |
+
gray_noise = True
|
823 |
+
else:
|
824 |
+
gray_noise = False
|
825 |
+
return generate_poisson_noise(img, scale, gray_noise)
|
826 |
+
|
827 |
+
|
828 |
+
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
829 |
+
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
830 |
+
out = img + noise
|
831 |
+
if clip and rounds:
|
832 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
833 |
+
elif clip:
|
834 |
+
out = np.clip(out, 0, 1)
|
835 |
+
elif rounds:
|
836 |
+
out = (out * 255.0).round() / 255.
|
837 |
+
return out
|
838 |
+
|
839 |
+
|
840 |
+
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
841 |
+
scale = torch.rand(
|
842 |
+
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
843 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
844 |
+
gray_noise = (gray_noise < gray_prob).float()
|
845 |
+
return generate_poisson_noise_pt(img, scale, gray_noise)
|
846 |
+
|
847 |
+
|
848 |
+
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
849 |
+
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
850 |
+
out = img + noise
|
851 |
+
if clip and rounds:
|
852 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
853 |
+
elif clip:
|
854 |
+
out = torch.clamp(out, 0, 1)
|
855 |
+
elif rounds:
|
856 |
+
out = (out * 255.0).round() / 255.
|
857 |
+
return out
|
858 |
+
|
859 |
+
|
860 |
+
# ------------------------------------------------------------------------ #
|
861 |
+
# --------------------------- JPEG compression --------------------------- #
|
862 |
+
# ------------------------------------------------------------------------ #
|
863 |
+
|
864 |
+
|
865 |
+
def add_jpg_compression(img, quality=90):
|
866 |
+
"""Add JPG compression artifacts.
|
867 |
+
|
868 |
+
Args:
|
869 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
870 |
+
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
871 |
+
best quality. Default: 90.
|
872 |
+
|
873 |
+
Returns:
|
874 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
875 |
+
float32.
|
876 |
+
"""
|
877 |
+
img = np.clip(img, 0, 1)
|
878 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
879 |
+
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
880 |
+
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
881 |
+
return img
|
882 |
+
|
883 |
+
|
884 |
+
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
885 |
+
"""Randomly add JPG compression artifacts.
|
886 |
+
|
887 |
+
Args:
|
888 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
889 |
+
quality_range (tuple[float] | list[float]): JPG compression quality
|
890 |
+
range. 0 for lowest quality, 100 for best quality.
|
891 |
+
Default: (90, 100).
|
892 |
+
|
893 |
+
Returns:
|
894 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
895 |
+
float32.
|
896 |
+
"""
|
897 |
+
quality = np.random.uniform(quality_range[0], quality_range[1])
|
898 |
+
return add_jpg_compression(img, int(quality))
|
899 |
+
|
900 |
+
|
901 |
+
def load_file_list(file_list_path: str) -> List[Dict[str, str]]:
|
902 |
+
files = []
|
903 |
+
with open(file_list_path, "r") as fin:
|
904 |
+
for line in fin:
|
905 |
+
path = line.strip()
|
906 |
+
if path:
|
907 |
+
files.append({"image_path": path, "prompt": ""})
|
908 |
+
return files
|
909 |
+
|
910 |
+
|
911 |
+
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py
|
912 |
+
def center_crop_arr(pil_image, image_size):
|
913 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
914 |
+
# argument, which uses BOX downsampling at powers of two first.
|
915 |
+
# Thus, we do it by hand to improve downsample quality.
|
916 |
+
while min(*pil_image.size) >= 2 * image_size:
|
917 |
+
pil_image = pil_image.resize(
|
918 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
919 |
+
)
|
920 |
+
|
921 |
+
scale = image_size / min(*pil_image.size)
|
922 |
+
pil_image = pil_image.resize(
|
923 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
924 |
+
)
|
925 |
+
|
926 |
+
arr = np.array(pil_image)
|
927 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
928 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
929 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
930 |
+
|
931 |
+
|
932 |
+
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py
|
933 |
+
def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
|
934 |
+
min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
|
935 |
+
max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
|
936 |
+
smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
|
937 |
+
|
938 |
+
# We are not on a new enough PIL to support the `reducing_gap`
|
939 |
+
# argument, which uses BOX downsampling at powers of two first.
|
940 |
+
# Thus, we do it by hand to improve downsample quality.
|
941 |
+
while min(*pil_image.size) >= 2 * smaller_dim_size:
|
942 |
+
pil_image = pil_image.resize(
|
943 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
944 |
+
)
|
945 |
+
|
946 |
+
scale = smaller_dim_size / min(*pil_image.size)
|
947 |
+
pil_image = pil_image.resize(
|
948 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
949 |
+
)
|
950 |
+
|
951 |
+
arr = np.array(pil_image)
|
952 |
+
crop_y = random.randrange(arr.shape[0] - image_size + 1)
|
953 |
+
crop_x = random.randrange(arr.shape[1] - image_size + 1)
|
954 |
+
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
|
utils/create_arch.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from arch.hourglass import image_transformer_v2 as itv2
|
2 |
+
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
|
3 |
+
from arch.swinir.swinir import SwinIR
|
4 |
+
|
5 |
+
|
6 |
+
def create_arch(arch, condition_channels=0):
|
7 |
+
# arch should be, e.g., swinir_XL, or hdit_XL
|
8 |
+
arch_name, arch_size = arch.split('_')
|
9 |
+
arch_config = arch_configs[arch_name][arch_size].copy()
|
10 |
+
arch_config['in_channels'] += condition_channels
|
11 |
+
return arch_name_to_object[arch_name](**arch_config)
|
12 |
+
|
13 |
+
|
14 |
+
arch_configs = {
|
15 |
+
'hdit': {
|
16 |
+
"ImageNet256Sp4": {
|
17 |
+
'in_channels': 3,
|
18 |
+
'out_channels': 3,
|
19 |
+
'widths': [256, 512, 1024],
|
20 |
+
'depths': [2, 2, 8],
|
21 |
+
'patch_size': [4, 4],
|
22 |
+
'self_attns': [
|
23 |
+
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
|
24 |
+
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
|
25 |
+
{"type": "global", "d_head": 64}
|
26 |
+
],
|
27 |
+
'mapping_depth': 2,
|
28 |
+
'mapping_width': 768,
|
29 |
+
'dropout_rate': [0, 0, 0],
|
30 |
+
'mapping_dropout_rate': 0.0
|
31 |
+
},
|
32 |
+
"XL2": {
|
33 |
+
'in_channels': 3,
|
34 |
+
'out_channels': 3,
|
35 |
+
'widths': [384, 768],
|
36 |
+
'depths': [2, 11],
|
37 |
+
'patch_size': [4, 4],
|
38 |
+
'self_attns': [
|
39 |
+
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
|
40 |
+
{"type": "global", "d_head": 64}
|
41 |
+
],
|
42 |
+
'mapping_depth': 2,
|
43 |
+
'mapping_width': 768,
|
44 |
+
'dropout_rate': [0, 0],
|
45 |
+
'mapping_dropout_rate': 0.0
|
46 |
+
}
|
47 |
+
|
48 |
+
},
|
49 |
+
'swinir': {
|
50 |
+
"M": {
|
51 |
+
'in_channels': 3,
|
52 |
+
'out_channels': 3,
|
53 |
+
'embed_dim': 120,
|
54 |
+
'depths': [6, 6, 6, 6, 6],
|
55 |
+
'num_heads': [6, 6, 6, 6, 6],
|
56 |
+
'resi_connection': '1conv',
|
57 |
+
'sf': 8
|
58 |
+
|
59 |
+
},
|
60 |
+
"L": {
|
61 |
+
'in_channels': 3,
|
62 |
+
'out_channels': 3,
|
63 |
+
'embed_dim': 180,
|
64 |
+
'depths': [6, 6, 6, 6, 6, 6, 6, 6],
|
65 |
+
'num_heads': [6, 6, 6, 6, 6, 6, 6, 6],
|
66 |
+
'resi_connection': '1conv',
|
67 |
+
'sf': 8
|
68 |
+
},
|
69 |
+
},
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
def create_swinir_model(in_channels, out_channels, embed_dim, depths, num_heads, resi_connection,
|
74 |
+
sf):
|
75 |
+
return SwinIR(
|
76 |
+
img_size=64,
|
77 |
+
patch_size=1,
|
78 |
+
in_chans=in_channels,
|
79 |
+
num_out_ch=out_channels,
|
80 |
+
embed_dim=embed_dim,
|
81 |
+
depths=depths,
|
82 |
+
num_heads=num_heads,
|
83 |
+
window_size=8,
|
84 |
+
mlp_ratio=2,
|
85 |
+
sf=sf,
|
86 |
+
img_range=1.0,
|
87 |
+
upsampler="nearest+conv",
|
88 |
+
resi_connection=resi_connection,
|
89 |
+
unshuffle=True,
|
90 |
+
unshuffle_scale=8
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def create_hdit_model(widths,
|
95 |
+
depths,
|
96 |
+
self_attns,
|
97 |
+
dropout_rate,
|
98 |
+
mapping_depth,
|
99 |
+
mapping_width,
|
100 |
+
mapping_dropout_rate,
|
101 |
+
in_channels,
|
102 |
+
out_channels,
|
103 |
+
patch_size
|
104 |
+
):
|
105 |
+
assert len(widths) == len(depths)
|
106 |
+
assert len(widths) == len(self_attns)
|
107 |
+
assert len(widths) == len(dropout_rate)
|
108 |
+
mapping_d_ff = mapping_width * 3
|
109 |
+
d_ffs = []
|
110 |
+
for width in widths:
|
111 |
+
d_ffs.append(width * 3)
|
112 |
+
|
113 |
+
levels = []
|
114 |
+
for depth, width, d_ff, self_attn, dropout in zip(depths, widths, d_ffs, self_attns, dropout_rate):
|
115 |
+
if self_attn['type'] == 'global':
|
116 |
+
self_attn = itv2.GlobalAttentionSpec(self_attn.get('d_head', 64))
|
117 |
+
elif self_attn['type'] == 'neighborhood':
|
118 |
+
self_attn = itv2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
|
119 |
+
elif self_attn['type'] == 'shifted-window':
|
120 |
+
self_attn = itv2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
|
121 |
+
elif self_attn['type'] == 'none':
|
122 |
+
self_attn = itv2.NoAttentionSpec()
|
123 |
+
else:
|
124 |
+
raise ValueError(f'unsupported self attention type {self_attn["type"]}')
|
125 |
+
levels.append(itv2.LevelSpec(depth, width, d_ff, self_attn, dropout))
|
126 |
+
mapping = itv2.MappingSpec(mapping_depth, mapping_width, mapping_d_ff, mapping_dropout_rate)
|
127 |
+
model = ImageTransformerDenoiserModelV2(
|
128 |
+
levels=levels,
|
129 |
+
mapping=mapping,
|
130 |
+
in_channels=in_channels,
|
131 |
+
out_channels=out_channels,
|
132 |
+
patch_size=patch_size,
|
133 |
+
num_classes=0,
|
134 |
+
mapping_cond_dim=0,
|
135 |
+
)
|
136 |
+
|
137 |
+
return model
|
138 |
+
|
139 |
+
|
140 |
+
arch_name_to_object = {
|
141 |
+
'hdit': create_hdit_model,
|
142 |
+
'swinir': create_swinir_model,
|
143 |
+
}
|
utils/create_degradation.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from basicsr.data import degradations as degradations
|
8 |
+
from basicsr.data.transforms import augment
|
9 |
+
from basicsr.utils import img2tensor
|
10 |
+
from torch.nn.functional import interpolate
|
11 |
+
from torchvision.transforms import Compose
|
12 |
+
from utils.basicsr_custom import (
|
13 |
+
random_mixed_kernels,
|
14 |
+
random_add_gaussian_noise,
|
15 |
+
random_add_jpg_compression,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def create_degradation(degradation):
|
20 |
+
if degradation == 'sr_bicubic_x8_gaussian_noise_005':
|
21 |
+
return Compose([
|
22 |
+
partial(down_scale, scale_factor=1.0 / 8.0, mode='bicubic'),
|
23 |
+
partial(add_gaussian_noise, std=0.05),
|
24 |
+
partial(interpolate, scale_factor=8.0, mode='nearest-exact'),
|
25 |
+
partial(torch.clip, min=0, max=1),
|
26 |
+
partial(torch.squeeze, dim=0),
|
27 |
+
lambda x: (x, None)
|
28 |
+
|
29 |
+
])
|
30 |
+
elif degradation == 'gaussian_noise_035':
|
31 |
+
return Compose([
|
32 |
+
partial(add_gaussian_noise, std=0.35),
|
33 |
+
partial(torch.clip, min=0, max=1),
|
34 |
+
partial(torch.squeeze, dim=0),
|
35 |
+
lambda x: (x, None)
|
36 |
+
|
37 |
+
])
|
38 |
+
elif degradation == 'colorization_gaussian_noise_025':
|
39 |
+
return Compose([
|
40 |
+
lambda x: torch.mean(x, dim=0, keepdim=True),
|
41 |
+
partial(add_gaussian_noise, std=0.25),
|
42 |
+
partial(torch.clip, min=0, max=1),
|
43 |
+
lambda x: (x, None)
|
44 |
+
])
|
45 |
+
elif degradation == 'random_inpainting_gaussian_noise_01':
|
46 |
+
def inpainting_dps(x):
|
47 |
+
total = x.shape[1] ** 2
|
48 |
+
# random pixel sampling
|
49 |
+
l, h = [0.9, 0.9]
|
50 |
+
prob = np.random.uniform(l, h)
|
51 |
+
mask_vec = torch.ones([1, x.shape[1] * x.shape[1]])
|
52 |
+
samples = np.random.choice(x.shape[1] * x.shape[1], int(total * prob), replace=False)
|
53 |
+
mask_vec[:, samples] = 0
|
54 |
+
mask_b = mask_vec.view(1, x.shape[1], x.shape[1])
|
55 |
+
mask_b = mask_b.repeat(3, 1, 1)
|
56 |
+
mask = torch.ones_like(x, device=x.device)
|
57 |
+
mask[:, ...] = mask_b
|
58 |
+
return add_gaussian_noise(x * mask, 0.1).clip(0, 1), None
|
59 |
+
|
60 |
+
return inpainting_dps
|
61 |
+
elif degradation == 'difface':
|
62 |
+
def deg(x):
|
63 |
+
blur_kernel_size = 41
|
64 |
+
kernel_list = ['iso', 'aniso']
|
65 |
+
kernel_prob = [0.5, 0.5]
|
66 |
+
blur_sigma = [0.1, 15]
|
67 |
+
downsample_range = [0.8, 32]
|
68 |
+
noise_range = [0, 20]
|
69 |
+
jpeg_range = [30, 100]
|
70 |
+
gt_gray = True
|
71 |
+
gray_prob = 0.01
|
72 |
+
x = x.permute(1, 2, 0).numpy()[..., ::-1].astype(np.float32)
|
73 |
+
# random horizontal flip
|
74 |
+
img_gt = augment(x.copy(), hflip=True, rotation=False)
|
75 |
+
h, w, _ = img_gt.shape
|
76 |
+
|
77 |
+
# ------------------------ generate lq image ------------------------ #
|
78 |
+
# blur
|
79 |
+
kernel = degradations.random_mixed_kernels(
|
80 |
+
kernel_list,
|
81 |
+
kernel_prob,
|
82 |
+
blur_kernel_size,
|
83 |
+
blur_sigma,
|
84 |
+
blur_sigma, [-math.pi, math.pi],
|
85 |
+
noise_range=None)
|
86 |
+
img_lq = cv2.filter2D(img_gt, -1, kernel)
|
87 |
+
# downsample
|
88 |
+
scale = np.random.uniform(downsample_range[0], downsample_range[1])
|
89 |
+
img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR)
|
90 |
+
# noise
|
91 |
+
if noise_range is not None:
|
92 |
+
img_lq = random_add_gaussian_noise(img_lq, noise_range)
|
93 |
+
# jpeg compression
|
94 |
+
if jpeg_range is not None:
|
95 |
+
img_lq = random_add_jpg_compression(img_lq, jpeg_range)
|
96 |
+
|
97 |
+
# resize to original size
|
98 |
+
img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR)
|
99 |
+
|
100 |
+
# random color jitter (only for lq)
|
101 |
+
# if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
|
102 |
+
# img_lq = self.color_jitter(img_lq, self.color_jitter_shift)
|
103 |
+
# random to gray (only for lq)
|
104 |
+
if np.random.uniform() < gray_prob:
|
105 |
+
img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY)
|
106 |
+
img_lq = np.tile(img_lq[:, :, None], [1, 1, 3])
|
107 |
+
if gt_gray: # whether convert GT to gray images
|
108 |
+
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
|
109 |
+
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels
|
110 |
+
|
111 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
112 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
113 |
+
|
114 |
+
# random color jitter (pytorch version) (only for lq)
|
115 |
+
# if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
|
116 |
+
# brightness = self.opt.get('brightness', (0.5, 1.5))
|
117 |
+
# contrast = self.opt.get('contrast', (0.5, 1.5))
|
118 |
+
# saturation = self.opt.get('saturation', (0, 1.5))
|
119 |
+
# hue = self.opt.get('hue', (-0.1, 0.1))
|
120 |
+
# img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue)
|
121 |
+
|
122 |
+
# round and clip
|
123 |
+
img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
|
124 |
+
|
125 |
+
return img_lq, img_gt.clip(0, 1)
|
126 |
+
|
127 |
+
return deg
|
128 |
+
else:
|
129 |
+
raise NotImplementedError()
|
130 |
+
|
131 |
+
|
132 |
+
def down_scale(x, scale_factor, mode):
|
133 |
+
with torch.no_grad():
|
134 |
+
return interpolate(x.unsqueeze(0),
|
135 |
+
scale_factor=scale_factor,
|
136 |
+
mode=mode,
|
137 |
+
antialias=True,
|
138 |
+
align_corners=False).clip(0, 1)
|
139 |
+
|
140 |
+
|
141 |
+
def add_gaussian_noise(x, std):
|
142 |
+
with torch.no_grad():
|
143 |
+
x = x + torch.randn_like(x) * std
|
144 |
+
return x
|
utils/img_utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.utils import make_grid
|
2 |
+
|
3 |
+
|
4 |
+
def create_grid(img, normalize=False, num_images=5):
|
5 |
+
return make_grid(img[:num_images], padding=0, normalize=normalize, nrow=16)
|