HikariDawn commited on
Commit
561c629
1 Parent(s): 193b9cd

feat: initial push

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +30 -0
  2. __assets__/logo.png +0 -0
  3. __assets__/lr_inputs/41.png +0 -0
  4. __assets__/lr_inputs/f91.jpg +0 -0
  5. __assets__/lr_inputs/image-00164.jpg +0 -0
  6. __assets__/lr_inputs/image-00186.png +0 -0
  7. __assets__/lr_inputs/image-00277.png +0 -0
  8. __assets__/lr_inputs/image-00440.png +0 -0
  9. __assets__/lr_inputs/image-00542.png +0 -0
  10. __assets__/lr_inputs/img_eva.jpeg +0 -0
  11. __assets__/lr_inputs/screenshot_resize.jpg +0 -0
  12. __assets__/visual_results/0079_2_visual.png +0 -0
  13. __assets__/visual_results/0079_visual.png +0 -0
  14. __assets__/visual_results/eva_visual.png +0 -0
  15. __assets__/visual_results/f91_visual.png +0 -0
  16. __assets__/visual_results/kiteret_visual.png +0 -0
  17. __assets__/visual_results/pokemon2_visual.png +0 -0
  18. __assets__/visual_results/pokemon_visual.png +0 -0
  19. __assets__/visual_results/wataru_visual.png +0 -0
  20. __assets__/workflow.png +0 -0
  21. app.py +117 -0
  22. architecture/cunet.py +189 -0
  23. architecture/dataset.py +106 -0
  24. architecture/discriminator.py +241 -0
  25. architecture/grl.py +616 -0
  26. architecture/grl_common/__init__.py +8 -0
  27. architecture/grl_common/common_edsr.py +227 -0
  28. architecture/grl_common/mixed_attn_block.py +1126 -0
  29. architecture/grl_common/mixed_attn_block_efficient.py +568 -0
  30. architecture/grl_common/ops.py +551 -0
  31. architecture/grl_common/resblock.py +61 -0
  32. architecture/grl_common/swin_v1_block.py +602 -0
  33. architecture/grl_common/swin_v2_block.py +306 -0
  34. architecture/grl_common/upsample.py +50 -0
  35. architecture/rrdb.py +218 -0
  36. architecture/swinir.py +874 -0
  37. dataset_curation_pipeline/IC9600/ICNet.py +151 -0
  38. dataset_curation_pipeline/IC9600/gene.py +113 -0
  39. dataset_curation_pipeline/collect.py +222 -0
  40. degradation/ESR/degradation_esr_shared.py +180 -0
  41. degradation/ESR/degradations_functionality.py +785 -0
  42. degradation/ESR/diffjpeg.py +517 -0
  43. degradation/ESR/usm_sharp.py +114 -0
  44. degradation/ESR/utils.py +126 -0
  45. degradation/degradation_esr.py +110 -0
  46. degradation/image_compression/avif.py +88 -0
  47. degradation/image_compression/heif.py +90 -0
  48. degradation/image_compression/jpeg.py +68 -0
  49. degradation/image_compression/webp.py +65 -0
  50. degradation/video_compression/h264.py +73 -0
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets/*
2
+ .ipynb_checkpoints
3
+ .idea
4
+ __pycache__
5
+
6
+ datasets/
7
+ tmp_imgs
8
+ runs/
9
+ runs_last/
10
+ saved_models/*
11
+ saved_models/
12
+ pre_trained/
13
+ save_log/*
14
+ tmp/*
15
+
16
+ *.pyc
17
+ *.pth
18
+ *.png
19
+ *.jpg
20
+ *.mp4
21
+ *.txt
22
+ *.json
23
+ *.zip
24
+ *.mp4
25
+ *.csv
26
+
27
+ !__assets__/lr_inputs/*
28
+ !__assets__/*
29
+ !__assets__/visual_results/*
30
+ !requirements.txt
__assets__/logo.png ADDED
__assets__/lr_inputs/41.png ADDED
__assets__/lr_inputs/f91.jpg ADDED
__assets__/lr_inputs/image-00164.jpg ADDED
__assets__/lr_inputs/image-00186.png ADDED
__assets__/lr_inputs/image-00277.png ADDED
__assets__/lr_inputs/image-00440.png ADDED
__assets__/lr_inputs/image-00542.png ADDED
__assets__/lr_inputs/img_eva.jpeg ADDED
__assets__/lr_inputs/screenshot_resize.jpg ADDED
__assets__/visual_results/0079_2_visual.png ADDED
__assets__/visual_results/0079_visual.png ADDED
__assets__/visual_results/eva_visual.png ADDED
__assets__/visual_results/f91_visual.png ADDED
__assets__/visual_results/kiteret_visual.png ADDED
__assets__/visual_results/pokemon2_visual.png ADDED
__assets__/visual_results/pokemon_visual.png ADDED
__assets__/visual_results/wataru_visual.png ADDED
__assets__/workflow.png ADDED
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import cv2
3
+ import gradio as gr
4
+ import torch
5
+ import numpy as np
6
+ from torchvision.utils import save_image
7
+
8
+
9
+ # Import files from the local folder
10
+ root_path = os.path.abspath('.')
11
+ sys.path.append(root_path)
12
+ from test_code.inference import super_resolve_img
13
+ from test_code.test_utils import load_grl, load_rrdb
14
+
15
+
16
+ def auto_download_if_needed(weight_path):
17
+ if os.path.exists(weight_path):
18
+ return
19
+
20
+ if not os.path.exists("pretrained"):
21
+ os.makedirs("pretrained")
22
+
23
+ if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
24
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
25
+ os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
26
+
27
+ if weight_path == "pretrained/2x_APISR_RRDB_GAN_generator.pth":
28
+ os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
29
+ os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")
30
+
31
+
32
+
33
+ def inference(img_path, model_name):
34
+
35
+ try:
36
+ weight_dtype = torch.float32
37
+
38
+ # Load the model
39
+ if model_name == "4xGRL":
40
+ weight_path = "pretrained/4x_APISR_GRL_GAN_generator.pth"
41
+ auto_download_if_needed(weight_path)
42
+ generator = load_grl(weight_path, scale=4) # Directly use default way now
43
+
44
+ elif model_name == "2xRRDB":
45
+ weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
46
+ auto_download_if_needed(weight_path)
47
+ generator = load_rrdb(weight_path, scale=2) # Directly use default way now
48
+
49
+ else:
50
+ raise gr.Error(error)
51
+
52
+ generator = generator.to(dtype=weight_dtype)
53
+
54
+
55
+ # In default, we will automatically use crop to match 4x size
56
+ super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, crop_for_4x=True)
57
+ save_image(super_resolved_img, "SR_result.png")
58
+ outputs = cv2.imread("SR_result.png")
59
+ outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
60
+
61
+ return outputs
62
+
63
+
64
+ except Exception as error:
65
+ raise gr.Error(f"global exception: {error}")
66
+
67
+
68
+
69
+ if __name__ == '__main__':
70
+
71
+ MARKDOWN = \
72
+ """
73
+ ## APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)
74
+
75
+ [GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
76
+
77
+ If APISR is helpful for you, please help star the GitHub Repo. Thanks!
78
+ """
79
+
80
+ block = gr.Blocks().queue()
81
+ with block:
82
+ with gr.Row():
83
+ gr.Markdown(MARKDOWN)
84
+ with gr.Row(elem_classes=["container"]):
85
+ with gr.Column(scale=2):
86
+ input_image = gr.Image(type="filepath", label="Input")
87
+ model_name = gr.Dropdown(
88
+ [
89
+ "2xRRDB",
90
+ "4xGRL"
91
+ ],
92
+ type="value",
93
+ value="4xGRL",
94
+ label="model",
95
+ )
96
+ run_btn = gr.Button(value="Submit")
97
+
98
+ with gr.Column(scale=3):
99
+ output_image = gr.Image(type="numpy", label="Output image")
100
+
101
+ with gr.Row(elem_classes=["container"]):
102
+ gr.Examples(
103
+ [
104
+ ["__assets__/lr_inputs/image-00277.png"],
105
+ ["__assets__/lr_inputs/image-00542.png"],
106
+ ["__assets__/lr_inputs/41.png"],
107
+ ["__assets__/lr_inputs/f91.jpg"],
108
+ ["__assets__/lr_inputs/image-00440.png"],
109
+ ["__assets__/lr_inputs/image-00164.png"],
110
+ ["__assets__/lr_inputs/img_eva.jpeg"],
111
+ ],
112
+ [input_image],
113
+ )
114
+
115
+ run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])
116
+
117
+ block.launch()
architecture/cunet.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Github Repository: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/README_EN.md
2
+ # Code snippet (with certain modificaiton) from: https://github.com/bilibili/ailab/blob/main/Real-CUGAN/VapourSynth/upcunet_v3_vs.py
3
+
4
+ import torch
5
+ from torch import nn as nn
6
+ from torch.nn import functional as F
7
+ import os, sys
8
+ import numpy as np
9
+ from time import time as ttime, sleep
10
+
11
+
12
+ class UNet_Full(nn.Module):
13
+
14
+ def __init__(self):
15
+ super(UNet_Full, self).__init__()
16
+ self.unet1 = UNet1(3, 3, deconv=True)
17
+ self.unet2 = UNet2(3, 3, deconv=False)
18
+
19
+ def forward(self, x):
20
+ n, c, h0, w0 = x.shape
21
+
22
+ ph = ((h0 - 1) // 2 + 1) * 2
23
+ pw = ((w0 - 1) // 2 + 1) * 2
24
+ x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), 'reflect') # In order to ensure that it can be divided by 2
25
+
26
+ x1 = self.unet1(x)
27
+ x2 = self.unet2(x1)
28
+
29
+ x1 = F.pad(x1, (-20, -20, -20, -20))
30
+ output = torch.add(x2, x1)
31
+
32
+ if (w0 != pw or h0 != ph):
33
+ output = output[:, :, :h0 * 2, :w0 * 2]
34
+
35
+ return output
36
+
37
+
38
+ class SEBlock(nn.Module):
39
+ def __init__(self, in_channels, reduction=8, bias=False):
40
+ super(SEBlock, self).__init__()
41
+ self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, 1, 0, bias=bias)
42
+ self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, 1, 0, bias=bias)
43
+
44
+ def forward(self, x):
45
+ if ("Half" in x.type()): # torch.HalfTensor/torch.cuda.HalfTensor
46
+ x0 = torch.mean(x.float(), dim=(2, 3), keepdim=True).half()
47
+ else:
48
+ x0 = torch.mean(x, dim=(2, 3), keepdim=True)
49
+ x0 = self.conv1(x0)
50
+ x0 = F.relu(x0, inplace=True)
51
+ x0 = self.conv2(x0)
52
+ x0 = torch.sigmoid(x0)
53
+ x = torch.mul(x, x0)
54
+ return x
55
+
56
+ class UNetConv(nn.Module):
57
+ def __init__(self, in_channels, mid_channels, out_channels, se):
58
+ super(UNetConv, self).__init__()
59
+ self.conv = nn.Sequential(
60
+ nn.Conv2d(in_channels, mid_channels, 3, 1, 0),
61
+ nn.LeakyReLU(0.1, inplace=True),
62
+ nn.Conv2d(mid_channels, out_channels, 3, 1, 0),
63
+ nn.LeakyReLU(0.1, inplace=True),
64
+ )
65
+ if se:
66
+ self.seblock = SEBlock(out_channels, reduction=8, bias=True)
67
+ else:
68
+ self.seblock = None
69
+
70
+ def forward(self, x):
71
+ z = self.conv(x)
72
+ if self.seblock is not None:
73
+ z = self.seblock(z)
74
+ return z
75
+
76
+ class UNet1(nn.Module):
77
+ def __init__(self, in_channels, out_channels, deconv):
78
+ super(UNet1, self).__init__()
79
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
80
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
81
+ self.conv2 = UNetConv(64, 128, 64, se=True)
82
+ self.conv2_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
83
+ self.conv3 = nn.Conv2d(64, 64, 3, 1, 0)
84
+
85
+ if deconv:
86
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
87
+ else:
88
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
89
+
90
+ for m in self.modules():
91
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
92
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
93
+ elif isinstance(m, nn.Linear):
94
+ nn.init.normal_(m.weight, 0, 0.01)
95
+ if m.bias is not None:
96
+ nn.init.constant_(m.bias, 0)
97
+
98
+ def forward(self, x):
99
+ x1 = self.conv1(x)
100
+ x2 = self.conv1_down(x1)
101
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
102
+ x2 = self.conv2(x2)
103
+ x2 = self.conv2_up(x2)
104
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
105
+
106
+ x1 = F.pad(x1, (-4, -4, -4, -4))
107
+ x3 = self.conv3(x1 + x2)
108
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
109
+ z = self.conv_bottom(x3)
110
+ return z
111
+
112
+
113
+ class UNet2(nn.Module):
114
+ def __init__(self, in_channels, out_channels, deconv):
115
+ super(UNet2, self).__init__()
116
+
117
+ self.conv1 = UNetConv(in_channels, 32, 64, se=False)
118
+ self.conv1_down = nn.Conv2d(64, 64, 2, 2, 0)
119
+ self.conv2 = UNetConv(64, 64, 128, se=True)
120
+ self.conv2_down = nn.Conv2d(128, 128, 2, 2, 0)
121
+ self.conv3 = UNetConv(128, 256, 128, se=True)
122
+ self.conv3_up = nn.ConvTranspose2d(128, 128, 2, 2, 0)
123
+ self.conv4 = UNetConv(128, 64, 64, se=True)
124
+ self.conv4_up = nn.ConvTranspose2d(64, 64, 2, 2, 0)
125
+ self.conv5 = nn.Conv2d(64, 64, 3, 1, 0)
126
+
127
+ if deconv:
128
+ self.conv_bottom = nn.ConvTranspose2d(64, out_channels, 4, 2, 3)
129
+ else:
130
+ self.conv_bottom = nn.Conv2d(64, out_channels, 3, 1, 0)
131
+
132
+ for m in self.modules():
133
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
134
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
135
+ elif isinstance(m, nn.Linear):
136
+ nn.init.normal_(m.weight, 0, 0.01)
137
+ if m.bias is not None:
138
+ nn.init.constant_(m.bias, 0)
139
+
140
+ def forward(self, x):
141
+ x1 = self.conv1(x)
142
+ x2 = self.conv1_down(x1)
143
+ x2 = F.leaky_relu(x2, 0.1, inplace=True)
144
+ x2 = self.conv2(x2)
145
+
146
+ x3 = self.conv2_down(x2)
147
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
148
+ x3 = self.conv3(x3)
149
+ x3 = self.conv3_up(x3)
150
+ x3 = F.leaky_relu(x3, 0.1, inplace=True)
151
+
152
+ x2 = F.pad(x2, (-4, -4, -4, -4))
153
+ x4 = self.conv4(x2 + x3)
154
+ x4 = self.conv4_up(x4)
155
+ x4 = F.leaky_relu(x4, 0.1, inplace=True)
156
+
157
+ x1 = F.pad(x1, (-16, -16, -16, -16))
158
+ x5 = self.conv5(x1 + x4)
159
+ x5 = F.leaky_relu(x5, 0.1, inplace=True)
160
+
161
+ z = self.conv_bottom(x5)
162
+ return z
163
+
164
+
165
+
166
+ def main():
167
+ root_path = os.path.abspath('.')
168
+ sys.path.append(root_path)
169
+
170
+ from opt import opt # Manage GPU to choose
171
+ import time
172
+
173
+ model = UNet_Full().cuda()
174
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
175
+ print(f"CuNet has param {pytorch_total_params//1000} K params")
176
+
177
+
178
+ # Count the number of FLOPs to double check
179
+ x = torch.randn((1, 3, 180, 180)).cuda()
180
+ start = time.time()
181
+ x = model(x)
182
+ print("output size is ", x.shape)
183
+ total = time.time() - start
184
+ print(total)
185
+
186
+
187
+
188
+ if __name__ == "__main__":
189
+ main()
architecture/dataset.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Variable
6
+ from torchvision.models import vgg19
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from torchvision.utils import save_image, make_grid
10
+ from torchvision.transforms import ToTensor
11
+
12
+ import numpy as np
13
+ import cv2
14
+ import glob
15
+ import random
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+
19
+
20
+ # from degradation.degradation_main import degredate_process, preparation
21
+ from opt import opt
22
+
23
+
24
+ class ImageDataset(Dataset):
25
+ @torch.no_grad()
26
+ def __init__(self, train_lr_paths, degrade_hr_paths, train_hr_paths):
27
+ # print("low_res path sample is ", train_lr_paths[0])
28
+ # print(train_hr_paths[0])
29
+ # hr_height, hr_width = hr_shape
30
+ self.transform = transforms.Compose(
31
+ [
32
+ transforms.ToTensor(),
33
+ ]
34
+ )
35
+
36
+ self.files_lr = train_lr_paths
37
+ self.files_degrade_hr = degrade_hr_paths
38
+ self.files_hr = train_hr_paths
39
+
40
+ assert(len(self.files_lr) == len(self.files_hr))
41
+ assert(len(self.files_lr) == len(self.files_degrade_hr))
42
+
43
+
44
+ def augment(self, imgs, hflip=True, rotation=True):
45
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
46
+
47
+ All the images in the list use the same augmentation.
48
+
49
+ Args:
50
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
51
+ is an ndarray, it will be transformed to a list.
52
+ hflip (bool): Horizontal flip. Default: True.
53
+ rotation (bool): Rotation. Default: True.
54
+
55
+ Returns:
56
+ imgs (list[ndarray] | ndarray): Augmented images and flows. If returned
57
+ results only have one element, just return ndarray.
58
+
59
+ """
60
+ hflip = hflip and random.random() < 0.5
61
+ vflip = rotation and random.random() < 0.5
62
+ rot90 = rotation and random.random() < 0.5
63
+
64
+ def _augment(img):
65
+ if hflip: # horizontal
66
+ cv2.flip(img, 1, img)
67
+ if vflip: # vertical
68
+ cv2.flip(img, 0, img)
69
+ if rot90:
70
+ img = img.transpose(1, 0, 2)
71
+ return img
72
+
73
+
74
+ if not isinstance(imgs, list):
75
+ imgs = [imgs]
76
+
77
+ imgs = [_augment(img) for img in imgs]
78
+ if len(imgs) == 1:
79
+ imgs = imgs[0]
80
+
81
+
82
+ return imgs
83
+
84
+
85
+ def __getitem__(self, index):
86
+
87
+ # Read File
88
+ img_lr = cv2.imread(self.files_lr[index % len(self.files_lr)]) # Should be BGR
89
+ img_degrade_hr = cv2.imread(self.files_degrade_hr[index % len(self.files_degrade_hr)])
90
+ img_hr = cv2.imread(self.files_hr[index % len(self.files_hr)])
91
+
92
+ # Augmentation
93
+ if random.random() < opt["augment_prob"]:
94
+ img_lr, img_degrade_hr, img_hr = self.augment([img_lr, img_degrade_hr, img_hr])
95
+
96
+ # Transform to Tensor
97
+ img_lr = self.transform(img_lr)
98
+ img_degrade_hr = self.transform(img_degrade_hr)
99
+ img_hr = self.transform(img_hr) # ToTensor() is already in the range [0, 1]
100
+
101
+
102
+ return {"lr": img_lr, "degrade_hr": img_degrade_hr, "hr": img_hr}
103
+
104
+ def __len__(self):
105
+ assert(len(self.files_hr) == len(self.files_lr))
106
+ return len(self.files_hr)
architecture/discriminator.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.nn.utils import spectral_norm
5
+ import torch
6
+ import functools
7
+
8
+ class UNetDiscriminatorSN(nn.Module):
9
+ """Defines a U-Net discriminator with spectral normalization (SN)
10
+
11
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12
+
13
+ Arg:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_feat (int): Channel number of base intermediate features. Default: 64.
16
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17
+ """
18
+
19
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20
+ super(UNetDiscriminatorSN, self).__init__()
21
+ self.skip_connection = skip_connection
22
+ norm = spectral_norm
23
+ # the first convolution
24
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25
+ # downsample
26
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
29
+ # upsample
30
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33
+ # extra convolutions
34
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
36
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37
+
38
+ def forward(self, x):
39
+
40
+ # downsample
41
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
42
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
43
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
44
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
45
+
46
+ # upsample
47
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
48
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
49
+
50
+ if self.skip_connection:
51
+ x4 = x4 + x2
52
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
53
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
54
+
55
+ if self.skip_connection:
56
+ x5 = x5 + x1
57
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
58
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
59
+
60
+ if self.skip_connection:
61
+ x6 = x6 + x0
62
+
63
+ # extra convolutions
64
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
65
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
66
+ out = self.conv9(out)
67
+
68
+ return out
69
+
70
+
71
+
72
+ def get_conv_layer(input_nc, ndf, kernel_size, stride, padding, bias=True, use_sn=False):
73
+ if not use_sn:
74
+ return nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
75
+ return spectral_norm(nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
76
+
77
+
78
+ class PatchDiscriminator(nn.Module):
79
+ """Defines a PatchGAN discriminator, the receptive field of default config is 70x70.
80
+
81
+ Args:
82
+ use_sn (bool): Use spectra_norm or not, if use_sn is True, then norm_type should be none.
83
+ """
84
+
85
+ def __init__(self,
86
+ num_in_ch,
87
+ num_feat=64,
88
+ num_layers=3,
89
+ max_nf_mult=8,
90
+ norm_type='batch',
91
+ use_sigmoid=False,
92
+ use_sn=False):
93
+ super(PatchDiscriminator, self).__init__()
94
+
95
+ norm_layer = self._get_norm_layer(norm_type)
96
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
97
+ use_bias = norm_layer.func != nn.BatchNorm2d
98
+ else:
99
+ use_bias = norm_layer != nn.BatchNorm2d
100
+
101
+ kw = 4
102
+ padw = 1
103
+ sequence = [
104
+ get_conv_layer(num_in_ch, num_feat, kernel_size=kw, stride=2, padding=padw, use_sn=use_sn),
105
+ nn.LeakyReLU(0.2, True)
106
+ ]
107
+ nf_mult = 1
108
+ nf_mult_prev = 1
109
+ for n in range(1, num_layers): # gradually increase the number of filters
110
+ nf_mult_prev = nf_mult
111
+ nf_mult = min(2**n, max_nf_mult)
112
+ sequence += [
113
+ get_conv_layer(
114
+ num_feat * nf_mult_prev,
115
+ num_feat * nf_mult,
116
+ kernel_size=kw,
117
+ stride=2,
118
+ padding=padw,
119
+ bias=use_bias,
120
+ use_sn=use_sn),
121
+ norm_layer(num_feat * nf_mult),
122
+ nn.LeakyReLU(0.2, True)
123
+ ]
124
+
125
+ nf_mult_prev = nf_mult
126
+ nf_mult = min(2**num_layers, max_nf_mult)
127
+ sequence += [
128
+ get_conv_layer(
129
+ num_feat * nf_mult_prev,
130
+ num_feat * nf_mult,
131
+ kernel_size=kw,
132
+ stride=1,
133
+ padding=padw,
134
+ bias=use_bias,
135
+ use_sn=use_sn),
136
+ norm_layer(num_feat * nf_mult),
137
+ nn.LeakyReLU(0.2, True)
138
+ ]
139
+
140
+ # output 1 channel prediction map 我觉得这个应该就是pixel by pixel的feedback反馈
141
+ sequence += [get_conv_layer(num_feat * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, use_sn=use_sn)]
142
+
143
+ if use_sigmoid:
144
+ sequence += [nn.Sigmoid()]
145
+ self.model = nn.Sequential(*sequence)
146
+
147
+ def _get_norm_layer(self, norm_type='batch'):
148
+ if norm_type == 'batch':
149
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
150
+ elif norm_type == 'instance':
151
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
152
+ elif norm_type == 'batchnorm2d':
153
+ norm_layer = nn.BatchNorm2d
154
+ elif norm_type == 'none':
155
+ norm_layer = nn.Identity
156
+ else:
157
+ raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
158
+
159
+ return norm_layer
160
+
161
+ def forward(self, x):
162
+ return self.model(x)
163
+
164
+
165
+ class MultiScaleDiscriminator(nn.Module):
166
+ """Define a multi-scale discriminator, each discriminator is a instance of PatchDiscriminator.
167
+
168
+ Args:
169
+ num_layers (int or list): If the type of this variable is int, then degrade to PatchDiscriminator.
170
+ If the type of this variable is list, then the length of the list is
171
+ the number of discriminators.
172
+ use_downscale (bool): Progressive downscale the input to feed into different discriminators.
173
+ If set to True, then the discriminators are usually the same.
174
+ """
175
+
176
+ def __init__(self,
177
+ num_in_ch,
178
+ num_feat=64,
179
+ num_layers=[3, 3, 3],
180
+ max_nf_mult=8,
181
+ norm_type='none',
182
+ use_sigmoid=False,
183
+ use_sn=True,
184
+ use_downscale=True):
185
+ super(MultiScaleDiscriminator, self).__init__()
186
+
187
+ if isinstance(num_layers, int):
188
+ num_layers = [num_layers]
189
+
190
+ # check whether the discriminators are the same
191
+ if use_downscale:
192
+ assert len(set(num_layers)) == 1
193
+ self.use_downscale = use_downscale
194
+
195
+ self.num_dis = len(num_layers)
196
+ self.dis_list = nn.ModuleList()
197
+ for nl in num_layers:
198
+ self.dis_list.append(
199
+ PatchDiscriminator(
200
+ num_in_ch,
201
+ num_feat=num_feat,
202
+ num_layers=nl,
203
+ max_nf_mult=max_nf_mult,
204
+ norm_type=norm_type,
205
+ use_sigmoid=use_sigmoid,
206
+ use_sn=use_sn,
207
+ ))
208
+
209
+ def forward(self, x):
210
+ outs = []
211
+ h, w = x.size()[2:]
212
+
213
+ y = x
214
+ for i in range(self.num_dis):
215
+ if i != 0 and self.use_downscale:
216
+ y = F.interpolate(y, size=(h // 2, w // 2), mode='bilinear', align_corners=True)
217
+ h, w = y.size()[2:]
218
+ outs.append(self.dis_list[i](y))
219
+
220
+ return outs
221
+
222
+
223
+ def main():
224
+ from pthflops import count_ops
225
+ from torchsummary import summary
226
+
227
+ model = UNetDiscriminatorSN(3)
228
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
229
+
230
+ # Create a network and a corresponding input
231
+ device = 'cuda'
232
+ inp = torch.rand(1, 3, 400, 400)
233
+
234
+ # Count the number of FLOPs
235
+ count_ops(model, inp)
236
+ summary(model.cuda(), (3, 400, 400), batch_size=1)
237
+ # print(f"pathGAN has param {pytorch_total_params//1000} K params")
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
architecture/grl.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Efficient and Explicit Modelling of Image Hierarchies for Image Restoration
3
+ Image restoration transformers with global, regional, and local modelling
4
+ A clean version of the.
5
+ Shared buffers are used for relative_coords_table, relative_position_index, and attn_mask.
6
+ """
7
+ import cv2
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torchvision.transforms import ToTensor
12
+ from torchvision.utils import save_image
13
+ from fairscale.nn import checkpoint_wrapper
14
+ from omegaconf import OmegaConf
15
+ from timm.models.layers import to_2tuple, trunc_normal_
16
+
17
+ # Import files from local folder
18
+ import os, sys
19
+ root_path = os.path.abspath('.')
20
+ sys.path.append(root_path)
21
+
22
+ from architecture.grl_common import Upsample, UpsampleOneStep
23
+ from architecture.grl_common.mixed_attn_block_efficient import (
24
+ _get_stripe_info,
25
+ EfficientMixAttnTransformerBlock,
26
+ )
27
+ from architecture.grl_common.ops import (
28
+ bchw_to_blc,
29
+ blc_to_bchw,
30
+ calculate_mask,
31
+ calculate_mask_all,
32
+ get_relative_coords_table_all,
33
+ get_relative_position_index_simple,
34
+ )
35
+ from architecture.grl_common.swin_v1_block import (
36
+ build_last_conv,
37
+ )
38
+
39
+
40
+ class TransformerStage(nn.Module):
41
+ """Transformer stage.
42
+ Args:
43
+ dim (int): Number of input channels.
44
+ input_resolution (tuple[int]): Input resolution.
45
+ depth (int): Number of blocks.
46
+ num_heads_window (list[int]): Number of window attention heads in different layers.
47
+ num_heads_stripe (list[int]): Number of stripe attention heads in different layers.
48
+ stripe_size (list[int]): Stripe size. Default: [8, 8]
49
+ stripe_groups (list[int]): Number of stripe groups. Default: [None, None].
50
+ stripe_shift (bool): whether to shift the stripes. This is used as an ablation study.
51
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
52
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
53
+ qkv_proj_type (str): QKV projection type. Default: linear. Choices: linear, separable_conv.
54
+ anchor_proj_type (str): Anchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging.
55
+ anchor_one_stage (bool): Whether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True.
56
+ anchor_window_down_factor (int): The downscale factor used to get the anchors.
57
+ drop (float, optional): Dropout rate. Default: 0.0
58
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
59
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
60
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
61
+ pretrained_window_size (list[int]): pretrained window size. This is actually not used. Default: [0, 0].
62
+ pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0].
63
+ conv_type: The convolutional block before residual connection.
64
+ init_method: initialization method of the weight parameters used to train large scale models.
65
+ Choices: n, normal -- Swin V1 init method.
66
+ l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
67
+ r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
68
+ w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
69
+ t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
70
+ fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
71
+ offload_to_cpu (bool): used by fairscale_checkpoint
72
+ args:
73
+ out_proj_type (str): Type of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d.
74
+ local_connection (bool): Whether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used. "local_connection": local_connection,
75
+ euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ dim,
81
+ input_resolution,
82
+ depth,
83
+ num_heads_window,
84
+ num_heads_stripe,
85
+ window_size,
86
+ stripe_size,
87
+ stripe_groups,
88
+ stripe_shift,
89
+ mlp_ratio=4.0,
90
+ qkv_bias=True,
91
+ qkv_proj_type="linear",
92
+ anchor_proj_type="avgpool",
93
+ anchor_one_stage=True,
94
+ anchor_window_down_factor=1,
95
+ drop=0.0,
96
+ attn_drop=0.0,
97
+ drop_path=0.0,
98
+ norm_layer=nn.LayerNorm,
99
+ pretrained_window_size=[0, 0],
100
+ pretrained_stripe_size=[0, 0],
101
+ conv_type="1conv",
102
+ init_method="",
103
+ fairscale_checkpoint=False,
104
+ offload_to_cpu=False,
105
+ args=None,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.dim = dim
110
+ self.input_resolution = input_resolution
111
+ self.init_method = init_method
112
+
113
+ self.blocks = nn.ModuleList()
114
+ for i in range(depth):
115
+ block = EfficientMixAttnTransformerBlock(
116
+ dim=dim,
117
+ input_resolution=input_resolution,
118
+ num_heads_w=num_heads_window,
119
+ num_heads_s=num_heads_stripe,
120
+ window_size=window_size,
121
+ window_shift=i % 2 == 0,
122
+ stripe_size=stripe_size,
123
+ stripe_groups=stripe_groups,
124
+ stripe_type="H" if i % 2 == 0 else "W",
125
+ stripe_shift=i % 4 in [2, 3] if stripe_shift else False,
126
+ mlp_ratio=mlp_ratio,
127
+ qkv_bias=qkv_bias,
128
+ qkv_proj_type=qkv_proj_type,
129
+ anchor_proj_type=anchor_proj_type,
130
+ anchor_one_stage=anchor_one_stage,
131
+ anchor_window_down_factor=anchor_window_down_factor,
132
+ drop=drop,
133
+ attn_drop=attn_drop,
134
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
135
+ norm_layer=norm_layer,
136
+ pretrained_window_size=pretrained_window_size,
137
+ pretrained_stripe_size=pretrained_stripe_size,
138
+ res_scale=0.1 if init_method == "r" else 1.0,
139
+ args=args,
140
+ )
141
+ # print(fairscale_checkpoint, offload_to_cpu)
142
+ if fairscale_checkpoint:
143
+ block = checkpoint_wrapper(block, offload_to_cpu=offload_to_cpu)
144
+ self.blocks.append(block)
145
+
146
+ self.conv = build_last_conv(conv_type, dim)
147
+
148
+ def _init_weights(self):
149
+ for n, m in self.named_modules():
150
+ if self.init_method == "w":
151
+ if isinstance(m, (nn.Linear, nn.Conv2d)) and n.find("cpb_mlp") < 0:
152
+ print("nn.Linear and nn.Conv2d weight initilization")
153
+ m.weight.data *= 0.1
154
+ elif self.init_method == "l":
155
+ if isinstance(m, nn.LayerNorm):
156
+ print("nn.LayerNorm initialization")
157
+ nn.init.constant_(m.bias, 0)
158
+ nn.init.constant_(m.weight, 0)
159
+ elif self.init_method.find("t") >= 0:
160
+ scale = 0.1 ** (len(self.init_method) - 1) * int(self.init_method[-1])
161
+ if isinstance(m, nn.Linear) and n.find("cpb_mlp") < 0:
162
+ trunc_normal_(m.weight, std=scale)
163
+ elif isinstance(m, nn.Conv2d):
164
+ m.weight.data *= 0.1
165
+ print(
166
+ "Initialization nn.Linear - trunc_normal; nn.Conv2d - weight rescale."
167
+ )
168
+ else:
169
+ raise NotImplementedError(
170
+ f"Parameter initialization method {self.init_method} not implemented in TransformerStage."
171
+ )
172
+
173
+ def forward(self, x, x_size, table_index_mask):
174
+ res = x
175
+ for blk in self.blocks:
176
+ res = blk(res, x_size, table_index_mask)
177
+ res = bchw_to_blc(self.conv(blc_to_bchw(res, x_size)))
178
+
179
+ return res + x
180
+
181
+ def flops(self):
182
+ pass
183
+
184
+
185
+ class GRL(nn.Module):
186
+ r"""Image restoration transformer with global, non-local, and local connections
187
+ Args:
188
+ img_size (int | list[int]): Input image size. Default 64
189
+ in_channels (int): Number of input image channels. Default: 3
190
+ out_channels (int): Number of output image channels. Default: None
191
+ embed_dim (int): Patch embedding dimension. Default: 96
192
+ upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
193
+ img_range (float): Image range. 1. or 255.
194
+ upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
195
+ depths (list[int]): Depth of each Swin Transformer layer.
196
+ num_heads_window (list[int]): Number of window attention heads in different layers.
197
+ num_heads_stripe (list[int]): Number of stripe attention heads in different layers.
198
+ window_size (int): Window size. Default: 8.
199
+ stripe_size (list[int]): Stripe size. Default: [8, 8]
200
+ stripe_groups (list[int]): Number of stripe groups. Default: [None, None].
201
+ stripe_shift (bool): whether to shift the stripes. This is used as an ablation study.
202
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
203
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
204
+ qkv_proj_type (str): QKV projection type. Default: linear. Choices: linear, separable_conv.
205
+ anchor_proj_type (str): Anchor projection type. Default: avgpool. Choices: avgpool, maxpool, conv2d, separable_conv, patchmerging.
206
+ anchor_one_stage (bool): Whether to use one operator or multiple progressive operators to reduce feature map resolution. Default: True.
207
+ anchor_window_down_factor (int): The downscale factor used to get the anchors.
208
+ out_proj_type (str): Type of the output projection in the self-attention modules. Default: linear. Choices: linear, conv2d.
209
+ local_connection (bool): Whether to enable the local modelling module (two convs followed by Channel attention). For GRL base model, this is used.
210
+ drop_rate (float): Dropout rate. Default: 0
211
+ attn_drop_rate (float): Attention dropout rate. Default: 0
212
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
213
+ pretrained_window_size (list[int]): pretrained window size. This is actually not used. Default: [0, 0].
214
+ pretrained_stripe_size (list[int]): pretrained stripe size. This is actually not used. Default: [0, 0].
215
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
216
+ conv_type (str): The convolutional block before residual connection. Default: 1conv. Choices: 1conv, 3conv, 1conv1x1, linear
217
+ init_method: initialization method of the weight parameters used to train large scale models.
218
+ Choices: n, normal -- Swin V1 init method.
219
+ l, layernorm -- Swin V2 init method. Zero the weight and bias in the post layer normalization layer.
220
+ r, res_rescale -- EDSR rescale method. Rescale the residual blocks with a scaling factor 0.1
221
+ w, weight_rescale -- MSRResNet rescale method. Rescale the weight parameter in residual blocks with a scaling factor 0.1
222
+ t, trunc_normal_ -- nn.Linear, trunc_normal; nn.Conv2d, weight_rescale
223
+ fairscale_checkpoint (bool): Whether to use fairscale checkpoint.
224
+ offload_to_cpu (bool): used by fairscale_checkpoint
225
+ euclidean_dist (bool): use Euclidean distance or inner product as the similarity metric. An ablation study.
226
+
227
+ """
228
+
229
+ def __init__(
230
+ self,
231
+ img_size=64,
232
+ in_channels=3,
233
+ out_channels=None,
234
+ embed_dim=96,
235
+ upscale=2,
236
+ img_range=1.0,
237
+ upsampler="",
238
+ depths=[6, 6, 6, 6, 6, 6],
239
+ num_heads_window=[3, 3, 3, 3, 3, 3],
240
+ num_heads_stripe=[3, 3, 3, 3, 3, 3],
241
+ window_size=8,
242
+ stripe_size=[8, 8], # used for stripe window attention
243
+ stripe_groups=[None, None],
244
+ stripe_shift=False,
245
+ mlp_ratio=4.0,
246
+ qkv_bias=True,
247
+ qkv_proj_type="linear",
248
+ anchor_proj_type="avgpool",
249
+ anchor_one_stage=True,
250
+ anchor_window_down_factor=1,
251
+ out_proj_type="linear",
252
+ local_connection=False,
253
+ drop_rate=0.0,
254
+ attn_drop_rate=0.0,
255
+ drop_path_rate=0.1,
256
+ norm_layer=nn.LayerNorm,
257
+ pretrained_window_size=[0, 0],
258
+ pretrained_stripe_size=[0, 0],
259
+ conv_type="1conv",
260
+ init_method="n", # initialization method of the weight parameters used to train large scale models.
261
+ fairscale_checkpoint=False, # fairscale activation checkpointing
262
+ offload_to_cpu=False,
263
+ euclidean_dist=False,
264
+ **kwargs,
265
+ ):
266
+ super(GRL, self).__init__()
267
+ # Process the input arguments
268
+ out_channels = out_channels or in_channels
269
+ self.in_channels = in_channels
270
+ self.out_channels = out_channels
271
+ num_out_feats = 64
272
+ self.embed_dim = embed_dim
273
+ self.upscale = upscale
274
+ self.upsampler = upsampler
275
+ self.img_range = img_range
276
+ if in_channels == 3:
277
+ rgb_mean = (0.4488, 0.4371, 0.4040)
278
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
279
+ else:
280
+ self.mean = torch.zeros(1, 1, 1, 1)
281
+
282
+ max_stripe_size = max([0 if s is None else s for s in stripe_size])
283
+ max_stripe_groups = max([0 if s is None else s for s in stripe_groups])
284
+ max_stripe_groups *= anchor_window_down_factor
285
+ self.pad_size = max(window_size, max_stripe_size, max_stripe_groups)
286
+ # if max_stripe_size >= window_size:
287
+ # self.pad_size *= anchor_window_down_factor
288
+ # if stripe_groups[0] is None and stripe_groups[1] is None:
289
+ # self.pad_size = max(stripe_size)
290
+ # else:
291
+ # self.pad_size = window_size
292
+ self.input_resolution = to_2tuple(img_size)
293
+ self.window_size = to_2tuple(window_size)
294
+ self.shift_size = [w // 2 for w in self.window_size]
295
+ self.stripe_size = stripe_size
296
+ self.stripe_groups = stripe_groups
297
+ self.pretrained_window_size = pretrained_window_size
298
+ self.pretrained_stripe_size = pretrained_stripe_size
299
+ self.anchor_window_down_factor = anchor_window_down_factor
300
+
301
+ # Head of the network. First convolution.
302
+ self.conv_first = nn.Conv2d(in_channels, embed_dim, 3, 1, 1)
303
+
304
+ # Body of the network
305
+ self.norm_start = norm_layer(embed_dim)
306
+ self.pos_drop = nn.Dropout(p=drop_rate)
307
+
308
+ # stochastic depth
309
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
310
+ # stochastic depth decay rule
311
+ args = OmegaConf.create(
312
+ {
313
+ "out_proj_type": out_proj_type,
314
+ "local_connection": local_connection,
315
+ "euclidean_dist": euclidean_dist,
316
+ }
317
+ )
318
+ for k, v in self.set_table_index_mask(self.input_resolution).items():
319
+ self.register_buffer(k, v)
320
+
321
+ self.layers = nn.ModuleList()
322
+ for i in range(len(depths)):
323
+ layer = TransformerStage(
324
+ dim=embed_dim,
325
+ input_resolution=self.input_resolution,
326
+ depth=depths[i],
327
+ num_heads_window=num_heads_window[i],
328
+ num_heads_stripe=num_heads_stripe[i],
329
+ window_size=self.window_size,
330
+ stripe_size=stripe_size,
331
+ stripe_groups=stripe_groups,
332
+ stripe_shift=stripe_shift,
333
+ mlp_ratio=mlp_ratio,
334
+ qkv_bias=qkv_bias,
335
+ qkv_proj_type=qkv_proj_type,
336
+ anchor_proj_type=anchor_proj_type,
337
+ anchor_one_stage=anchor_one_stage,
338
+ anchor_window_down_factor=anchor_window_down_factor,
339
+ drop=drop_rate,
340
+ attn_drop=attn_drop_rate,
341
+ drop_path=dpr[
342
+ sum(depths[:i]) : sum(depths[: i + 1])
343
+ ], # no impact on SR results
344
+ norm_layer=norm_layer,
345
+ pretrained_window_size=pretrained_window_size,
346
+ pretrained_stripe_size=pretrained_stripe_size,
347
+ conv_type=conv_type,
348
+ init_method=init_method,
349
+ fairscale_checkpoint=fairscale_checkpoint,
350
+ offload_to_cpu=offload_to_cpu,
351
+ args=args,
352
+ )
353
+ self.layers.append(layer)
354
+ self.norm_end = norm_layer(embed_dim)
355
+
356
+ # Tail of the network
357
+ self.conv_after_body = build_last_conv(conv_type, embed_dim)
358
+
359
+ #####################################################################################################
360
+ ################################ 3, high quality image reconstruction ################################
361
+ if self.upsampler == "pixelshuffle":
362
+ # for classical SR
363
+ self.conv_before_upsample = nn.Sequential(
364
+ nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True)
365
+ )
366
+ self.upsample = Upsample(upscale, num_out_feats)
367
+ self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1)
368
+ elif self.upsampler == "pixelshuffledirect":
369
+ # for lightweight SR (to save parameters)
370
+ self.upsample = UpsampleOneStep(
371
+ upscale,
372
+ embed_dim,
373
+ out_channels,
374
+ )
375
+ elif self.upsampler == "nearest+conv":
376
+ # for real-world SR (less artifacts)
377
+ assert self.upscale == 4, "only support x4 now."
378
+ self.conv_before_upsample = nn.Sequential(
379
+ nn.Conv2d(embed_dim, num_out_feats, 3, 1, 1), nn.LeakyReLU(inplace=True)
380
+ )
381
+ self.conv_up1 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
382
+ self.conv_up2 = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
383
+ self.conv_hr = nn.Conv2d(num_out_feats, num_out_feats, 3, 1, 1)
384
+ self.conv_last = nn.Conv2d(num_out_feats, out_channels, 3, 1, 1)
385
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
386
+ else:
387
+ # for image denoising and JPEG compression artifact reduction
388
+ self.conv_last = nn.Conv2d(embed_dim, out_channels, 3, 1, 1)
389
+
390
+ self.apply(self._init_weights)
391
+ if init_method in ["l", "w"] or init_method.find("t") >= 0:
392
+ for layer in self.layers:
393
+ layer._init_weights()
394
+
395
+ def set_table_index_mask(self, x_size):
396
+ """
397
+ Two used cases:
398
+ 1) At initialization: set the shared buffers.
399
+ 2) During forward pass: get the new buffers if the resolution of the input changes
400
+ """
401
+ # ss - stripe_size, sss - stripe_shift_size
402
+ ss, sss = _get_stripe_info(self.stripe_size, self.stripe_groups, True, x_size)
403
+ df = self.anchor_window_down_factor
404
+
405
+ table_w = get_relative_coords_table_all(
406
+ self.window_size, self.pretrained_window_size
407
+ )
408
+ table_sh = get_relative_coords_table_all(ss, self.pretrained_stripe_size, df)
409
+ table_sv = get_relative_coords_table_all(
410
+ ss[::-1], self.pretrained_stripe_size, df
411
+ )
412
+
413
+ index_w = get_relative_position_index_simple(self.window_size)
414
+ index_sh_a2w = get_relative_position_index_simple(ss, df, False)
415
+ index_sh_w2a = get_relative_position_index_simple(ss, df, True)
416
+ index_sv_a2w = get_relative_position_index_simple(ss[::-1], df, False)
417
+ index_sv_w2a = get_relative_position_index_simple(ss[::-1], df, True)
418
+
419
+ mask_w = calculate_mask(x_size, self.window_size, self.shift_size)
420
+ mask_sh_a2w = calculate_mask_all(x_size, ss, sss, df, False)
421
+ mask_sh_w2a = calculate_mask_all(x_size, ss, sss, df, True)
422
+ mask_sv_a2w = calculate_mask_all(x_size, ss[::-1], sss[::-1], df, False)
423
+ mask_sv_w2a = calculate_mask_all(x_size, ss[::-1], sss[::-1], df, True)
424
+ return {
425
+ "table_w": table_w,
426
+ "table_sh": table_sh,
427
+ "table_sv": table_sv,
428
+ "index_w": index_w,
429
+ "index_sh_a2w": index_sh_a2w,
430
+ "index_sh_w2a": index_sh_w2a,
431
+ "index_sv_a2w": index_sv_a2w,
432
+ "index_sv_w2a": index_sv_w2a,
433
+ "mask_w": mask_w,
434
+ "mask_sh_a2w": mask_sh_a2w,
435
+ "mask_sh_w2a": mask_sh_w2a,
436
+ "mask_sv_a2w": mask_sv_a2w,
437
+ "mask_sv_w2a": mask_sv_w2a,
438
+ }
439
+
440
+ def get_table_index_mask(self, device=None, input_resolution=None):
441
+ # Used during forward pass
442
+ if input_resolution == self.input_resolution:
443
+ return {
444
+ "table_w": self.table_w,
445
+ "table_sh": self.table_sh,
446
+ "table_sv": self.table_sv,
447
+ "index_w": self.index_w,
448
+ "index_sh_a2w": self.index_sh_a2w,
449
+ "index_sh_w2a": self.index_sh_w2a,
450
+ "index_sv_a2w": self.index_sv_a2w,
451
+ "index_sv_w2a": self.index_sv_w2a,
452
+ "mask_w": self.mask_w,
453
+ "mask_sh_a2w": self.mask_sh_a2w,
454
+ "mask_sh_w2a": self.mask_sh_w2a,
455
+ "mask_sv_a2w": self.mask_sv_a2w,
456
+ "mask_sv_w2a": self.mask_sv_w2a,
457
+ }
458
+ else:
459
+ table_index_mask = self.set_table_index_mask(input_resolution)
460
+ for k, v in table_index_mask.items():
461
+ table_index_mask[k] = v.to(device)
462
+ return table_index_mask
463
+
464
+ def _init_weights(self, m):
465
+ if isinstance(m, nn.Linear):
466
+ # Only used to initialize linear layers
467
+ # weight_shape = m.weight.shape
468
+ # if weight_shape[0] > 256 and weight_shape[1] > 256:
469
+ # std = 0.004
470
+ # else:
471
+ # std = 0.02
472
+ # print(f"Standard deviation during initialization {std}.")
473
+ trunc_normal_(m.weight, std=0.02)
474
+ if isinstance(m, nn.Linear) and m.bias is not None:
475
+ nn.init.constant_(m.bias, 0)
476
+ elif isinstance(m, nn.LayerNorm):
477
+ nn.init.constant_(m.bias, 0)
478
+ nn.init.constant_(m.weight, 1.0)
479
+
480
+ @torch.jit.ignore
481
+ def no_weight_decay(self):
482
+ return {"absolute_pos_embed"}
483
+
484
+ @torch.jit.ignore
485
+ def no_weight_decay_keywords(self):
486
+ return {"relative_position_bias_table"}
487
+
488
+ def check_image_size(self, x):
489
+ _, _, h, w = x.size()
490
+ mod_pad_h = (self.pad_size - h % self.pad_size) % self.pad_size
491
+ mod_pad_w = (self.pad_size - w % self.pad_size) % self.pad_size
492
+ # print("padding size", h, w, self.pad_size, mod_pad_h, mod_pad_w)
493
+
494
+ try:
495
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
496
+ except BaseException:
497
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant")
498
+ return x
499
+
500
+ def forward_features(self, x):
501
+ x_size = (x.shape[2], x.shape[3])
502
+ x = bchw_to_blc(x)
503
+ x = self.norm_start(x)
504
+ x = self.pos_drop(x)
505
+
506
+ table_index_mask = self.get_table_index_mask(x.device, x_size)
507
+ for layer in self.layers:
508
+ x = layer(x, x_size, table_index_mask)
509
+
510
+ x = self.norm_end(x) # B L C
511
+ x = blc_to_bchw(x, x_size)
512
+
513
+ return x
514
+
515
+ def forward(self, x):
516
+ H, W = x.shape[2:]
517
+ x = self.check_image_size(x)
518
+
519
+ self.mean = self.mean.type_as(x)
520
+ x = (x - self.mean) * self.img_range
521
+
522
+ if self.upsampler == "pixelshuffle":
523
+ # for classical SR
524
+ x = self.conv_first(x)
525
+ x = self.conv_after_body(self.forward_features(x)) + x
526
+ x = self.conv_before_upsample(x)
527
+ x = self.conv_last(self.upsample(x))
528
+ elif self.upsampler == "pixelshuffledirect":
529
+ # for lightweight SR
530
+ x = self.conv_first(x)
531
+ x = self.conv_after_body(self.forward_features(x)) + x
532
+ x = self.upsample(x)
533
+ elif self.upsampler == "nearest+conv":
534
+ # for real-world SR (claimed to have less artifacts)
535
+ x = self.conv_first(x)
536
+ x = self.conv_after_body(self.forward_features(x)) + x
537
+ x = self.conv_before_upsample(x)
538
+ x = self.lrelu(
539
+ self.conv_up1(
540
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
541
+ )
542
+ )
543
+ x = self.lrelu(
544
+ self.conv_up2(
545
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
546
+ )
547
+ )
548
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
549
+ else:
550
+ # for image denoising and JPEG compression artifact reduction
551
+ x_first = self.conv_first(x)
552
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
553
+ if self.in_channels == self.out_channels:
554
+ x = x + self.conv_last(res)
555
+ else:
556
+ x = self.conv_last(res)
557
+
558
+ x = x / self.img_range + self.mean
559
+
560
+ return x[:, :, : H * self.upscale, : W * self.upscale]
561
+
562
+ def flops(self):
563
+ pass
564
+
565
+ def convert_checkpoint(self, state_dict):
566
+ for k in list(state_dict.keys()):
567
+ if (
568
+ k.find("relative_coords_table") >= 0
569
+ or k.find("relative_position_index") >= 0
570
+ or k.find("attn_mask") >= 0
571
+ or k.find("model.table_") >= 0
572
+ or k.find("model.index_") >= 0
573
+ or k.find("model.mask_") >= 0
574
+ # or k.find(".upsample.") >= 0
575
+ ):
576
+ state_dict.pop(k)
577
+ print(k)
578
+ return state_dict
579
+
580
+
581
+ if __name__ == "__main__":
582
+ # The version of GRL we use
583
+ model = GRL(
584
+ upscale = 4,
585
+ img_size = 64,
586
+ window_size = 8,
587
+ depths = [4, 4, 4, 4],
588
+ embed_dim = 64,
589
+ num_heads_window = [2, 2, 2, 2],
590
+ num_heads_stripe = [2, 2, 2, 2],
591
+ mlp_ratio = 2,
592
+ qkv_proj_type = "linear",
593
+ anchor_proj_type = "avgpool",
594
+ anchor_window_down_factor = 2,
595
+ out_proj_type = "linear",
596
+ conv_type = "1conv",
597
+ upsampler = "nearest+conv", # Change
598
+ ).cuda()
599
+
600
+ # Parameter analysis
601
+ num_params = 0
602
+ for p in model.parameters():
603
+ if p.requires_grad:
604
+ num_params += p.numel()
605
+ print(f"Number of parameters {num_params / 10 ** 6: 0.2f}")
606
+
607
+ # Print param
608
+ for name, param in model.named_parameters():
609
+ print(name, param.dtype)
610
+
611
+
612
+ # Count the number of FLOPs to double check
613
+ x = torch.randn((1, 3, 180, 180)).cuda() # Don't use input size that is too big (we don't have @torch.no_grad here)
614
+ x = model(x)
615
+ print("output size is ", x.shape)
616
+
architecture/grl_common/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from architecture.grl_common.resblock import ResBlock
2
+ from architecture.grl_common.upsample import (
3
+ Upsample,
4
+ UpsampleOneStep,
5
+ )
6
+
7
+
8
+ __all__ = ["Upsample", "UpsampleOneStep", "ResBlock"]
architecture/grl_common/common_edsr.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EDSR common.py
3
+ Since a lot of models are developed on top of EDSR, here we include some common functions from EDSR.
4
+ In this repository, the common functions is used by edsr_esa.py and ipt.py
5
+ """
6
+
7
+
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
16
+ return nn.Conv2d(
17
+ in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias
18
+ )
19
+
20
+
21
+ class MeanShift(nn.Conv2d):
22
+ def __init__(
23
+ self,
24
+ rgb_range,
25
+ rgb_mean=(0.4488, 0.4371, 0.4040),
26
+ rgb_std=(1.0, 1.0, 1.0),
27
+ sign=-1,
28
+ ):
29
+
30
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
31
+ std = torch.Tensor(rgb_std)
32
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
33
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
34
+ for p in self.parameters():
35
+ p.requires_grad = False
36
+
37
+
38
+ class BasicBlock(nn.Sequential):
39
+ def __init__(
40
+ self,
41
+ conv,
42
+ in_channels,
43
+ out_channels,
44
+ kernel_size,
45
+ stride=1,
46
+ bias=False,
47
+ bn=True,
48
+ act=nn.ReLU(True),
49
+ ):
50
+
51
+ m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
52
+ if bn:
53
+ m.append(nn.BatchNorm2d(out_channels))
54
+ if act is not None:
55
+ m.append(act)
56
+
57
+ super(BasicBlock, self).__init__(*m)
58
+
59
+
60
+ class ESA(nn.Module):
61
+ def __init__(self, esa_channels, n_feats):
62
+ super(ESA, self).__init__()
63
+ f = esa_channels
64
+ self.conv1 = nn.Conv2d(n_feats, f, kernel_size=1)
65
+ self.conv_f = nn.Conv2d(f, f, kernel_size=1)
66
+ # self.conv_max = conv(f, f, kernel_size=3, padding=1)
67
+ self.conv2 = nn.Conv2d(f, f, kernel_size=3, stride=2, padding=0)
68
+ self.conv3 = nn.Conv2d(f, f, kernel_size=3, padding=1)
69
+ # self.conv3_ = conv(f, f, kernel_size=3, padding=1)
70
+ self.conv4 = nn.Conv2d(f, n_feats, kernel_size=1)
71
+ self.sigmoid = nn.Sigmoid()
72
+ # self.relu = nn.ReLU(inplace=True)
73
+
74
+ def forward(self, x):
75
+ c1_ = self.conv1(x)
76
+ c1 = self.conv2(c1_)
77
+ v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
78
+ c3 = self.conv3(v_max)
79
+ # v_range = self.relu(self.conv_max(v_max))
80
+ # c3 = self.relu(self.conv3(v_range))
81
+ # c3 = self.conv3_(c3)
82
+ c3 = F.interpolate(
83
+ c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
84
+ )
85
+ cf = self.conv_f(c1_)
86
+ c4 = self.conv4(c3 + cf)
87
+ m = self.sigmoid(c4)
88
+
89
+ return x * m
90
+
91
+
92
+ # class ESA(nn.Module):
93
+ # def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
94
+ # super(ESA, self).__init__()
95
+ # f = n_feats // 4
96
+ # self.conv1 = conv(n_feats, f, kernel_size=1)
97
+ # self.conv_f = conv(f, f, kernel_size=1)
98
+ # self.conv_max = conv(f, f, kernel_size=3, padding=1)
99
+ # self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
100
+ # self.conv3 = conv(f, f, kernel_size=3, padding=1)
101
+ # self.conv3_ = conv(f, f, kernel_size=3, padding=1)
102
+ # self.conv4 = conv(f, n_feats, kernel_size=1)
103
+ # self.sigmoid = nn.Sigmoid()
104
+ # self.relu = nn.ReLU(inplace=True)
105
+ #
106
+ # def forward(self, x):
107
+ # c1_ = (self.conv1(x))
108
+ # c1 = self.conv2(c1_)
109
+ # v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
110
+ # v_range = self.relu(self.conv_max(v_max))
111
+ # c3 = self.relu(self.conv3(v_range))
112
+ # c3 = self.conv3_(c3)
113
+ # c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
114
+ # cf = self.conv_f(c1_)
115
+ # c4 = self.conv4(c3 + cf)
116
+ # m = self.sigmoid(c4)
117
+ #
118
+ # return x * m
119
+
120
+
121
+ class ResBlock(nn.Module):
122
+ def __init__(
123
+ self,
124
+ conv,
125
+ n_feats,
126
+ kernel_size,
127
+ bias=True,
128
+ bn=False,
129
+ act=nn.ReLU(True),
130
+ res_scale=1,
131
+ esa_block=True,
132
+ depth_wise_kernel=7,
133
+ ):
134
+
135
+ super(ResBlock, self).__init__()
136
+ m = []
137
+ for i in range(2):
138
+ m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
139
+ if bn:
140
+ m.append(nn.BatchNorm2d(n_feats))
141
+ if i == 0:
142
+ m.append(act)
143
+
144
+ self.body = nn.Sequential(*m)
145
+ self.esa_block = esa_block
146
+ if self.esa_block:
147
+ esa_channels = 16
148
+ self.c5 = nn.Conv2d(
149
+ n_feats,
150
+ n_feats,
151
+ depth_wise_kernel,
152
+ padding=depth_wise_kernel // 2,
153
+ groups=n_feats,
154
+ bias=True,
155
+ )
156
+ self.esa = ESA(esa_channels, n_feats)
157
+ self.res_scale = res_scale
158
+
159
+ def forward(self, x):
160
+ res = self.body(x).mul(self.res_scale)
161
+ res += x
162
+ if self.esa_block:
163
+ res = self.esa(self.c5(res))
164
+
165
+ return res
166
+
167
+
168
+ class Upsampler(nn.Sequential):
169
+ def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
170
+
171
+ m = []
172
+ if (scale & (scale - 1)) == 0: # Is scale = 2^n?
173
+ for _ in range(int(math.log(scale, 2))):
174
+ m.append(conv(n_feats, 4 * n_feats, 3, bias))
175
+ m.append(nn.PixelShuffle(2))
176
+ if bn:
177
+ m.append(nn.BatchNorm2d(n_feats))
178
+ if act == "relu":
179
+ m.append(nn.ReLU(True))
180
+ elif act == "prelu":
181
+ m.append(nn.PReLU(n_feats))
182
+
183
+ elif scale == 3:
184
+ m.append(conv(n_feats, 9 * n_feats, 3, bias))
185
+ m.append(nn.PixelShuffle(3))
186
+ if bn:
187
+ m.append(nn.BatchNorm2d(n_feats))
188
+ if act == "relu":
189
+ m.append(nn.ReLU(True))
190
+ elif act == "prelu":
191
+ m.append(nn.PReLU(n_feats))
192
+ else:
193
+ raise NotImplementedError
194
+
195
+ super(Upsampler, self).__init__(*m)
196
+
197
+
198
+ class LiteUpsampler(nn.Sequential):
199
+ def __init__(self, conv, scale, n_feats, n_out=3, bn=False, act=False, bias=True):
200
+
201
+ m = []
202
+ m.append(conv(n_feats, n_out * (scale**2), 3, bias))
203
+ m.append(nn.PixelShuffle(scale))
204
+ # if (scale & (scale - 1)) == 0: # Is scale = 2^n?
205
+ # for _ in range(int(math.log(scale, 2))):
206
+ # m.append(conv(n_feats, 4 * n_out, 3, bias))
207
+ # m.append(nn.PixelShuffle(2))
208
+ # if bn:
209
+ # m.append(nn.BatchNorm2d(n_out))
210
+ # if act == 'relu':
211
+ # m.append(nn.ReLU(True))
212
+ # elif act == 'prelu':
213
+ # m.append(nn.PReLU(n_out))
214
+
215
+ # elif scale == 3:
216
+ # m.append(conv(n_feats, 9 * n_out, 3, bias))
217
+ # m.append(nn.PixelShuffle(3))
218
+ # if bn:
219
+ # m.append(nn.BatchNorm2d(n_out))
220
+ # if act == 'relu':
221
+ # m.append(nn.ReLU(True))
222
+ # elif act == 'prelu':
223
+ # m.append(nn.PReLU(n_out))
224
+ # else:
225
+ # raise NotImplementedError
226
+
227
+ super(LiteUpsampler, self).__init__(*m)
architecture/grl_common/mixed_attn_block.py ADDED
@@ -0,0 +1,1126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC
3
+ from math import prod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from architecture.grl_common.ops import (
9
+ bchw_to_bhwc,
10
+ bchw_to_blc,
11
+ blc_to_bchw,
12
+ blc_to_bhwc,
13
+ calculate_mask,
14
+ calculate_mask_all,
15
+ get_relative_coords_table_all,
16
+ get_relative_position_index_simple,
17
+ window_partition,
18
+ window_reverse,
19
+ )
20
+ from architecture.grl_common.swin_v1_block import Mlp
21
+ from timm.models.layers import DropPath
22
+
23
+
24
+ class CPB_MLP(nn.Sequential):
25
+ def __init__(self, in_channels, out_channels, channels=512):
26
+ m = [
27
+ nn.Linear(in_channels, channels, bias=True),
28
+ nn.ReLU(inplace=True),
29
+ nn.Linear(channels, out_channels, bias=False),
30
+ ]
31
+ super(CPB_MLP, self).__init__(*m)
32
+
33
+
34
+ class AffineTransformWindow(nn.Module):
35
+ r"""Affine transformation of the attention map.
36
+ The window is a square window.
37
+ Supports attention between different window sizes
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_heads,
43
+ input_resolution,
44
+ window_size,
45
+ pretrained_window_size=[0, 0],
46
+ shift_size=0,
47
+ anchor_window_down_factor=1,
48
+ args=None,
49
+ ):
50
+ super(AffineTransformWindow, self).__init__()
51
+ # print("AffineTransformWindow", args)
52
+ self.num_heads = num_heads
53
+ self.input_resolution = input_resolution
54
+ self.window_size = window_size
55
+ self.pretrained_window_size = pretrained_window_size
56
+ self.shift_size = shift_size
57
+ self.anchor_window_down_factor = anchor_window_down_factor
58
+ self.use_buffer = args.use_buffer
59
+
60
+ logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1)))
61
+ self.logit_scale = nn.Parameter(logit_scale, requires_grad=True)
62
+
63
+ # mlp to generate continuous relative position bias
64
+ self.cpb_mlp = CPB_MLP(2, num_heads)
65
+ if self.use_buffer:
66
+ table = get_relative_coords_table_all(
67
+ window_size, pretrained_window_size, anchor_window_down_factor
68
+ )
69
+ index = get_relative_position_index_simple(
70
+ window_size, anchor_window_down_factor
71
+ )
72
+ self.register_buffer("relative_coords_table", table)
73
+ self.register_buffer("relative_position_index", index)
74
+
75
+ if self.shift_size > 0:
76
+ attn_mask = calculate_mask(
77
+ input_resolution, self.window_size, self.shift_size
78
+ )
79
+ else:
80
+ attn_mask = None
81
+ self.register_buffer("attn_mask", attn_mask)
82
+
83
+ def forward(self, attn, x_size):
84
+ B_, H, N, _ = attn.shape
85
+ device = attn.device
86
+ # logit scale
87
+ attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
88
+
89
+ # relative position bias
90
+ if self.use_buffer:
91
+ table = self.relative_coords_table
92
+ index = self.relative_position_index
93
+ else:
94
+ table = get_relative_coords_table_all(
95
+ self.window_size,
96
+ self.pretrained_window_size,
97
+ self.anchor_window_down_factor,
98
+ ).to(device)
99
+ index = get_relative_position_index_simple(
100
+ self.window_size, self.anchor_window_down_factor
101
+ ).to(device)
102
+
103
+ bias_table = self.cpb_mlp(table) # 2*Wh-1, 2*Ww-1, num_heads
104
+ bias_table = bias_table.view(-1, self.num_heads)
105
+
106
+ win_dim = prod(self.window_size)
107
+ bias = bias_table[index.view(-1)]
108
+ bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous()
109
+ # nH, Wh*Ww, Wh*Ww
110
+ bias = 16 * torch.sigmoid(bias)
111
+ attn = attn + bias.unsqueeze(0)
112
+
113
+ # W-MSA/SW-MSA
114
+ if self.use_buffer:
115
+ mask = self.attn_mask
116
+ # during test and window shift, recalculate the mask
117
+ if self.input_resolution != x_size and self.shift_size > 0:
118
+ mask = calculate_mask(x_size, self.window_size, self.shift_size)
119
+ mask = mask.to(attn.device)
120
+ else:
121
+ if self.shift_size > 0:
122
+ mask = calculate_mask(x_size, self.window_size, self.shift_size)
123
+ mask = mask.to(attn.device)
124
+ else:
125
+ mask = None
126
+
127
+ # shift attention mask
128
+ if mask is not None:
129
+ nW = mask.shape[0]
130
+ mask = mask.unsqueeze(1).unsqueeze(0)
131
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
132
+ attn = attn.view(-1, self.num_heads, N, N)
133
+
134
+ return attn
135
+
136
+
137
+ class AffineTransformStripe(nn.Module):
138
+ r"""Affine transformation of the attention map.
139
+ The window is a stripe window. Supports attention between different window sizes
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ num_heads,
145
+ input_resolution,
146
+ stripe_size,
147
+ stripe_groups,
148
+ stripe_shift,
149
+ pretrained_stripe_size=[0, 0],
150
+ anchor_window_down_factor=1,
151
+ window_to_anchor=True,
152
+ args=None,
153
+ ):
154
+ super(AffineTransformStripe, self).__init__()
155
+ self.num_heads = num_heads
156
+ self.input_resolution = input_resolution
157
+ self.stripe_size = stripe_size
158
+ self.stripe_groups = stripe_groups
159
+ self.pretrained_stripe_size = pretrained_stripe_size
160
+ # TODO: be careful when determining the pretrained_stripe_size
161
+ self.stripe_shift = stripe_shift
162
+ stripe_size, shift_size = self._get_stripe_info(input_resolution)
163
+ self.anchor_window_down_factor = anchor_window_down_factor
164
+ self.window_to_anchor = window_to_anchor
165
+ self.use_buffer = args.use_buffer
166
+
167
+ logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1)))
168
+ self.logit_scale = nn.Parameter(logit_scale, requires_grad=True)
169
+
170
+ # mlp to generate continuous relative position bias
171
+ self.cpb_mlp = CPB_MLP(2, num_heads)
172
+ if self.use_buffer:
173
+ table = get_relative_coords_table_all(
174
+ stripe_size, pretrained_stripe_size, anchor_window_down_factor
175
+ )
176
+ index = get_relative_position_index_simple(
177
+ stripe_size, anchor_window_down_factor, window_to_anchor
178
+ )
179
+ self.register_buffer("relative_coords_table", table)
180
+ self.register_buffer("relative_position_index", index)
181
+
182
+ if self.stripe_shift:
183
+ attn_mask = calculate_mask_all(
184
+ input_resolution,
185
+ stripe_size,
186
+ shift_size,
187
+ anchor_window_down_factor,
188
+ window_to_anchor,
189
+ )
190
+ else:
191
+ attn_mask = None
192
+ self.register_buffer("attn_mask", attn_mask)
193
+
194
+ def forward(self, attn, x_size):
195
+ B_, H, N1, N2 = attn.shape
196
+ device = attn.device
197
+ # logit scale
198
+ attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
199
+
200
+ # relative position bias
201
+ stripe_size, shift_size = self._get_stripe_info(x_size)
202
+ fixed_stripe_size = (
203
+ self.stripe_groups[0] is None and self.stripe_groups[1] is None
204
+ )
205
+ if not self.use_buffer or (
206
+ self.use_buffer
207
+ and self.input_resolution != x_size
208
+ and not fixed_stripe_size
209
+ ):
210
+ # during test and stripe size is not fixed.
211
+ pretrained_stripe_size = (
212
+ self.pretrained_stripe_size
213
+ ) # or stripe_size; Needs further pondering
214
+ table = get_relative_coords_table_all(
215
+ stripe_size, pretrained_stripe_size, self.anchor_window_down_factor
216
+ )
217
+ table = table.to(device)
218
+ index = get_relative_position_index_simple(
219
+ stripe_size, self.anchor_window_down_factor, self.window_to_anchor
220
+ ).to(device)
221
+ else:
222
+ table = self.relative_coords_table
223
+ index = self.relative_position_index
224
+ # The same table size-> 1, Wh+AWh-1, Ww+AWw-1, 2
225
+ # But different index size -> # Wh*Ww, AWh*AWw
226
+ # if N1 < N2:
227
+ # index = index.transpose(0, 1)
228
+
229
+ bias_table = self.cpb_mlp(table).view(-1, self.num_heads)
230
+ # if not self.training:
231
+ # print(bias_table.shape, index.max(), index.min())
232
+ bias = bias_table[index.view(-1)]
233
+ bias = bias.view(N1, N2, -1).permute(2, 0, 1).contiguous()
234
+ # nH, Wh*Ww, Wh*Ww
235
+ bias = 16 * torch.sigmoid(bias)
236
+ # print(N1, N2, attn.shape, bias.unsqueeze(0).shape)
237
+ attn = attn + bias.unsqueeze(0)
238
+
239
+ # W-MSA/SW-MSA
240
+ if self.use_buffer:
241
+ mask = self.attn_mask
242
+ # during test and window shift, recalculate the mask
243
+ if self.input_resolution != x_size and self.stripe_shift > 0:
244
+ mask = calculate_mask_all(
245
+ x_size,
246
+ stripe_size,
247
+ shift_size,
248
+ self.anchor_window_down_factor,
249
+ self.window_to_anchor,
250
+ )
251
+ mask = mask.to(device)
252
+ else:
253
+ if self.stripe_shift > 0:
254
+ mask = calculate_mask_all(
255
+ x_size,
256
+ stripe_size,
257
+ shift_size,
258
+ self.anchor_window_down_factor,
259
+ self.window_to_anchor,
260
+ )
261
+ mask = mask.to(attn.device)
262
+ else:
263
+ mask = None
264
+
265
+ # shift attention mask
266
+ if mask is not None:
267
+ nW = mask.shape[0]
268
+ mask = mask.unsqueeze(1).unsqueeze(0)
269
+ attn = attn.view(B_ // nW, nW, self.num_heads, N1, N2) + mask
270
+ attn = attn.view(-1, self.num_heads, N1, N2)
271
+
272
+ return attn
273
+
274
+ def _get_stripe_info(self, input_resolution):
275
+ stripe_size, shift_size = [], []
276
+ for s, g, d in zip(self.stripe_size, self.stripe_groups, input_resolution):
277
+ if g is None:
278
+ stripe_size.append(s)
279
+ shift_size.append(s // 2 if self.stripe_shift else 0)
280
+ else:
281
+ stripe_size.append(d // g)
282
+ shift_size.append(0 if g == 1 else d // (g * 2))
283
+ return stripe_size, shift_size
284
+
285
+
286
+ class Attention(ABC, nn.Module):
287
+ def __init__(self):
288
+ super(Attention, self).__init__()
289
+
290
+ def attn(self, q, k, v, attn_transform, x_size, reshape=True):
291
+ # cosine attention map
292
+ B_, _, H, head_dim = q.shape
293
+ if self.euclidean_dist:
294
+ attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1)
295
+ else:
296
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
297
+ attn = attn_transform(attn, x_size)
298
+ # attention
299
+ attn = self.softmax(attn)
300
+ attn = self.attn_drop(attn)
301
+ x = attn @ v # B_, H, N1, head_dim
302
+ if reshape:
303
+ x = x.transpose(1, 2).reshape(B_, -1, H * head_dim)
304
+ # B_, N, C
305
+ return x
306
+
307
+
308
+ class WindowAttention(Attention):
309
+ r"""Window attention. QKV is the input to the forward method.
310
+ Args:
311
+ num_heads (int): Number of attention heads.
312
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
313
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ input_resolution,
319
+ window_size,
320
+ num_heads,
321
+ window_shift=False,
322
+ attn_drop=0.0,
323
+ pretrained_window_size=[0, 0],
324
+ args=None,
325
+ ):
326
+
327
+ super(WindowAttention, self).__init__()
328
+ self.input_resolution = input_resolution
329
+ self.window_size = window_size
330
+ self.pretrained_window_size = pretrained_window_size
331
+ self.num_heads = num_heads
332
+ self.shift_size = window_size[0] // 2 if window_shift else 0
333
+ self.euclidean_dist = args.euclidean_dist
334
+
335
+ self.attn_transform = AffineTransformWindow(
336
+ num_heads,
337
+ input_resolution,
338
+ window_size,
339
+ pretrained_window_size,
340
+ self.shift_size,
341
+ args=args,
342
+ )
343
+ self.attn_drop = nn.Dropout(attn_drop)
344
+ self.softmax = nn.Softmax(dim=-1)
345
+
346
+ def forward(self, qkv, x_size):
347
+ """
348
+ Args:
349
+ qkv: input QKV features with shape of (B, L, 3C)
350
+ x_size: use x_size to determine whether the relative positional bias table and index
351
+ need to be regenerated.
352
+ """
353
+ H, W = x_size
354
+ B, L, C = qkv.shape
355
+ qkv = qkv.view(B, H, W, C)
356
+
357
+ # cyclic shift
358
+ if self.shift_size > 0:
359
+ qkv = torch.roll(
360
+ qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
361
+ )
362
+
363
+ # partition windows
364
+ qkv = window_partition(qkv, self.window_size) # nW*B, wh, ww, C
365
+ qkv = qkv.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
366
+
367
+ B_, N, _ = qkv.shape
368
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
369
+ q, k, v = qkv[0], qkv[1], qkv[2]
370
+
371
+ # attention
372
+ x = self.attn(q, k, v, self.attn_transform, x_size)
373
+
374
+ # merge windows
375
+ x = x.view(-1, *self.window_size, C // 3)
376
+ x = window_reverse(x, self.window_size, x_size) # B, H, W, C/3
377
+
378
+ # reverse cyclic shift
379
+ if self.shift_size > 0:
380
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
381
+ x = x.view(B, L, C // 3)
382
+
383
+ return x
384
+
385
+ def extra_repr(self) -> str:
386
+ return (
387
+ f"window_size={self.window_size}, shift_size={self.shift_size}, "
388
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
389
+ )
390
+
391
+ def flops(self, N):
392
+ # calculate flops for 1 window with token length of N
393
+ flops = 0
394
+ # qkv = self.qkv(x)
395
+ flops += N * self.dim * 3 * self.dim
396
+ # attn = (q @ k.transpose(-2, -1))
397
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
398
+ # x = (attn @ v)
399
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
400
+ # x = self.proj(x)
401
+ flops += N * self.dim * self.dim
402
+ return flops
403
+
404
+
405
+ class StripeAttention(Attention):
406
+ r"""Stripe attention
407
+ Args:
408
+ stripe_size (tuple[int]): The height and width of the stripe.
409
+ num_heads (int): Number of attention heads.
410
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
411
+ pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
412
+ """
413
+
414
+ def __init__(
415
+ self,
416
+ input_resolution,
417
+ stripe_size,
418
+ stripe_groups,
419
+ stripe_shift,
420
+ num_heads,
421
+ attn_drop=0.0,
422
+ pretrained_stripe_size=[0, 0],
423
+ args=None,
424
+ ):
425
+
426
+ super(StripeAttention, self).__init__()
427
+ self.input_resolution = input_resolution
428
+ self.stripe_size = stripe_size # Wh, Ww
429
+ self.stripe_groups = stripe_groups
430
+ self.stripe_shift = stripe_shift
431
+ self.num_heads = num_heads
432
+ self.pretrained_stripe_size = pretrained_stripe_size
433
+ self.euclidean_dist = args.euclidean_dist
434
+
435
+ self.attn_transform = AffineTransformStripe(
436
+ num_heads,
437
+ input_resolution,
438
+ stripe_size,
439
+ stripe_groups,
440
+ stripe_shift,
441
+ pretrained_stripe_size,
442
+ anchor_window_down_factor=1,
443
+ args=args,
444
+ )
445
+ self.attn_drop = nn.Dropout(attn_drop)
446
+ self.softmax = nn.Softmax(dim=-1)
447
+
448
+ def forward(self, qkv, x_size):
449
+ """
450
+ Args:
451
+ x: input features with shape of (B, L, C)
452
+ stripe_size: use stripe_size to determine whether the relative positional bias table and index
453
+ need to be regenerated.
454
+ """
455
+ H, W = x_size
456
+ B, L, C = qkv.shape
457
+ qkv = qkv.view(B, H, W, C)
458
+
459
+ running_stripe_size, running_shift_size = self.attn_transform._get_stripe_info(
460
+ x_size
461
+ )
462
+ # cyclic shift
463
+ if self.stripe_shift:
464
+ qkv = torch.roll(
465
+ qkv,
466
+ shifts=(-running_shift_size[0], -running_shift_size[1]),
467
+ dims=(1, 2),
468
+ )
469
+
470
+ # partition windows
471
+ qkv = window_partition(qkv, running_stripe_size) # nW*B, wh, ww, C
472
+ qkv = qkv.view(-1, prod(running_stripe_size), C) # nW*B, wh*ww, C
473
+
474
+ B_, N, _ = qkv.shape
475
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
476
+ q, k, v = qkv[0], qkv[1], qkv[2]
477
+
478
+ # attention
479
+ x = self.attn(q, k, v, self.attn_transform, x_size)
480
+
481
+ # merge windows
482
+ x = x.view(-1, *running_stripe_size, C // 3)
483
+ x = window_reverse(x, running_stripe_size, x_size) # B H W C/3
484
+
485
+ # reverse the shift
486
+ if self.stripe_shift:
487
+ x = torch.roll(x, shifts=running_shift_size, dims=(1, 2))
488
+
489
+ x = x.view(B, L, C // 3)
490
+ return x
491
+
492
+ def extra_repr(self) -> str:
493
+ return (
494
+ f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
495
+ f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}"
496
+ )
497
+
498
+ def flops(self, N):
499
+ # calculate flops for 1 window with token length of N
500
+ flops = 0
501
+ # qkv = self.qkv(x)
502
+ flops += N * self.dim * 3 * self.dim
503
+ # attn = (q @ k.transpose(-2, -1))
504
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
505
+ # x = (attn @ v)
506
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
507
+ # x = self.proj(x)
508
+ flops += N * self.dim * self.dim
509
+ return flops
510
+
511
+
512
+ class AnchorStripeAttention(Attention):
513
+ r"""Stripe attention
514
+ Args:
515
+ stripe_size (tuple[int]): The height and width of the stripe.
516
+ num_heads (int): Number of attention heads.
517
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
518
+ pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
519
+ """
520
+
521
+ def __init__(
522
+ self,
523
+ input_resolution,
524
+ stripe_size,
525
+ stripe_groups,
526
+ stripe_shift,
527
+ num_heads,
528
+ attn_drop=0.0,
529
+ pretrained_stripe_size=[0, 0],
530
+ anchor_window_down_factor=1,
531
+ args=None,
532
+ ):
533
+
534
+ super(AnchorStripeAttention, self).__init__()
535
+ self.input_resolution = input_resolution
536
+ self.stripe_size = stripe_size # Wh, Ww
537
+ self.stripe_groups = stripe_groups
538
+ self.stripe_shift = stripe_shift
539
+ self.num_heads = num_heads
540
+ self.pretrained_stripe_size = pretrained_stripe_size
541
+ self.anchor_window_down_factor = anchor_window_down_factor
542
+ self.euclidean_dist = args.euclidean_dist
543
+
544
+ self.attn_transform1 = AffineTransformStripe(
545
+ num_heads,
546
+ input_resolution,
547
+ stripe_size,
548
+ stripe_groups,
549
+ stripe_shift,
550
+ pretrained_stripe_size,
551
+ anchor_window_down_factor,
552
+ window_to_anchor=False,
553
+ args=args,
554
+ )
555
+
556
+ self.attn_transform2 = AffineTransformStripe(
557
+ num_heads,
558
+ input_resolution,
559
+ stripe_size,
560
+ stripe_groups,
561
+ stripe_shift,
562
+ pretrained_stripe_size,
563
+ anchor_window_down_factor,
564
+ window_to_anchor=True,
565
+ args=args,
566
+ )
567
+
568
+ self.attn_drop = nn.Dropout(attn_drop)
569
+ self.softmax = nn.Softmax(dim=-1)
570
+
571
+ def forward(self, qkv, anchor, x_size):
572
+ """
573
+ Args:
574
+ qkv: input features with shape of (B, L, C)
575
+ anchor:
576
+ x_size: use stripe_size to determine whether the relative positional bias table and index
577
+ need to be regenerated.
578
+ """
579
+ H, W = x_size
580
+ B, L, C = qkv.shape
581
+ qkv = qkv.view(B, H, W, C)
582
+
583
+ stripe_size, shift_size = self.attn_transform1._get_stripe_info(x_size)
584
+ anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size]
585
+ anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size]
586
+ # cyclic shift
587
+ if self.stripe_shift:
588
+ qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
589
+ anchor = torch.roll(
590
+ anchor,
591
+ shifts=(-anchor_shift_size[0], -anchor_shift_size[1]),
592
+ dims=(1, 2),
593
+ )
594
+
595
+ # partition windows
596
+ qkv = window_partition(qkv, stripe_size) # nW*B, wh, ww, C
597
+ qkv = qkv.view(-1, prod(stripe_size), C) # nW*B, wh*ww, C
598
+ anchor = window_partition(anchor, anchor_stripe_size)
599
+ anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3)
600
+
601
+ B_, N1, _ = qkv.shape
602
+ N2 = anchor.shape[1]
603
+ qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
604
+ q, k, v = qkv[0], qkv[1], qkv[2]
605
+ anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)
606
+
607
+ # attention
608
+ x = self.attn(anchor, k, v, self.attn_transform1, x_size, False)
609
+ x = self.attn(q, anchor, x, self.attn_transform2, x_size)
610
+
611
+ # merge windows
612
+ x = x.view(B_, *stripe_size, C // 3)
613
+ x = window_reverse(x, stripe_size, x_size) # B H' W' C
614
+
615
+ # reverse the shift
616
+ if self.stripe_shift:
617
+ x = torch.roll(x, shifts=shift_size, dims=(1, 2))
618
+
619
+ x = x.view(B, H * W, C // 3)
620
+ return x
621
+
622
+ def extra_repr(self) -> str:
623
+ return (
624
+ f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
625
+ f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}"
626
+ )
627
+
628
+ def flops(self, N):
629
+ # calculate flops for 1 window with token length of N
630
+ flops = 0
631
+ # qkv = self.qkv(x)
632
+ flops += N * self.dim * 3 * self.dim
633
+ # attn = (q @ k.transpose(-2, -1))
634
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
635
+ # x = (attn @ v)
636
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
637
+ # x = self.proj(x)
638
+ flops += N * self.dim * self.dim
639
+ return flops
640
+
641
+
642
+ class SeparableConv(nn.Sequential):
643
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias, args):
644
+ m = [
645
+ nn.Conv2d(
646
+ in_channels,
647
+ in_channels,
648
+ kernel_size,
649
+ stride,
650
+ kernel_size // 2,
651
+ groups=in_channels,
652
+ bias=bias,
653
+ )
654
+ ]
655
+ if args.separable_conv_act:
656
+ m.append(nn.GELU())
657
+ m.append(nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=bias))
658
+ super(SeparableConv, self).__init__(*m)
659
+
660
+
661
+ class QKVProjection(nn.Module):
662
+ def __init__(self, dim, qkv_bias, proj_type, args):
663
+ super(QKVProjection, self).__init__()
664
+ self.proj_type = proj_type
665
+ if proj_type == "linear":
666
+ self.body = nn.Linear(dim, dim * 3, bias=qkv_bias)
667
+ else:
668
+ self.body = SeparableConv(dim, dim * 3, 3, 1, qkv_bias, args)
669
+
670
+ def forward(self, x, x_size):
671
+ if self.proj_type == "separable_conv":
672
+ x = blc_to_bchw(x, x_size)
673
+ x = self.body(x)
674
+ if self.proj_type == "separable_conv":
675
+ x = bchw_to_blc(x)
676
+ return x
677
+
678
+
679
+ class PatchMerging(nn.Module):
680
+ r"""Patch Merging Layer.
681
+ Args:
682
+ dim (int): Number of input channels.
683
+ """
684
+
685
+ def __init__(self, in_dim, out_dim):
686
+ super().__init__()
687
+ self.in_dim = in_dim
688
+ self.out_dim = out_dim
689
+ self.reduction = nn.Linear(4 * in_dim, out_dim, bias=False)
690
+
691
+ def forward(self, x, x_size):
692
+ """
693
+ x: B, H*W, C
694
+ """
695
+ H, W = x_size
696
+ B, L, C = x.shape
697
+ assert L == H * W, "input feature has wrong size"
698
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
699
+
700
+ x = x.view(B, H, W, C)
701
+
702
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
703
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
704
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
705
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
706
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
707
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
708
+
709
+ x = self.reduction(x)
710
+
711
+ return x
712
+
713
+
714
+ class AnchorLinear(nn.Module):
715
+ r"""Linear anchor projection layer
716
+ Args:
717
+ dim (int): Number of input channels.
718
+ """
719
+
720
+ def __init__(self, in_channels, out_channels, down_factor, pooling_mode, bias):
721
+ super().__init__()
722
+ self.down_factor = down_factor
723
+ if pooling_mode == "maxpool":
724
+ self.pooling = nn.MaxPool2d(down_factor, down_factor)
725
+ elif pooling_mode == "avgpool":
726
+ self.pooling = nn.AvgPool2d(down_factor, down_factor)
727
+ self.reduction = nn.Linear(in_channels, out_channels, bias=bias)
728
+
729
+ def forward(self, x, x_size):
730
+ """
731
+ x: B, H*W, C
732
+ """
733
+ x = blc_to_bchw(x, x_size)
734
+ x = bchw_to_blc(self.pooling(x))
735
+ x = blc_to_bhwc(self.reduction(x), [s // self.down_factor for s in x_size])
736
+ return x
737
+
738
+
739
+ class AnchorProjection(nn.Module):
740
+ def __init__(self, dim, proj_type, one_stage, anchor_window_down_factor, args):
741
+ super(AnchorProjection, self).__init__()
742
+ self.proj_type = proj_type
743
+ self.body = nn.ModuleList([])
744
+ if one_stage:
745
+ if proj_type == "patchmerging":
746
+ m = PatchMerging(dim, dim // 2)
747
+ elif proj_type == "conv2d":
748
+ kernel_size = anchor_window_down_factor + 1
749
+ stride = anchor_window_down_factor
750
+ padding = kernel_size // 2
751
+ m = nn.Conv2d(dim, dim // 2, kernel_size, stride, padding)
752
+ elif proj_type == "separable_conv":
753
+ kernel_size = anchor_window_down_factor + 1
754
+ stride = anchor_window_down_factor
755
+ m = SeparableConv(dim, dim // 2, kernel_size, stride, True, args)
756
+ elif proj_type.find("pool") >= 0:
757
+ m = AnchorLinear(
758
+ dim, dim // 2, anchor_window_down_factor, proj_type, True
759
+ )
760
+ self.body.append(m)
761
+ else:
762
+ for i in range(int(math.log2(anchor_window_down_factor))):
763
+ cin = dim if i == 0 else dim // 2
764
+ if proj_type == "patchmerging":
765
+ m = PatchMerging(cin, dim // 2)
766
+ elif proj_type == "conv2d":
767
+ m = nn.Conv2d(cin, dim // 2, 3, 2, 1)
768
+ elif proj_type == "separable_conv":
769
+ m = SeparableConv(cin, dim // 2, 3, 2, True, args)
770
+ self.body.append(m)
771
+
772
+ def forward(self, x, x_size):
773
+ if self.proj_type.find("conv") >= 0:
774
+ x = blc_to_bchw(x, x_size)
775
+ for m in self.body:
776
+ x = m(x)
777
+ x = bchw_to_bhwc(x)
778
+ elif self.proj_type.find("pool") >= 0:
779
+ for m in self.body:
780
+ x = m(x, x_size)
781
+ else:
782
+ for i, m in enumerate(self.body):
783
+ x = m(x, [s // 2**i for s in x_size])
784
+ x = blc_to_bhwc(x, [s // 2 ** (i + 1) for s in x_size])
785
+ return x
786
+
787
+
788
+ class MixedAttention(nn.Module):
789
+ r"""Mixed window attention and stripe attention
790
+ Args:
791
+ dim (int): Number of input channels.
792
+ stripe_size (tuple[int]): The height and width of the stripe.
793
+ num_heads (int): Number of attention heads.
794
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
795
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
796
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
797
+ pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
798
+ """
799
+
800
+ def __init__(
801
+ self,
802
+ dim,
803
+ input_resolution,
804
+ num_heads_w,
805
+ num_heads_s,
806
+ window_size,
807
+ window_shift,
808
+ stripe_size,
809
+ stripe_groups,
810
+ stripe_shift,
811
+ qkv_bias=True,
812
+ qkv_proj_type="linear",
813
+ anchor_proj_type="separable_conv",
814
+ anchor_one_stage=True,
815
+ anchor_window_down_factor=1,
816
+ attn_drop=0.0,
817
+ proj_drop=0.0,
818
+ pretrained_window_size=[0, 0],
819
+ pretrained_stripe_size=[0, 0],
820
+ args=None,
821
+ ):
822
+
823
+ super(MixedAttention, self).__init__()
824
+ self.dim = dim
825
+ self.input_resolution = input_resolution
826
+ self.use_anchor = anchor_window_down_factor > 1
827
+ self.args = args
828
+ # print(args)
829
+ self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args)
830
+ if self.use_anchor:
831
+ # anchor is only used for stripe attention
832
+ self.anchor = AnchorProjection(
833
+ dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args
834
+ )
835
+
836
+ self.window_attn = WindowAttention(
837
+ input_resolution,
838
+ window_size,
839
+ num_heads_w,
840
+ window_shift,
841
+ attn_drop,
842
+ pretrained_window_size,
843
+ args,
844
+ )
845
+
846
+ if self.args.double_window:
847
+ self.stripe_attn = WindowAttention(
848
+ input_resolution,
849
+ window_size,
850
+ num_heads_w,
851
+ window_shift,
852
+ attn_drop,
853
+ pretrained_window_size,
854
+ args,
855
+ )
856
+ else:
857
+ if self.use_anchor:
858
+ self.stripe_attn = AnchorStripeAttention(
859
+ input_resolution,
860
+ stripe_size,
861
+ stripe_groups,
862
+ stripe_shift,
863
+ num_heads_s,
864
+ attn_drop,
865
+ pretrained_stripe_size,
866
+ anchor_window_down_factor,
867
+ args,
868
+ )
869
+ else:
870
+ if self.args.stripe_square:
871
+ self.stripe_attn = StripeAttention(
872
+ input_resolution,
873
+ window_size,
874
+ [None, None],
875
+ window_shift,
876
+ num_heads_s,
877
+ attn_drop,
878
+ pretrained_stripe_size,
879
+ args,
880
+ )
881
+ else:
882
+ self.stripe_attn = StripeAttention(
883
+ input_resolution,
884
+ stripe_size,
885
+ stripe_groups,
886
+ stripe_shift,
887
+ num_heads_s,
888
+ attn_drop,
889
+ pretrained_stripe_size,
890
+ args,
891
+ )
892
+ if self.args.out_proj_type == "linear":
893
+ self.proj = nn.Linear(dim, dim)
894
+ else:
895
+ self.proj = nn.Conv2d(dim, dim, 3, 1, 1)
896
+ self.proj_drop = nn.Dropout(proj_drop)
897
+
898
+ def forward(self, x, x_size):
899
+ """
900
+ Args:
901
+ x: input features with shape of (B, L, C)
902
+ stripe_size: use stripe_size to determine whether the relative positional bias table and index
903
+ need to be regenerated.
904
+ """
905
+ B, L, C = x.shape
906
+
907
+ # qkv projection
908
+ qkv = self.qkv(x, x_size)
909
+ qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1)
910
+ # anchor projection
911
+ if self.use_anchor:
912
+ anchor = self.anchor(x, x_size)
913
+
914
+ # attention
915
+ x_window = self.window_attn(qkv_window, x_size)
916
+ if self.use_anchor:
917
+ x_stripe = self.stripe_attn(qkv_stripe, anchor, x_size)
918
+ else:
919
+ x_stripe = self.stripe_attn(qkv_stripe, x_size)
920
+ x = torch.cat([x_window, x_stripe], dim=-1)
921
+
922
+ # output projection
923
+ if self.args.out_proj_type == "linear":
924
+ x = self.proj(x)
925
+ else:
926
+ x = blc_to_bchw(x, x_size)
927
+ x = bchw_to_blc(self.proj(x))
928
+ x = self.proj_drop(x)
929
+ return x
930
+
931
+ def extra_repr(self) -> str:
932
+ return f"dim={self.dim}, input_resolution={self.input_resolution}"
933
+
934
+ def flops(self, N):
935
+ # calculate flops for 1 window with token length of N
936
+ flops = 0
937
+ # qkv = self.qkv(x)
938
+ flops += N * self.dim * 3 * self.dim
939
+ # attn = (q @ k.transpose(-2, -1))
940
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
941
+ # x = (attn @ v)
942
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
943
+ # x = self.proj(x)
944
+ flops += N * self.dim * self.dim
945
+ return flops
946
+
947
+
948
+ class ChannelAttention(nn.Module):
949
+ """Channel attention used in RCAN.
950
+ Args:
951
+ num_feat (int): Channel number of intermediate features.
952
+ reduction (int): Channel reduction factor. Default: 16.
953
+ """
954
+
955
+ def __init__(self, num_feat, reduction=16):
956
+ super(ChannelAttention, self).__init__()
957
+ self.attention = nn.Sequential(
958
+ nn.AdaptiveAvgPool2d(1),
959
+ nn.Conv2d(num_feat, num_feat // reduction, 1, padding=0),
960
+ nn.ReLU(inplace=True),
961
+ nn.Conv2d(num_feat // reduction, num_feat, 1, padding=0),
962
+ nn.Sigmoid(),
963
+ )
964
+
965
+ def forward(self, x):
966
+ y = self.attention(x)
967
+ return x * y
968
+
969
+
970
+ class CAB(nn.Module):
971
+ def __init__(self, num_feat, compress_ratio=4, reduction=18):
972
+ super(CAB, self).__init__()
973
+
974
+ self.cab = nn.Sequential(
975
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
976
+ nn.GELU(),
977
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
978
+ ChannelAttention(num_feat, reduction),
979
+ )
980
+
981
+ def forward(self, x, x_size):
982
+ x = self.cab(blc_to_bchw(x, x_size).contiguous())
983
+ return bchw_to_blc(x)
984
+
985
+
986
+ class MixAttnTransformerBlock(nn.Module):
987
+ r"""Mix attention transformer block with shared QKV projection and output projection for mixed attention modules.
988
+ Args:
989
+ dim (int): Number of input channels.
990
+ input_resolution (tuple[int]): Input resulotion.
991
+ num_heads (int): Number of attention heads.
992
+ window_size (int): Window size.
993
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
994
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
995
+ drop (float, optional): Dropout rate. Default: 0.0
996
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
997
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
998
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
999
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
1000
+ pretrained_stripe_size (int): Window size in pre-training.
1001
+ attn_type (str, optional): Attention type. Default: cwhv.
1002
+ c: residual blocks
1003
+ w: window attention
1004
+ h: horizontal stripe attention
1005
+ v: vertical stripe attention
1006
+ """
1007
+
1008
+ def __init__(
1009
+ self,
1010
+ dim,
1011
+ input_resolution,
1012
+ num_heads_w,
1013
+ num_heads_s,
1014
+ window_size=7,
1015
+ window_shift=False,
1016
+ stripe_size=[8, 8],
1017
+ stripe_groups=[None, None],
1018
+ stripe_shift=False,
1019
+ stripe_type="H",
1020
+ mlp_ratio=4.0,
1021
+ qkv_bias=True,
1022
+ qkv_proj_type="linear",
1023
+ anchor_proj_type="separable_conv",
1024
+ anchor_one_stage=True,
1025
+ anchor_window_down_factor=1,
1026
+ drop=0.0,
1027
+ attn_drop=0.0,
1028
+ drop_path=0.0,
1029
+ act_layer=nn.GELU,
1030
+ norm_layer=nn.LayerNorm,
1031
+ pretrained_window_size=[0, 0],
1032
+ pretrained_stripe_size=[0, 0],
1033
+ res_scale=1.0,
1034
+ args=None,
1035
+ ):
1036
+ super().__init__()
1037
+ self.dim = dim
1038
+ self.input_resolution = input_resolution
1039
+ self.num_heads_w = num_heads_w
1040
+ self.num_heads_s = num_heads_s
1041
+ self.window_size = window_size
1042
+ self.window_shift = window_shift
1043
+ self.stripe_shift = stripe_shift
1044
+ self.stripe_type = stripe_type
1045
+ self.args = args
1046
+ if self.stripe_type == "W":
1047
+ self.stripe_size = stripe_size[::-1]
1048
+ self.stripe_groups = stripe_groups[::-1]
1049
+ else:
1050
+ self.stripe_size = stripe_size
1051
+ self.stripe_groups = stripe_groups
1052
+ self.mlp_ratio = mlp_ratio
1053
+ self.res_scale = res_scale
1054
+
1055
+ self.attn = MixedAttention(
1056
+ dim,
1057
+ input_resolution,
1058
+ num_heads_w,
1059
+ num_heads_s,
1060
+ window_size,
1061
+ window_shift,
1062
+ self.stripe_size,
1063
+ self.stripe_groups,
1064
+ stripe_shift,
1065
+ qkv_bias,
1066
+ qkv_proj_type,
1067
+ anchor_proj_type,
1068
+ anchor_one_stage,
1069
+ anchor_window_down_factor,
1070
+ attn_drop,
1071
+ drop,
1072
+ pretrained_window_size,
1073
+ pretrained_stripe_size,
1074
+ args,
1075
+ )
1076
+ self.norm1 = norm_layer(dim)
1077
+ if self.args.local_connection:
1078
+ self.conv = CAB(dim)
1079
+
1080
+ # self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
1081
+
1082
+ # self.mlp = Mlp(
1083
+ # in_features=dim,
1084
+ # hidden_features=int(dim * mlp_ratio),
1085
+ # act_layer=act_layer,
1086
+ # drop=drop,
1087
+ # )
1088
+ # self.norm2 = norm_layer(dim)
1089
+
1090
+ def forward(self, x, x_size):
1091
+ # Mixed attention
1092
+ if self.args.local_connection:
1093
+ x = (
1094
+ x
1095
+ + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size)))
1096
+ + self.conv(x, x_size)
1097
+ )
1098
+ else:
1099
+ x = x + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size)))
1100
+ # FFN
1101
+ x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
1102
+
1103
+ # return x
1104
+
1105
+ def extra_repr(self) -> str:
1106
+ return (
1107
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads=({self.num_heads_w}, {self.num_heads_s}), "
1108
+ f"window_size={self.window_size}, window_shift={self.window_shift}, "
1109
+ f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, self.stripe_type={self.stripe_type}, "
1110
+ f"mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
1111
+ )
1112
+
1113
+
1114
+ # def flops(self):
1115
+ # flops = 0
1116
+ # H, W = self.input_resolution
1117
+ # # norm1
1118
+ # flops += self.dim * H * W
1119
+ # # W-MSA/SW-MSA
1120
+ # nW = H * W / self.stripe_size[0] / self.stripe_size[1]
1121
+ # flops += nW * self.attn.flops(self.stripe_size[0] * self.stripe_size[1])
1122
+ # # mlp
1123
+ # flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
1124
+ # # norm2
1125
+ # flops += self.dim * H * W
1126
+ # return flops
architecture/grl_common/mixed_attn_block_efficient.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from abc import ABC
3
+ from math import prod
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from timm.models.layers import DropPath
9
+
10
+
11
+ from architecture.grl_common.mixed_attn_block import (
12
+ AnchorProjection,
13
+ CAB,
14
+ CPB_MLP,
15
+ QKVProjection,
16
+ )
17
+ from architecture.grl_common.ops import (
18
+ window_partition,
19
+ window_reverse,
20
+ )
21
+ from architecture.grl_common.swin_v1_block import Mlp
22
+
23
+
24
+ class AffineTransform(nn.Module):
25
+ r"""Affine transformation of the attention map.
26
+ The window could be a square window or a stripe window. Supports attention between different window sizes
27
+ """
28
+
29
+ def __init__(self, num_heads):
30
+ super(AffineTransform, self).__init__()
31
+ logit_scale = torch.log(10 * torch.ones((num_heads, 1, 1)))
32
+ self.logit_scale = nn.Parameter(logit_scale, requires_grad=True)
33
+
34
+ # mlp to generate continuous relative position bias
35
+ self.cpb_mlp = CPB_MLP(2, num_heads)
36
+
37
+ def forward(self, attn, relative_coords_table, relative_position_index, mask):
38
+ B_, H, N1, N2 = attn.shape
39
+ # logit scale
40
+ attn = attn * torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
41
+
42
+ bias_table = self.cpb_mlp(relative_coords_table) # 2*Wh-1, 2*Ww-1, num_heads
43
+ bias_table = bias_table.view(-1, H)
44
+
45
+ bias = bias_table[relative_position_index.view(-1)]
46
+ bias = bias.view(N1, N2, -1).permute(2, 0, 1).contiguous()
47
+ # nH, Wh*Ww, Wh*Ww
48
+ bias = 16 * torch.sigmoid(bias)
49
+ attn = attn + bias.unsqueeze(0)
50
+
51
+ # W-MSA/SW-MSA
52
+ # shift attention mask
53
+ if mask is not None:
54
+ nW = mask.shape[0]
55
+ mask = mask.unsqueeze(1).unsqueeze(0)
56
+ attn = attn.view(B_ // nW, nW, H, N1, N2) + mask
57
+ attn = attn.view(-1, H, N1, N2)
58
+
59
+ return attn
60
+
61
+
62
+ def _get_stripe_info(stripe_size_in, stripe_groups_in, stripe_shift, input_resolution):
63
+ stripe_size, shift_size = [], []
64
+ for s, g, d in zip(stripe_size_in, stripe_groups_in, input_resolution):
65
+ if g is None:
66
+ stripe_size.append(s)
67
+ shift_size.append(s // 2 if stripe_shift else 0)
68
+ else:
69
+ stripe_size.append(d // g)
70
+ shift_size.append(0 if g == 1 else d // (g * 2))
71
+ return stripe_size, shift_size
72
+
73
+
74
+ class Attention(ABC, nn.Module):
75
+ def __init__(self):
76
+ super(Attention, self).__init__()
77
+
78
+ def attn(self, q, k, v, attn_transform, table, index, mask, reshape=True):
79
+ # q, k, v: # nW*B, H, wh*ww, dim
80
+ # cosine attention map
81
+ B_, _, H, head_dim = q.shape
82
+ if self.euclidean_dist:
83
+ # print("use euclidean distance")
84
+ attn = torch.norm(q.unsqueeze(-2) - k.unsqueeze(-3), dim=-1)
85
+ else:
86
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
87
+ attn = attn_transform(attn, table, index, mask)
88
+ # attention
89
+ attn = self.softmax(attn)
90
+ attn = self.attn_drop(attn)
91
+ x = attn @ v # B_, H, N1, head_dim
92
+ if reshape:
93
+ x = x.transpose(1, 2).reshape(B_, -1, H * head_dim)
94
+ # B_, N, C
95
+ return x
96
+
97
+
98
+ class WindowAttention(Attention):
99
+ r"""Window attention. QKV is the input to the forward method.
100
+ Args:
101
+ num_heads (int): Number of attention heads.
102
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
103
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ input_resolution,
109
+ window_size,
110
+ num_heads,
111
+ window_shift=False,
112
+ attn_drop=0.0,
113
+ pretrained_window_size=[0, 0],
114
+ args=None,
115
+ ):
116
+
117
+ super(WindowAttention, self).__init__()
118
+ self.input_resolution = input_resolution
119
+ self.window_size = window_size
120
+ self.pretrained_window_size = pretrained_window_size
121
+ self.num_heads = num_heads
122
+ self.shift_size = window_size[0] // 2 if window_shift else 0
123
+ self.euclidean_dist = args.euclidean_dist
124
+
125
+ self.attn_transform = AffineTransform(num_heads)
126
+ self.attn_drop = nn.Dropout(attn_drop)
127
+ self.softmax = nn.Softmax(dim=-1)
128
+
129
+ def forward(self, qkv, x_size, table, index, mask):
130
+ """
131
+ Args:
132
+ qkv: input QKV features with shape of (B, L, 3C)
133
+ x_size: use x_size to determine whether the relative positional bias table and index
134
+ need to be regenerated.
135
+ """
136
+ H, W = x_size
137
+ B, L, C = qkv.shape
138
+ qkv = qkv.view(B, H, W, C)
139
+
140
+ # cyclic shift
141
+ if self.shift_size > 0:
142
+ qkv = torch.roll(
143
+ qkv, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
144
+ )
145
+
146
+ # partition windows
147
+ qkv = window_partition(qkv, self.window_size) # nW*B, wh, ww, C
148
+ qkv = qkv.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
149
+
150
+ B_, N, _ = qkv.shape
151
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
152
+ q, k, v = qkv[0], qkv[1], qkv[2] # nW*B, H, wh*ww, dim
153
+
154
+ # attention
155
+ x = self.attn(q, k, v, self.attn_transform, table, index, mask)
156
+
157
+ # merge windows
158
+ x = x.view(-1, *self.window_size, C // 3)
159
+ x = window_reverse(x, self.window_size, x_size) # B, H, W, C/3
160
+
161
+ # reverse cyclic shift
162
+ if self.shift_size > 0:
163
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
164
+ x = x.view(B, L, C // 3)
165
+
166
+ return x
167
+
168
+ def extra_repr(self) -> str:
169
+ return (
170
+ f"window_size={self.window_size}, shift_size={self.shift_size}, "
171
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
172
+ )
173
+
174
+ def flops(self, N):
175
+ pass
176
+
177
+
178
+ class AnchorStripeAttention(Attention):
179
+ r"""Stripe attention
180
+ Args:
181
+ stripe_size (tuple[int]): The height and width of the stripe.
182
+ num_heads (int): Number of attention heads.
183
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
184
+ pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
185
+ """
186
+
187
+ def __init__(
188
+ self,
189
+ input_resolution,
190
+ stripe_size,
191
+ stripe_groups,
192
+ stripe_shift,
193
+ num_heads,
194
+ attn_drop=0.0,
195
+ pretrained_stripe_size=[0, 0],
196
+ anchor_window_down_factor=1,
197
+ args=None,
198
+ ):
199
+
200
+ super(AnchorStripeAttention, self).__init__()
201
+ self.input_resolution = input_resolution
202
+ self.stripe_size = stripe_size # Wh, Ww
203
+ self.stripe_groups = stripe_groups
204
+ self.stripe_shift = stripe_shift
205
+ self.num_heads = num_heads
206
+ self.pretrained_stripe_size = pretrained_stripe_size
207
+ self.anchor_window_down_factor = anchor_window_down_factor
208
+ self.euclidean_dist = args.euclidean_dist
209
+
210
+ self.attn_transform1 = AffineTransform(num_heads)
211
+ self.attn_transform2 = AffineTransform(num_heads)
212
+
213
+ self.attn_drop = nn.Dropout(attn_drop)
214
+ self.softmax = nn.Softmax(dim=-1)
215
+
216
+ def forward(
217
+ self, qkv, anchor, x_size, table, index_a2w, index_w2a, mask_a2w, mask_w2a
218
+ ):
219
+ """
220
+ Args:
221
+ qkv: input features with shape of (B, L, C)
222
+ anchor:
223
+ x_size: use stripe_size to determine whether the relative positional bias table and index
224
+ need to be regenerated.
225
+ """
226
+ H, W = x_size
227
+ B, L, C = qkv.shape
228
+ qkv = qkv.view(B, H, W, C)
229
+
230
+ stripe_size, shift_size = _get_stripe_info(
231
+ self.stripe_size, self.stripe_groups, self.stripe_shift, x_size
232
+ )
233
+ anchor_stripe_size = [s // self.anchor_window_down_factor for s in stripe_size]
234
+ anchor_shift_size = [s // self.anchor_window_down_factor for s in shift_size]
235
+ # cyclic shift
236
+ if self.stripe_shift:
237
+ qkv = torch.roll(qkv, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2))
238
+ anchor = torch.roll(
239
+ anchor,
240
+ shifts=(-anchor_shift_size[0], -anchor_shift_size[1]),
241
+ dims=(1, 2),
242
+ )
243
+
244
+ # partition windows
245
+ qkv = window_partition(qkv, stripe_size) # nW*B, wh, ww, C
246
+ qkv = qkv.view(-1, prod(stripe_size), C) # nW*B, wh*ww, C
247
+ anchor = window_partition(anchor, anchor_stripe_size)
248
+ anchor = anchor.view(-1, prod(anchor_stripe_size), C // 3)
249
+
250
+ B_, N1, _ = qkv.shape
251
+ N2 = anchor.shape[1]
252
+ qkv = qkv.reshape(B_, N1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
253
+ q, k, v = qkv[0], qkv[1], qkv[2]
254
+ anchor = anchor.reshape(B_, N2, self.num_heads, -1).permute(0, 2, 1, 3)
255
+
256
+ # attention
257
+ x = self.attn(
258
+ anchor, k, v, self.attn_transform1, table, index_a2w, mask_a2w, False
259
+ )
260
+ x = self.attn(q, anchor, x, self.attn_transform2, table, index_w2a, mask_w2a)
261
+
262
+ # merge windows
263
+ x = x.view(B_, *stripe_size, C // 3)
264
+ x = window_reverse(x, stripe_size, x_size) # B H' W' C
265
+
266
+ # reverse the shift
267
+ if self.stripe_shift:
268
+ x = torch.roll(x, shifts=shift_size, dims=(1, 2))
269
+
270
+ x = x.view(B, H * W, C // 3)
271
+ return x
272
+
273
+ def extra_repr(self) -> str:
274
+ return (
275
+ f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, "
276
+ f"pretrained_stripe_size={self.pretrained_stripe_size}, num_heads={self.num_heads}, anchor_window_down_factor={self.anchor_window_down_factor}"
277
+ )
278
+
279
+ def flops(self, N):
280
+ pass
281
+
282
+
283
+ class MixedAttention(nn.Module):
284
+ r"""Mixed window attention and stripe attention
285
+ Args:
286
+ dim (int): Number of input channels.
287
+ stripe_size (tuple[int]): The height and width of the stripe.
288
+ num_heads (int): Number of attention heads.
289
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
290
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
291
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
292
+ pretrained_stripe_size (tuple[int]): The height and width of the stripe in pre-training.
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ dim,
298
+ input_resolution,
299
+ num_heads_w,
300
+ num_heads_s,
301
+ window_size,
302
+ window_shift,
303
+ stripe_size,
304
+ stripe_groups,
305
+ stripe_shift,
306
+ qkv_bias=True,
307
+ qkv_proj_type="linear",
308
+ anchor_proj_type="separable_conv",
309
+ anchor_one_stage=True,
310
+ anchor_window_down_factor=1,
311
+ attn_drop=0.0,
312
+ proj_drop=0.0,
313
+ pretrained_window_size=[0, 0],
314
+ pretrained_stripe_size=[0, 0],
315
+ args=None,
316
+ ):
317
+
318
+ super(MixedAttention, self).__init__()
319
+ self.dim = dim
320
+ self.input_resolution = input_resolution
321
+ self.args = args
322
+ # print(args)
323
+ self.qkv = QKVProjection(dim, qkv_bias, qkv_proj_type, args)
324
+ # anchor is only used for stripe attention
325
+ self.anchor = AnchorProjection(
326
+ dim, anchor_proj_type, anchor_one_stage, anchor_window_down_factor, args
327
+ )
328
+
329
+ self.window_attn = WindowAttention(
330
+ input_resolution,
331
+ window_size,
332
+ num_heads_w,
333
+ window_shift,
334
+ attn_drop,
335
+ pretrained_window_size,
336
+ args,
337
+ )
338
+ self.stripe_attn = AnchorStripeAttention(
339
+ input_resolution,
340
+ stripe_size,
341
+ stripe_groups,
342
+ stripe_shift,
343
+ num_heads_s,
344
+ attn_drop,
345
+ pretrained_stripe_size,
346
+ anchor_window_down_factor,
347
+ args,
348
+ )
349
+ self.proj = nn.Linear(dim, dim)
350
+ self.proj_drop = nn.Dropout(proj_drop)
351
+
352
+ def forward(self, x, x_size, table_index_mask):
353
+ """
354
+ Args:
355
+ x: input features with shape of (B, L, C)
356
+ stripe_size: use stripe_size to determine whether the relative positional bias table and index
357
+ need to be regenerated.
358
+ """
359
+ B, L, C = x.shape
360
+
361
+ # qkv projection
362
+ qkv = self.qkv(x, x_size)
363
+ qkv_window, qkv_stripe = torch.split(qkv, C * 3 // 2, dim=-1)
364
+ # anchor projection
365
+ anchor = self.anchor(x, x_size)
366
+
367
+ # attention
368
+ x_window = self.window_attn(
369
+ qkv_window, x_size, *self._get_table_index_mask(table_index_mask, True)
370
+ )
371
+ x_stripe = self.stripe_attn(
372
+ qkv_stripe,
373
+ anchor,
374
+ x_size,
375
+ *self._get_table_index_mask(table_index_mask, False),
376
+ )
377
+ x = torch.cat([x_window, x_stripe], dim=-1)
378
+
379
+ # output projection
380
+ x = self.proj(x)
381
+ x = self.proj_drop(x)
382
+ return x
383
+
384
+ def _get_table_index_mask(self, table_index_mask, window_attn=True):
385
+ if window_attn:
386
+ return (
387
+ table_index_mask["table_w"],
388
+ table_index_mask["index_w"],
389
+ table_index_mask["mask_w"],
390
+ )
391
+ else:
392
+ return (
393
+ table_index_mask["table_s"],
394
+ table_index_mask["index_a2w"],
395
+ table_index_mask["index_w2a"],
396
+ table_index_mask["mask_a2w"],
397
+ table_index_mask["mask_w2a"],
398
+ )
399
+
400
+ def extra_repr(self) -> str:
401
+ return f"dim={self.dim}, input_resolution={self.input_resolution}"
402
+
403
+ def flops(self, N):
404
+ pass
405
+
406
+
407
+ class EfficientMixAttnTransformerBlock(nn.Module):
408
+ r"""Mix attention transformer block with shared QKV projection and output projection for mixed attention modules.
409
+ Args:
410
+ dim (int): Number of input channels.
411
+ input_resolution (tuple[int]): Input resulotion.
412
+ num_heads (int): Number of attention heads.
413
+ window_size (int): Window size.
414
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
415
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
416
+ drop (float, optional): Dropout rate. Default: 0.0
417
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
418
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
419
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
420
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
421
+ pretrained_stripe_size (int): Window size in pre-training.
422
+ attn_type (str, optional): Attention type. Default: cwhv.
423
+ c: residual blocks
424
+ w: window attention
425
+ h: horizontal stripe attention
426
+ v: vertical stripe attention
427
+ """
428
+
429
+ def __init__(
430
+ self,
431
+ dim,
432
+ input_resolution,
433
+ num_heads_w,
434
+ num_heads_s,
435
+ window_size=7,
436
+ window_shift=False,
437
+ stripe_size=[8, 8],
438
+ stripe_groups=[None, None],
439
+ stripe_shift=False,
440
+ stripe_type="H",
441
+ mlp_ratio=4.0,
442
+ qkv_bias=True,
443
+ qkv_proj_type="linear",
444
+ anchor_proj_type="separable_conv",
445
+ anchor_one_stage=True,
446
+ anchor_window_down_factor=1,
447
+ drop=0.0,
448
+ attn_drop=0.0,
449
+ drop_path=0.0,
450
+ act_layer=nn.GELU,
451
+ norm_layer=nn.LayerNorm,
452
+ pretrained_window_size=[0, 0],
453
+ pretrained_stripe_size=[0, 0],
454
+ res_scale=1.0,
455
+ args=None,
456
+ ):
457
+ super().__init__()
458
+ self.dim = dim
459
+ self.input_resolution = input_resolution
460
+ self.num_heads_w = num_heads_w
461
+ self.num_heads_s = num_heads_s
462
+ self.window_size = window_size
463
+ self.window_shift = window_shift
464
+ self.stripe_shift = stripe_shift
465
+ self.stripe_type = stripe_type
466
+ self.args = args
467
+ if self.stripe_type == "W":
468
+ self.stripe_size = stripe_size[::-1]
469
+ self.stripe_groups = stripe_groups[::-1]
470
+ else:
471
+ self.stripe_size = stripe_size
472
+ self.stripe_groups = stripe_groups
473
+ self.mlp_ratio = mlp_ratio
474
+ self.res_scale = res_scale
475
+
476
+ self.attn = MixedAttention(
477
+ dim,
478
+ input_resolution,
479
+ num_heads_w,
480
+ num_heads_s,
481
+ window_size,
482
+ window_shift,
483
+ self.stripe_size,
484
+ self.stripe_groups,
485
+ stripe_shift,
486
+ qkv_bias,
487
+ qkv_proj_type,
488
+ anchor_proj_type,
489
+ anchor_one_stage,
490
+ anchor_window_down_factor,
491
+ attn_drop,
492
+ drop,
493
+ pretrained_window_size,
494
+ pretrained_stripe_size,
495
+ args,
496
+ )
497
+ self.norm1 = norm_layer(dim)
498
+ if self.args.local_connection:
499
+ self.conv = CAB(dim)
500
+
501
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
502
+
503
+ self.mlp = Mlp(
504
+ in_features=dim,
505
+ hidden_features=int(dim * mlp_ratio),
506
+ act_layer=act_layer,
507
+ drop=drop,
508
+ )
509
+ self.norm2 = norm_layer(dim)
510
+
511
+ def _get_table_index_mask(self, all_table_index_mask):
512
+ table_index_mask = {
513
+ "table_w": all_table_index_mask["table_w"],
514
+ "index_w": all_table_index_mask["index_w"],
515
+ }
516
+ if self.stripe_type == "W":
517
+ table_index_mask["table_s"] = all_table_index_mask["table_sv"]
518
+ table_index_mask["index_a2w"] = all_table_index_mask["index_sv_a2w"]
519
+ table_index_mask["index_w2a"] = all_table_index_mask["index_sv_w2a"]
520
+ else:
521
+ table_index_mask["table_s"] = all_table_index_mask["table_sh"]
522
+ table_index_mask["index_a2w"] = all_table_index_mask["index_sh_a2w"]
523
+ table_index_mask["index_w2a"] = all_table_index_mask["index_sh_w2a"]
524
+ if self.window_shift:
525
+ table_index_mask["mask_w"] = all_table_index_mask["mask_w"]
526
+ else:
527
+ table_index_mask["mask_w"] = None
528
+ if self.stripe_shift:
529
+ if self.stripe_type == "W":
530
+ table_index_mask["mask_a2w"] = all_table_index_mask["mask_sv_a2w"]
531
+ table_index_mask["mask_w2a"] = all_table_index_mask["mask_sv_w2a"]
532
+ else:
533
+ table_index_mask["mask_a2w"] = all_table_index_mask["mask_sh_a2w"]
534
+ table_index_mask["mask_w2a"] = all_table_index_mask["mask_sh_w2a"]
535
+ else:
536
+ table_index_mask["mask_a2w"] = None
537
+ table_index_mask["mask_w2a"] = None
538
+ return table_index_mask
539
+
540
+ def forward(self, x, x_size, all_table_index_mask):
541
+ # Mixed attention
542
+ table_index_mask = self._get_table_index_mask(all_table_index_mask)
543
+ if self.args.local_connection:
544
+ x = (
545
+ x
546
+ + self.res_scale
547
+ * self.drop_path(self.norm1(self.attn(x, x_size, table_index_mask)))
548
+ + self.conv(x, x_size)
549
+ )
550
+ else:
551
+ x = x + self.res_scale * self.drop_path(
552
+ self.norm1(self.attn(x, x_size, table_index_mask))
553
+ )
554
+ # FFN
555
+ x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
556
+
557
+ return x
558
+
559
+ def extra_repr(self) -> str:
560
+ return (
561
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads=({self.num_heads_w}, {self.num_heads_s}), "
562
+ f"window_size={self.window_size}, window_shift={self.window_shift}, "
563
+ f"stripe_size={self.stripe_size}, stripe_groups={self.stripe_groups}, stripe_shift={self.stripe_shift}, self.stripe_type={self.stripe_type}, "
564
+ f"mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
565
+ )
566
+
567
+ def flops(self):
568
+ pass
architecture/grl_common/ops.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import prod
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ from timm.models.layers import to_2tuple
7
+
8
+
9
+ def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
10
+ """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C)."""
11
+ return x.permute(0, 2, 3, 1)
12
+
13
+
14
+ def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
15
+ """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W)."""
16
+ return x.permute(0, 3, 1, 2)
17
+
18
+
19
+ def bchw_to_blc(x: torch.Tensor) -> torch.Tensor:
20
+ """Rearrange a tensor from the shape (B, C, H, W) to (B, L, C)."""
21
+ return x.flatten(2).transpose(1, 2)
22
+
23
+
24
+ def blc_to_bchw(x: torch.Tensor, x_size: Tuple) -> torch.Tensor:
25
+ """Rearrange a tensor from the shape (B, L, C) to (B, C, H, W)."""
26
+ B, L, C = x.shape
27
+ return x.transpose(1, 2).view(B, C, *x_size)
28
+
29
+
30
+ def blc_to_bhwc(x: torch.Tensor, x_size: Tuple) -> torch.Tensor:
31
+ """Rearrange a tensor from the shape (B, L, C) to (B, H, W, C)."""
32
+ B, L, C = x.shape
33
+ return x.view(B, *x_size, C)
34
+
35
+
36
+ def window_partition(x, window_size: Tuple[int, int]):
37
+ """
38
+ Args:
39
+ x: (B, H, W, C)
40
+ window_size (int): window size
41
+
42
+ Returns:
43
+ windows: (num_windows*B, window_size, window_size, C)
44
+ """
45
+ B, H, W, C = x.shape
46
+ x = x.view(
47
+ B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C
48
+ )
49
+ windows = (
50
+ x.permute(0, 1, 3, 2, 4, 5)
51
+ .contiguous()
52
+ .view(-1, window_size[0], window_size[1], C)
53
+ )
54
+ return windows
55
+
56
+
57
+ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
58
+ """
59
+ Args:
60
+ windows: (num_windows * B, window_size[0], window_size[1], C)
61
+ window_size (Tuple[int, int]): Window size
62
+ img_size (Tuple[int, int]): Image size
63
+
64
+ Returns:
65
+ x: (B, H, W, C)
66
+ """
67
+ H, W = img_size
68
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
69
+ x = windows.view(
70
+ B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
71
+ )
72
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
73
+ return x
74
+
75
+
76
+ def _fill_window(input_resolution, window_size, shift_size=None):
77
+ if shift_size is None:
78
+ shift_size = [s // 2 for s in window_size]
79
+
80
+ img_mask = torch.zeros((1, *input_resolution, 1)) # 1 H W 1
81
+ h_slices = (
82
+ slice(0, -window_size[0]),
83
+ slice(-window_size[0], -shift_size[0]),
84
+ slice(-shift_size[0], None),
85
+ )
86
+ w_slices = (
87
+ slice(0, -window_size[1]),
88
+ slice(-window_size[1], -shift_size[1]),
89
+ slice(-shift_size[1], None),
90
+ )
91
+ cnt = 0
92
+ for h in h_slices:
93
+ for w in w_slices:
94
+ img_mask[:, h, w, :] = cnt
95
+ cnt += 1
96
+
97
+ mask_windows = window_partition(img_mask, window_size)
98
+ # nW, window_size, window_size, 1
99
+ mask_windows = mask_windows.view(-1, prod(window_size))
100
+ return mask_windows
101
+
102
+
103
+ #####################################
104
+ # Different versions of the functions
105
+ # 1) Swin Transformer, SwinIR, Square window attention in GRL;
106
+ # 2) Early development of the decomposition-based efficient attention mechanism (efficient_win_attn.py);
107
+ # 3) GRL. Window-anchor attention mechanism.
108
+ # 1) & 3) are still useful
109
+ #####################################
110
+
111
+
112
+ def calculate_mask(input_resolution, window_size, shift_size):
113
+ """
114
+ Use case: 1)
115
+ """
116
+ # calculate attention mask for SW-MSA
117
+ if isinstance(shift_size, int):
118
+ shift_size = to_2tuple(shift_size)
119
+ mask_windows = _fill_window(input_resolution, window_size, shift_size)
120
+
121
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
122
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
123
+ attn_mask == 0, float(0.0)
124
+ ) # nW, window_size**2, window_size**2
125
+
126
+ return attn_mask
127
+
128
+
129
+ def calculate_mask_all(
130
+ input_resolution,
131
+ window_size,
132
+ shift_size,
133
+ anchor_window_down_factor=1,
134
+ window_to_anchor=True,
135
+ ):
136
+ """
137
+ Use case: 3)
138
+ """
139
+ # calculate attention mask for SW-MSA
140
+ anchor_resolution = [s // anchor_window_down_factor for s in input_resolution]
141
+ aws = [s // anchor_window_down_factor for s in window_size]
142
+ anchor_shift = [s // anchor_window_down_factor for s in shift_size]
143
+
144
+ # mask of window1: nW, Wh**Ww
145
+ mask_windows = _fill_window(input_resolution, window_size, shift_size)
146
+ # mask of window2: nW, AWh*AWw
147
+ mask_anchor = _fill_window(anchor_resolution, aws, anchor_shift)
148
+
149
+ if window_to_anchor:
150
+ attn_mask = mask_windows.unsqueeze(2) - mask_anchor.unsqueeze(1)
151
+ else:
152
+ attn_mask = mask_anchor.unsqueeze(2) - mask_windows.unsqueeze(1)
153
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
154
+ attn_mask == 0, float(0.0)
155
+ ) # nW, Wh**Ww, AWh*AWw
156
+
157
+ return attn_mask
158
+
159
+
160
+ def calculate_win_mask(
161
+ input_resolution1, input_resolution2, window_size1, window_size2
162
+ ):
163
+ """
164
+ Use case: 2)
165
+ """
166
+ # calculate attention mask for SW-MSA
167
+
168
+ # mask of window1: nW, Wh**Ww
169
+ mask_windows1 = _fill_window(input_resolution1, window_size1)
170
+ # mask of window2: nW, AWh*AWw
171
+ mask_windows2 = _fill_window(input_resolution2, window_size2)
172
+
173
+ attn_mask = mask_windows1.unsqueeze(2) - mask_windows2.unsqueeze(1)
174
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
175
+ attn_mask == 0, float(0.0)
176
+ ) # nW, Wh**Ww, AWh*AWw
177
+
178
+ return attn_mask
179
+
180
+
181
+ def _get_meshgrid_coords(start_coords, end_coords):
182
+ coord_h = torch.arange(start_coords[0], end_coords[0])
183
+ coord_w = torch.arange(start_coords[1], end_coords[1])
184
+ coords = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")) # 2, Wh, Ww
185
+ coords = torch.flatten(coords, 1) # 2, Wh*Ww
186
+ return coords
187
+
188
+
189
+ def get_relative_coords_table(
190
+ window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1
191
+ ):
192
+ """
193
+ Use case: 1)
194
+ """
195
+ # get relative_coords_table
196
+ ws = window_size
197
+ aws = [w // anchor_window_down_factor for w in window_size]
198
+ pws = pretrained_window_size
199
+ paws = [w // anchor_window_down_factor for w in pretrained_window_size]
200
+
201
+ ts = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)]
202
+ pts = [(w1 + w2) // 2 for w1, w2 in zip(pws, paws)]
203
+
204
+ # TODO: pretrained window size and pretrained anchor window size is only used here.
205
+ # TODO: Investigate whether it is really important to use this setting when finetuning large window size
206
+ # TODO: based on pretrained weights with small window size.
207
+
208
+ coord_h = torch.arange(-(ts[0] - 1), ts[0], dtype=torch.float32)
209
+ coord_w = torch.arange(-(ts[1] - 1), ts[1], dtype=torch.float32)
210
+ table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
211
+ 1, 2, 0
212
+ )
213
+ table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2
214
+ if pts[0] > 0:
215
+ table[:, :, :, 0] /= pts[0] - 1
216
+ table[:, :, :, 1] /= pts[1] - 1
217
+ else:
218
+ table[:, :, :, 0] /= ts[0] - 1
219
+ table[:, :, :, 1] /= ts[1] - 1
220
+ table *= 8 # normalize to -8, 8
221
+ table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
222
+ return table
223
+
224
+
225
+ def get_relative_coords_table_all(
226
+ window_size, pretrained_window_size=[0, 0], anchor_window_down_factor=1
227
+ ):
228
+ """
229
+ Use case: 3)
230
+
231
+ Support all window shapes.
232
+ Args:
233
+ window_size:
234
+ pretrained_window_size:
235
+ anchor_window_down_factor:
236
+
237
+ Returns:
238
+
239
+ """
240
+ # get relative_coords_table
241
+ ws = window_size
242
+ aws = [w // anchor_window_down_factor for w in window_size]
243
+ pws = pretrained_window_size
244
+ paws = [w // anchor_window_down_factor for w in pretrained_window_size]
245
+
246
+ # positive table size: (Ww - 1) - (Ww - AWw) // 2
247
+ ts_p = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
248
+ # negative table size: -(AWw - 1) - (Ww - AWw) // 2
249
+ ts_n = [-(w2 - 1) - (w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
250
+ pts = [w1 - 1 - (w1 - w2) // 2 for w1, w2 in zip(pws, paws)]
251
+
252
+ # TODO: pretrained window size and pretrained anchor window size is only used here.
253
+ # TODO: Investigate whether it is really important to use this setting when finetuning large window size
254
+ # TODO: based on pretrained weights with small window size.
255
+
256
+ coord_h = torch.arange(ts_n[0], ts_p[0] + 1, dtype=torch.float32)
257
+ coord_w = torch.arange(ts_n[1], ts_p[1] + 1, dtype=torch.float32)
258
+ table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
259
+ 1, 2, 0
260
+ )
261
+ table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2
262
+ if pts[0] > 0:
263
+ table[:, :, :, 0] /= pts[0]
264
+ table[:, :, :, 1] /= pts[1]
265
+ else:
266
+ table[:, :, :, 0] /= ts_p[0]
267
+ table[:, :, :, 1] /= ts_p[1]
268
+ table *= 8 # normalize to -8, 8
269
+ table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
270
+ # 1, Wh+AWh-1, Ww+AWw-1, 2
271
+ return table
272
+
273
+
274
+ def coords_diff(coords1, coords2, max_diff):
275
+ # The coordinates starts from (-start_coord[0], -start_coord[1])
276
+ coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw
277
+ coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2
278
+ coords[:, :, 0] += max_diff[0] - 1 # shift to start from 0
279
+ coords[:, :, 1] += max_diff[1] - 1
280
+ coords[:, :, 0] *= 2 * max_diff[1] - 1
281
+ idx = coords.sum(-1) # Wh*Ww, AWh*AWw
282
+ return idx
283
+
284
+
285
+ def get_relative_position_index(
286
+ window_size, anchor_window_down_factor=1, window_to_anchor=True
287
+ ):
288
+ """
289
+ Use case: 1)
290
+ """
291
+ # get pair-wise relative position index for each token inside the window
292
+ ws = window_size
293
+ aws = [w // anchor_window_down_factor for w in window_size]
294
+ coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)]
295
+ coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
296
+
297
+ coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
298
+ coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end)
299
+ # 2, AWh*AWw
300
+
301
+ if window_to_anchor:
302
+ idx = coords_diff(coords, coords_anchor, max_diff=coords_anchor_end)
303
+ else:
304
+ idx = coords_diff(coords_anchor, coords, max_diff=coords_anchor_end)
305
+ return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww
306
+
307
+
308
+ def coords_diff_odd(coords1, coords2, start_coord, max_diff):
309
+ # The coordinates starts from (-start_coord[0], -start_coord[1])
310
+ coords = coords1[:, :, None] - coords2[:, None, :] # 2, Wh*Ww, AWh*AWw
311
+ coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2
312
+ coords[:, :, 0] += start_coord[0] # shift to start from 0
313
+ coords[:, :, 1] += start_coord[1]
314
+ coords[:, :, 0] *= max_diff
315
+ idx = coords.sum(-1) # Wh*Ww, AWh*AWw
316
+ return idx
317
+
318
+
319
+ def get_relative_position_index_all(
320
+ window_size, anchor_window_down_factor=1, window_to_anchor=True
321
+ ):
322
+ """
323
+ Use case: 3)
324
+ Support all window shapes:
325
+ square window - square window
326
+ rectangular window - rectangular window
327
+ window - anchor
328
+ anchor - window
329
+ [8, 8] - [8, 8]
330
+ [4, 86] - [2, 43]
331
+ """
332
+ # get pair-wise relative position index for each token inside the window
333
+ ws = window_size
334
+ aws = [w // anchor_window_down_factor for w in window_size]
335
+ coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
336
+ coords_anchor_end = [s + w2 for s, w2 in zip(coords_anchor_start, aws)]
337
+
338
+ coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
339
+ coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end)
340
+ # 2, AWh*AWw
341
+
342
+ max_horizontal_diff = aws[1] + ws[1] - 1
343
+ if window_to_anchor:
344
+ offset = [w2 + s - 1 for s, w2 in zip(coords_anchor_start, aws)]
345
+ idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff)
346
+ else:
347
+ offset = [w1 - s - 1 for s, w1 in zip(coords_anchor_start, ws)]
348
+ idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff)
349
+ return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww
350
+
351
+
352
+ def get_relative_position_index_simple(
353
+ window_size, anchor_window_down_factor=1, window_to_anchor=True
354
+ ):
355
+ """
356
+ Use case: 3)
357
+ This is a simplified version of get_relative_position_index_all
358
+ The start coordinate of anchor window is also (0, 0)
359
+ get pair-wise relative position index for each token inside the window
360
+ """
361
+ ws = window_size
362
+ aws = [w // anchor_window_down_factor for w in window_size]
363
+
364
+ coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
365
+ coords_anchor = _get_meshgrid_coords((0, 0), aws)
366
+ # 2, AWh*AWw
367
+
368
+ max_horizontal_diff = aws[1] + ws[1] - 1
369
+ if window_to_anchor:
370
+ offset = [w2 - 1 for w2 in aws]
371
+ idx = coords_diff_odd(coords, coords_anchor, offset, max_horizontal_diff)
372
+ else:
373
+ offset = [w1 - 1 for w1 in ws]
374
+ idx = coords_diff_odd(coords_anchor, coords, offset, max_horizontal_diff)
375
+ return idx # Wh*Ww, AWh*AWw or AWh*AWw, Wh*Ww
376
+
377
+
378
+ # def get_relative_position_index(window_size):
379
+ # # This is a very early version
380
+ # # get pair-wise relative position index for each token inside the window
381
+ # coords = _get_meshgrid_coords(start_coords=(0, 0), end_coords=window_size)
382
+
383
+ # coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
384
+ # coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
385
+ # coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
386
+ # coords[:, :, 1] += window_size[1] - 1
387
+ # coords[:, :, 0] *= 2 * window_size[1] - 1
388
+ # idx = coords.sum(-1) # Wh*Ww, Wh*Ww
389
+ # return idx
390
+
391
+
392
+ def get_relative_win_position_index(window_size, anchor_window_size):
393
+ """
394
+ Use case: 2)
395
+ """
396
+ # get pair-wise relative position index for each token inside the window
397
+ ws = window_size
398
+ aws = anchor_window_size
399
+ coords_anchor_end = [(w1 + w2) // 2 for w1, w2 in zip(ws, aws)]
400
+ coords_anchor_start = [(w1 - w2) // 2 for w1, w2 in zip(ws, aws)]
401
+
402
+ coords = _get_meshgrid_coords((0, 0), window_size) # 2, Wh*Ww
403
+ coords_anchor = _get_meshgrid_coords(coords_anchor_start, coords_anchor_end)
404
+ # 2, AWh*AWw
405
+ coords = coords[:, :, None] - coords_anchor[:, None, :] # 2, Wh*Ww, AWh*AWw
406
+ coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, AWh*AWw, 2
407
+ coords[:, :, 0] += coords_anchor_end[0] - 1 # shift to start from 0
408
+ coords[:, :, 1] += coords_anchor_end[1] - 1
409
+ coords[:, :, 0] *= 2 * coords_anchor_end[1] - 1
410
+ idx = coords.sum(-1) # Wh*Ww, AWh*AWw
411
+ return idx
412
+
413
+
414
+ # def get_relative_coords_table(window_size, pretrained_window_size):
415
+ # # This is a very early version
416
+ # # get relative_coords_table
417
+ # ws = window_size
418
+ # pws = pretrained_window_size
419
+ # coord_h = torch.arange(-(ws[0] - 1), ws[0], dtype=torch.float32)
420
+ # coord_w = torch.arange(-(ws[1] - 1), ws[1], dtype=torch.float32)
421
+ # table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing='ij')).permute(1, 2, 0)
422
+ # table = table.contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
423
+ # if pws[0] > 0:
424
+ # table[:, :, :, 0] /= pws[0] - 1
425
+ # table[:, :, :, 1] /= pws[1] - 1
426
+ # else:
427
+ # table[:, :, :, 0] /= ws[0] - 1
428
+ # table[:, :, :, 1] /= ws[1] - 1
429
+ # table *= 8 # normalize to -8, 8
430
+ # table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
431
+ # return table
432
+
433
+
434
+ def get_relative_win_coords_table(
435
+ window_size,
436
+ anchor_window_size,
437
+ pretrained_window_size=[0, 0],
438
+ pretrained_anchor_window_size=[0, 0],
439
+ ):
440
+ """
441
+ Use case: 2)
442
+ """
443
+ # get relative_coords_table
444
+ ws = window_size
445
+ aws = anchor_window_size
446
+ pws = pretrained_window_size
447
+ paws = pretrained_anchor_window_size
448
+
449
+ # TODO: pretrained window size and pretrained anchor window size is only used here.
450
+ # TODO: Investigate whether it is really important to use this setting when finetuning large window size
451
+ # TODO: based on pretrained weights with small window size.
452
+
453
+ table_size = [(wsi + awsi) // 2 for wsi, awsi in zip(ws, aws)]
454
+ table_size_pretrained = [(pwsi + pawsi) // 2 for pwsi, pawsi in zip(pws, paws)]
455
+ coord_h = torch.arange(-(table_size[0] - 1), table_size[0], dtype=torch.float32)
456
+ coord_w = torch.arange(-(table_size[1] - 1), table_size[1], dtype=torch.float32)
457
+ table = torch.stack(torch.meshgrid([coord_h, coord_w], indexing="ij")).permute(
458
+ 1, 2, 0
459
+ )
460
+ table = table.contiguous().unsqueeze(0) # 1, Wh+AWh-1, Ww+AWw-1, 2
461
+ if table_size_pretrained[0] > 0:
462
+ table[:, :, :, 0] /= table_size_pretrained[0] - 1
463
+ table[:, :, :, 1] /= table_size_pretrained[1] - 1
464
+ else:
465
+ table[:, :, :, 0] /= table_size[0] - 1
466
+ table[:, :, :, 1] /= table_size[1] - 1
467
+ table *= 8 # normalize to -8, 8
468
+ table = torch.sign(table) * torch.log2(torch.abs(table) + 1.0) / np.log2(8)
469
+ return table
470
+
471
+
472
+ if __name__ == "__main__":
473
+ table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=2)
474
+ table = table.view(-1, 2)
475
+ index1 = get_relative_position_index_all((4, 86), 2, False)
476
+ index2 = get_relative_position_index_simple((4, 86), 2, False)
477
+ print(index2)
478
+ index3 = get_relative_position_index_all((4, 86), 2)
479
+ index4 = get_relative_position_index_simple((4, 86), 2)
480
+ print(index4)
481
+ print(
482
+ table.shape,
483
+ index2.shape,
484
+ index2.max(),
485
+ index2.min(),
486
+ index4.shape,
487
+ index4.max(),
488
+ index4.min(),
489
+ torch.allclose(index1, index2),
490
+ torch.allclose(index3, index4),
491
+ )
492
+
493
+ table = get_relative_coords_table_all((4, 86), anchor_window_down_factor=1)
494
+ table = table.view(-1, 2)
495
+ index1 = get_relative_position_index_all((4, 86), 1, False)
496
+ index2 = get_relative_position_index_simple((4, 86), 1, False)
497
+ # print(index1)
498
+ index3 = get_relative_position_index_all((4, 86), 1)
499
+ index4 = get_relative_position_index_simple((4, 86), 1)
500
+ # print(index2)
501
+ print(
502
+ table.shape,
503
+ index2.shape,
504
+ index2.max(),
505
+ index2.min(),
506
+ index4.shape,
507
+ index4.max(),
508
+ index4.min(),
509
+ torch.allclose(index1, index2),
510
+ torch.allclose(index3, index4),
511
+ )
512
+
513
+ table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=2)
514
+ table = table.view(-1, 2)
515
+ index1 = get_relative_position_index_all((8, 8), 2, False)
516
+ index2 = get_relative_position_index_simple((8, 8), 2, False)
517
+ # print(index1)
518
+ index3 = get_relative_position_index_all((8, 8), 2)
519
+ index4 = get_relative_position_index_simple((8, 8), 2)
520
+ # print(index2)
521
+ print(
522
+ table.shape,
523
+ index2.shape,
524
+ index2.max(),
525
+ index2.min(),
526
+ index4.shape,
527
+ index4.max(),
528
+ index4.min(),
529
+ torch.allclose(index1, index2),
530
+ torch.allclose(index3, index4),
531
+ )
532
+
533
+ table = get_relative_coords_table_all((8, 8), anchor_window_down_factor=1)
534
+ table = table.view(-1, 2)
535
+ index1 = get_relative_position_index_all((8, 8), 1, False)
536
+ index2 = get_relative_position_index_simple((8, 8), 1, False)
537
+ # print(index1)
538
+ index3 = get_relative_position_index_all((8, 8), 1)
539
+ index4 = get_relative_position_index_simple((8, 8), 1)
540
+ # print(index2)
541
+ print(
542
+ table.shape,
543
+ index2.shape,
544
+ index2.max(),
545
+ index2.min(),
546
+ index4.shape,
547
+ index4.max(),
548
+ index4.min(),
549
+ torch.allclose(index1, index2),
550
+ torch.allclose(index3, index4),
551
+ )
architecture/grl_common/resblock.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class ResBlock(nn.Module):
5
+ """Residual block without BN.
6
+
7
+ It has a style of:
8
+
9
+ ::
10
+
11
+ ---Conv-ReLU-Conv-+-
12
+ |________________|
13
+
14
+ Args:
15
+ num_feats (int): Channel number of intermediate features.
16
+ Default: 64.
17
+ res_scale (float): Used to scale the residual before addition.
18
+ Default: 1.0.
19
+ """
20
+
21
+ def __init__(self, num_feats=64, res_scale=1.0, bias=True, shortcut=True):
22
+ super().__init__()
23
+ self.res_scale = res_scale
24
+ self.shortcut = shortcut
25
+ self.conv1 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
26
+ self.conv2 = nn.Conv2d(num_feats, num_feats, 3, 1, 1, bias=bias)
27
+ self.relu = nn.ReLU(inplace=True)
28
+
29
+ def forward(self, x):
30
+ """Forward function.
31
+
32
+ Args:
33
+ x (Tensor): Input tensor with shape (n, c, h, w).
34
+
35
+ Returns:
36
+ Tensor: Forward results.
37
+ """
38
+
39
+ identity = x
40
+ out = self.conv2(self.relu(self.conv1(x)))
41
+ if self.shortcut:
42
+ return identity + out * self.res_scale
43
+ else:
44
+ return out * self.res_scale
45
+
46
+
47
+ class ResBlockWrapper(ResBlock):
48
+ "Used for transformers"
49
+
50
+ def __init__(self, num_feats, bias=True, shortcut=True):
51
+ super(ResBlockWrapper, self).__init__(
52
+ num_feats=num_feats, bias=bias, shortcut=shortcut
53
+ )
54
+
55
+ def forward(self, x, x_size):
56
+ H, W = x_size
57
+ B, L, C = x.shape
58
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2)
59
+ x = super(ResBlockWrapper, self).forward(x)
60
+ x = x.flatten(2).permute(0, 2, 1)
61
+ return x
architecture/grl_common/swin_v1_block.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import prod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from architecture.grl_common.ops import (
6
+ bchw_to_blc,
7
+ blc_to_bchw,
8
+ calculate_mask,
9
+ window_partition,
10
+ window_reverse,
11
+ )
12
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13
+
14
+
15
+ class Mlp(nn.Module):
16
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
17
+
18
+ def __init__(
19
+ self,
20
+ in_features,
21
+ hidden_features=None,
22
+ out_features=None,
23
+ act_layer=nn.GELU,
24
+ drop=0.0,
25
+ ):
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ drop_probs = to_2tuple(drop)
30
+
31
+ self.fc1 = nn.Linear(in_features, hidden_features)
32
+ self.act = act_layer()
33
+ self.drop1 = nn.Dropout(drop_probs[0])
34
+ self.fc2 = nn.Linear(hidden_features, out_features)
35
+ self.drop2 = nn.Dropout(drop_probs[1])
36
+
37
+ def forward(self, x):
38
+ x = self.fc1(x)
39
+ x = self.act(x)
40
+ x = self.drop1(x)
41
+ x = self.fc2(x)
42
+ x = self.drop2(x)
43
+ return x
44
+
45
+
46
+ class WindowAttentionV1(nn.Module):
47
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
48
+ It supports both of shifted and non-shifted window.
49
+ Args:
50
+ dim (int): Number of input channels.
51
+ window_size (tuple[int]): The height and width of the window.
52
+ num_heads (int): Number of attention heads.
53
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
54
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
55
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
56
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ dim,
62
+ window_size,
63
+ num_heads,
64
+ qkv_bias=True,
65
+ qk_scale=None,
66
+ attn_drop=0.0,
67
+ proj_drop=0.0,
68
+ use_pe=True,
69
+ ):
70
+
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.window_size = window_size # Wh, Ww
74
+ self.num_heads = num_heads
75
+ head_dim = dim // num_heads
76
+ self.scale = qk_scale or head_dim**-0.5
77
+ self.use_pe = use_pe
78
+
79
+ if self.use_pe:
80
+ # define a parameter table of relative position bias
81
+ ws = self.window_size
82
+ table = torch.zeros((2 * ws[0] - 1) * (2 * ws[1] - 1), num_heads)
83
+ self.relative_position_bias_table = nn.Parameter(table)
84
+ # 2*Wh-1 * 2*Ww-1, nH
85
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
86
+
87
+ self.get_relative_position_index(self.window_size)
88
+
89
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
90
+ self.attn_drop = nn.Dropout(attn_drop)
91
+ self.proj = nn.Linear(dim, dim)
92
+
93
+ self.proj_drop = nn.Dropout(proj_drop)
94
+
95
+ self.softmax = nn.Softmax(dim=-1)
96
+
97
+ def get_relative_position_index(self, window_size):
98
+ # get pair-wise relative position index for each token inside the window
99
+ coord_h = torch.arange(window_size[0])
100
+ coord_w = torch.arange(window_size[1])
101
+ coords = torch.stack(torch.meshgrid([coord_h, coord_w])) # 2, Wh, Ww
102
+ coords = torch.flatten(coords, 1) # 2, Wh*Ww
103
+ coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
104
+ coords = coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
105
+ coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
106
+ coords[:, :, 1] += window_size[1] - 1
107
+ coords[:, :, 0] *= 2 * window_size[1] - 1
108
+ relative_position_index = coords.sum(-1) # Wh*Ww, Wh*Ww
109
+ self.register_buffer("relative_position_index", relative_position_index)
110
+
111
+ def forward(self, x, mask=None):
112
+ """
113
+ Args:
114
+ x: input features with shape of (num_windows*B, N, C)
115
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
116
+ """
117
+ B_, N, C = x.shape
118
+
119
+ # qkv projection
120
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
121
+ q, k, v = qkv[0], qkv[1], qkv[2]
122
+
123
+ # attention map
124
+ q = q * self.scale
125
+ attn = q @ k.transpose(-2, -1)
126
+
127
+ # positional encoding
128
+ if self.use_pe:
129
+ win_dim = prod(self.window_size)
130
+ bias = self.relative_position_bias_table[
131
+ self.relative_position_index.view(-1)
132
+ ]
133
+ bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous()
134
+ # nH, Wh*Ww, Wh*Ww
135
+ attn = attn + bias.unsqueeze(0)
136
+
137
+ # shift attention mask
138
+ if mask is not None:
139
+ nW = mask.shape[0]
140
+ mask = mask.unsqueeze(1).unsqueeze(0)
141
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
142
+ attn = attn.view(-1, self.num_heads, N, N)
143
+
144
+ # attention
145
+ attn = self.softmax(attn)
146
+ attn = self.attn_drop(attn)
147
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
148
+
149
+ # output projection
150
+ x = self.proj(x)
151
+ x = self.proj_drop(x)
152
+ return x
153
+
154
+ def extra_repr(self) -> str:
155
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
156
+
157
+ def flops(self, N):
158
+ # calculate flops for 1 window with token length of N
159
+ flops = 0
160
+ # qkv = self.qkv(x)
161
+ flops += N * self.dim * 3 * self.dim
162
+ # attn = (q @ k.transpose(-2, -1))
163
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
164
+ # x = (attn @ v)
165
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
166
+ # x = self.proj(x)
167
+ flops += N * self.dim * self.dim
168
+ return flops
169
+
170
+
171
+ class WindowAttentionWrapperV1(WindowAttentionV1):
172
+ def __init__(self, shift_size, input_resolution, **kwargs):
173
+ super(WindowAttentionWrapperV1, self).__init__(**kwargs)
174
+ self.shift_size = shift_size
175
+ self.input_resolution = input_resolution
176
+
177
+ if self.shift_size > 0:
178
+ attn_mask = calculate_mask(input_resolution, self.window_size, shift_size)
179
+ else:
180
+ attn_mask = None
181
+ self.register_buffer("attn_mask", attn_mask)
182
+
183
+ def forward(self, x, x_size):
184
+ H, W = x_size
185
+ B, L, C = x.shape
186
+ x = x.view(B, H, W, C)
187
+
188
+ # cyclic shift
189
+ if self.shift_size > 0:
190
+ x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
191
+
192
+ # partition windows
193
+ x = window_partition(x, self.window_size) # nW*B, wh, ww, C
194
+ x = x.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
195
+
196
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
197
+ if self.input_resolution == x_size:
198
+ attn_mask = self.attn_mask
199
+ else:
200
+ attn_mask = calculate_mask(x_size, self.window_size, self.shift_size)
201
+ attn_mask = attn_mask.to(x.device)
202
+
203
+ # attention
204
+ x = super(WindowAttentionWrapperV1, self).forward(x, mask=attn_mask)
205
+ # nW*B, wh*ww, C
206
+
207
+ # merge windows
208
+ x = x.view(-1, *self.window_size, C)
209
+ x = window_reverse(x, self.window_size, x_size) # B, H, W, C
210
+
211
+ # reverse cyclic shift
212
+ if self.shift_size > 0:
213
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
214
+ x = x.view(B, H * W, C)
215
+
216
+ return x
217
+
218
+
219
+ class SwinTransformerBlockV1(nn.Module):
220
+ r"""Swin Transformer Block.
221
+ Args:
222
+ dim (int): Number of input channels.
223
+ input_resolution (tuple[int]): Input resulotion.
224
+ num_heads (int): Number of attention heads.
225
+ window_size (int): Window size.
226
+ shift_size (int): Shift size for SW-MSA.
227
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
228
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
229
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
230
+ drop (float, optional): Dropout rate. Default: 0.0
231
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
232
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
233
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
234
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ dim,
240
+ input_resolution,
241
+ num_heads,
242
+ window_size=7,
243
+ shift_size=0,
244
+ mlp_ratio=4.0,
245
+ qkv_bias=True,
246
+ qk_scale=None,
247
+ drop=0.0,
248
+ attn_drop=0.0,
249
+ drop_path=0.0,
250
+ act_layer=nn.GELU,
251
+ norm_layer=nn.LayerNorm,
252
+ use_pe=True,
253
+ res_scale=1.0,
254
+ ):
255
+ super().__init__()
256
+ self.dim = dim
257
+ self.input_resolution = input_resolution
258
+ self.num_heads = num_heads
259
+ self.window_size = window_size
260
+ self.shift_size = shift_size
261
+ self.mlp_ratio = mlp_ratio
262
+ if min(self.input_resolution) <= self.window_size:
263
+ # if window size is larger than input resolution, we don't partition windows
264
+ self.shift_size = 0
265
+ self.window_size = min(self.input_resolution)
266
+ assert (
267
+ 0 <= self.shift_size < self.window_size
268
+ ), "shift_size must in 0-window_size"
269
+ self.res_scale = res_scale
270
+
271
+ self.norm1 = norm_layer(dim)
272
+ self.attn = WindowAttentionWrapperV1(
273
+ shift_size=self.shift_size,
274
+ input_resolution=self.input_resolution,
275
+ dim=dim,
276
+ window_size=to_2tuple(self.window_size),
277
+ num_heads=num_heads,
278
+ qkv_bias=qkv_bias,
279
+ qk_scale=qk_scale,
280
+ attn_drop=attn_drop,
281
+ proj_drop=drop,
282
+ use_pe=use_pe,
283
+ )
284
+
285
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
286
+
287
+ self.norm2 = norm_layer(dim)
288
+ self.mlp = Mlp(
289
+ in_features=dim,
290
+ hidden_features=int(dim * mlp_ratio),
291
+ act_layer=act_layer,
292
+ drop=drop,
293
+ )
294
+
295
+ def forward(self, x, x_size):
296
+ # Window attention
297
+ x = x + self.res_scale * self.drop_path(self.attn(self.norm1(x), x_size))
298
+ # FFN
299
+ x = x + self.res_scale * self.drop_path(self.mlp(self.norm2(x)))
300
+
301
+ return x
302
+
303
+ def extra_repr(self) -> str:
304
+ return (
305
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
306
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
307
+ )
308
+
309
+ def flops(self):
310
+ flops = 0
311
+ H, W = self.input_resolution
312
+ # norm1
313
+ flops += self.dim * H * W
314
+ # W-MSA/SW-MSA
315
+ nW = H * W / self.window_size / self.window_size
316
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
317
+ # mlp
318
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
319
+ # norm2
320
+ flops += self.dim * H * W
321
+ return flops
322
+
323
+
324
+ class PatchMerging(nn.Module):
325
+ r"""Patch Merging Layer.
326
+ Args:
327
+ input_resolution (tuple[int]): Resolution of input feature.
328
+ dim (int): Number of input channels.
329
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
330
+ """
331
+
332
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
333
+ super().__init__()
334
+ self.input_resolution = input_resolution
335
+ self.dim = dim
336
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
337
+ self.norm = norm_layer(4 * dim)
338
+
339
+ def forward(self, x):
340
+ """
341
+ x: B, H*W, C
342
+ """
343
+ H, W = self.input_resolution
344
+ B, L, C = x.shape
345
+ assert L == H * W, "input feature has wrong size"
346
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
347
+
348
+ x = x.view(B, H, W, C)
349
+
350
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
351
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
352
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
353
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
354
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
355
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
356
+
357
+ x = self.norm(x)
358
+ x = self.reduction(x)
359
+
360
+ return x
361
+
362
+ def extra_repr(self) -> str:
363
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
364
+
365
+ def flops(self):
366
+ H, W = self.input_resolution
367
+ flops = H * W * self.dim
368
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
369
+ return flops
370
+
371
+
372
+ class PatchEmbed(nn.Module):
373
+ r"""Image to Patch Embedding
374
+ Args:
375
+ img_size (int): Image size. Default: 224.
376
+ patch_size (int): Patch token size. Default: 4.
377
+ in_chans (int): Number of input image channels. Default: 3.
378
+ embed_dim (int): Number of linear projection output channels. Default: 96.
379
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
380
+ """
381
+
382
+ def __init__(
383
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
384
+ ):
385
+ super().__init__()
386
+ img_size = to_2tuple(img_size)
387
+ patch_size = to_2tuple(patch_size)
388
+ patches_resolution = [
389
+ img_size[0] // patch_size[0],
390
+ img_size[1] // patch_size[1],
391
+ ]
392
+ self.img_size = img_size
393
+ self.patch_size = patch_size
394
+ self.patches_resolution = patches_resolution
395
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
396
+
397
+ self.in_chans = in_chans
398
+ self.embed_dim = embed_dim
399
+
400
+ if norm_layer is not None:
401
+ self.norm = norm_layer(embed_dim)
402
+ else:
403
+ self.norm = None
404
+
405
+ def forward(self, x):
406
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
407
+ if self.norm is not None:
408
+ x = self.norm(x)
409
+ return x
410
+
411
+ def flops(self):
412
+ flops = 0
413
+ H, W = self.img_size
414
+ if self.norm is not None:
415
+ flops += H * W * self.embed_dim
416
+ return flops
417
+
418
+
419
+ class PatchUnEmbed(nn.Module):
420
+ r"""Image to Patch Unembedding
421
+ Args:
422
+ img_size (int): Image size. Default: 224.
423
+ patch_size (int): Patch token size. Default: 4.
424
+ in_chans (int): Number of input image channels. Default: 3.
425
+ embed_dim (int): Number of linear projection output channels. Default: 96.
426
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
427
+ """
428
+
429
+ def __init__(
430
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
431
+ ):
432
+ super().__init__()
433
+ img_size = to_2tuple(img_size)
434
+ patch_size = to_2tuple(patch_size)
435
+ patches_resolution = [
436
+ img_size[0] // patch_size[0],
437
+ img_size[1] // patch_size[1],
438
+ ]
439
+ self.img_size = img_size
440
+ self.patch_size = patch_size
441
+ self.patches_resolution = patches_resolution
442
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
443
+
444
+ self.in_chans = in_chans
445
+ self.embed_dim = embed_dim
446
+
447
+ def forward(self, x, x_size):
448
+ B, HW, C = x.shape
449
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
450
+ return x
451
+
452
+ def flops(self):
453
+ flops = 0
454
+ return flops
455
+
456
+
457
+ class Linear(nn.Linear):
458
+ def __init__(self, in_features, out_features, bias=True):
459
+ super(Linear, self).__init__(in_features, out_features, bias)
460
+
461
+ def forward(self, x):
462
+ B, C, H, W = x.shape
463
+ x = bchw_to_blc(x)
464
+ x = super(Linear, self).forward(x)
465
+ x = blc_to_bchw(x, (H, W))
466
+ return x
467
+
468
+
469
+ def build_last_conv(conv_type, dim):
470
+ if conv_type == "1conv":
471
+ block = nn.Conv2d(dim, dim, 3, 1, 1)
472
+ elif conv_type == "3conv":
473
+ # to save parameters and memory
474
+ block = nn.Sequential(
475
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
476
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
477
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
478
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
479
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
480
+ )
481
+ elif conv_type == "1conv1x1":
482
+ block = nn.Conv2d(dim, dim, 1, 1, 0)
483
+ elif conv_type == "linear":
484
+ block = Linear(dim, dim)
485
+ return block
486
+
487
+
488
+ # class BasicLayer(nn.Module):
489
+ # """A basic Swin Transformer layer for one stage.
490
+ # Args:
491
+ # dim (int): Number of input channels.
492
+ # input_resolution (tuple[int]): Input resolution.
493
+ # depth (int): Number of blocks.
494
+ # num_heads (int): Number of attention heads.
495
+ # window_size (int): Local window size.
496
+ # mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
497
+ # qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
498
+ # qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
499
+ # drop (float, optional): Dropout rate. Default: 0.0
500
+ # attn_drop (float, optional): Attention dropout rate. Default: 0.0
501
+ # drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
502
+ # norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
503
+ # downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
504
+ # args: Additional arguments
505
+ # """
506
+
507
+ # def __init__(
508
+ # self,
509
+ # dim,
510
+ # input_resolution,
511
+ # depth,
512
+ # num_heads,
513
+ # window_size,
514
+ # mlp_ratio=4.0,
515
+ # qkv_bias=True,
516
+ # qk_scale=None,
517
+ # drop=0.0,
518
+ # attn_drop=0.0,
519
+ # drop_path=0.0,
520
+ # norm_layer=nn.LayerNorm,
521
+ # downsample=None,
522
+ # args=None,
523
+ # ):
524
+
525
+ # super().__init__()
526
+ # self.dim = dim
527
+ # self.input_resolution = input_resolution
528
+ # self.depth = depth
529
+
530
+ # # build blocks
531
+ # self.blocks = nn.ModuleList(
532
+ # [
533
+ # _parse_block(
534
+ # dim=dim,
535
+ # input_resolution=input_resolution,
536
+ # num_heads=num_heads,
537
+ # window_size=window_size,
538
+ # shift_size=0
539
+ # if args.no_shift
540
+ # else (0 if (i % 2 == 0) else window_size // 2),
541
+ # mlp_ratio=mlp_ratio,
542
+ # qkv_bias=qkv_bias,
543
+ # qk_scale=qk_scale,
544
+ # drop=drop,
545
+ # attn_drop=attn_drop,
546
+ # drop_path=drop_path[i]
547
+ # if isinstance(drop_path, list)
548
+ # else drop_path,
549
+ # norm_layer=norm_layer,
550
+ # stripe_type="H" if (i % 2 == 0) else "W",
551
+ # args=args,
552
+ # )
553
+ # for i in range(depth)
554
+ # ]
555
+ # )
556
+ # # self.blocks = nn.ModuleList(
557
+ # # [
558
+ # # STV1Block(
559
+ # # dim=dim,
560
+ # # input_resolution=input_resolution,
561
+ # # num_heads=num_heads,
562
+ # # window_size=window_size,
563
+ # # shift_size=0 if (i % 2 == 0) else window_size // 2,
564
+ # # mlp_ratio=mlp_ratio,
565
+ # # qkv_bias=qkv_bias,
566
+ # # qk_scale=qk_scale,
567
+ # # drop=drop,
568
+ # # attn_drop=attn_drop,
569
+ # # drop_path=drop_path[i]
570
+ # # if isinstance(drop_path, list)
571
+ # # else drop_path,
572
+ # # norm_layer=norm_layer,
573
+ # # )
574
+ # # for i in range(depth)
575
+ # # ]
576
+ # # )
577
+
578
+ # # patch merging layer
579
+ # if downsample is not None:
580
+ # self.downsample = downsample(
581
+ # input_resolution, dim=dim, norm_layer=norm_layer
582
+ # )
583
+ # else:
584
+ # self.downsample = None
585
+
586
+ # def forward(self, x, x_size):
587
+ # for blk in self.blocks:
588
+ # x = blk(x, x_size)
589
+ # if self.downsample is not None:
590
+ # x = self.downsample(x)
591
+ # return x
592
+
593
+ # def extra_repr(self) -> str:
594
+ # return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
595
+
596
+ # def flops(self):
597
+ # flops = 0
598
+ # for blk in self.blocks:
599
+ # flops += blk.flops()
600
+ # if self.downsample is not None:
601
+ # flops += self.downsample.flops()
602
+ # return flops
architecture/grl_common/swin_v2_block.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from math import prod
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from architecture.grl_common.ops import (
8
+ calculate_mask,
9
+ get_relative_coords_table,
10
+ get_relative_position_index,
11
+ window_partition,
12
+ window_reverse,
13
+ )
14
+ from architecture.grl_common.swin_v1_block import Mlp
15
+ from timm.models.layers import DropPath, to_2tuple
16
+
17
+
18
+ class WindowAttentionV2(nn.Module):
19
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
20
+ It supports both of shifted and non-shifted window.
21
+ Args:
22
+ dim (int): Number of input channels.
23
+ window_size (tuple[int]): The height and width of the window.
24
+ num_heads (int): Number of attention heads.
25
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
26
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
27
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
28
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ dim,
34
+ window_size,
35
+ num_heads,
36
+ qkv_bias=True,
37
+ attn_drop=0.0,
38
+ proj_drop=0.0,
39
+ pretrained_window_size=[0, 0],
40
+ use_pe=True,
41
+ ):
42
+
43
+ super().__init__()
44
+ self.dim = dim
45
+ self.window_size = window_size # Wh, Ww
46
+ self.pretrained_window_size = pretrained_window_size
47
+ self.num_heads = num_heads
48
+ self.use_pe = use_pe
49
+
50
+ self.logit_scale = nn.Parameter(
51
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
52
+ )
53
+
54
+ if self.use_pe:
55
+ # mlp to generate continuous relative position bias
56
+ self.cpb_mlp = nn.Sequential(
57
+ nn.Linear(2, 512, bias=True),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear(512, num_heads, bias=False),
60
+ )
61
+ table = get_relative_coords_table(window_size, pretrained_window_size)
62
+ index = get_relative_position_index(window_size)
63
+ self.register_buffer("relative_coords_table", table)
64
+ self.register_buffer("relative_position_index", index)
65
+
66
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
67
+ # self.qkv = nn.Linear(dim, dim * 3, bias=False)
68
+ # if qkv_bias:
69
+ # self.q_bias = nn.Parameter(torch.zeros(dim))
70
+ # self.v_bias = nn.Parameter(torch.zeros(dim))
71
+ # else:
72
+ # self.q_bias = None
73
+ # self.v_bias = None
74
+ self.attn_drop = nn.Dropout(attn_drop)
75
+ self.proj = nn.Linear(dim, dim)
76
+ self.proj_drop = nn.Dropout(proj_drop)
77
+ self.softmax = nn.Softmax(dim=-1)
78
+
79
+ def forward(self, x, mask=None):
80
+ """
81
+ Args:
82
+ x: input features with shape of (num_windows*B, N, C)
83
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
84
+ """
85
+ B_, N, C = x.shape
86
+
87
+ # qkv projection
88
+ # qkv_bias = None
89
+ # if self.q_bias is not None:
90
+ # qkv_bias = torch.cat(
91
+ # (
92
+ # self.q_bias,
93
+ # torch.zeros_like(self.v_bias, requires_grad=False),
94
+ # self.v_bias,
95
+ # )
96
+ # )
97
+ # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
98
+ qkv = self.qkv(x)
99
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
100
+ q, k, v = qkv[0], qkv[1], qkv[2]
101
+
102
+ # cosine attention map
103
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
104
+ logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
105
+ attn = attn * logit_scale
106
+
107
+ # positional encoding
108
+ if self.use_pe:
109
+ bias_table = self.cpb_mlp(self.relative_coords_table)
110
+ bias_table = bias_table.view(-1, self.num_heads)
111
+
112
+ win_dim = prod(self.window_size)
113
+ bias = bias_table[self.relative_position_index.view(-1)]
114
+ bias = bias.view(win_dim, win_dim, -1).permute(2, 0, 1).contiguous()
115
+ # nH, Wh*Ww, Wh*Ww
116
+ bias = 16 * torch.sigmoid(bias)
117
+ attn = attn + bias.unsqueeze(0)
118
+
119
+ # shift attention mask
120
+ if mask is not None:
121
+ nW = mask.shape[0]
122
+ mask = mask.unsqueeze(1).unsqueeze(0)
123
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
124
+ attn = attn.view(-1, self.num_heads, N, N)
125
+
126
+ # attention
127
+ attn = self.softmax(attn)
128
+ attn = self.attn_drop(attn)
129
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
130
+
131
+ # output projection
132
+ x = self.proj(x)
133
+ x = self.proj_drop(x)
134
+ return x
135
+
136
+ def extra_repr(self) -> str:
137
+ return (
138
+ f"dim={self.dim}, window_size={self.window_size}, "
139
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
140
+ )
141
+
142
+ def flops(self, N):
143
+ # calculate flops for 1 window with token length of N
144
+ flops = 0
145
+ # qkv = self.qkv(x)
146
+ flops += N * self.dim * 3 * self.dim
147
+ # attn = (q @ k.transpose(-2, -1))
148
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
149
+ # x = (attn @ v)
150
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
151
+ # x = self.proj(x)
152
+ flops += N * self.dim * self.dim
153
+ return flops
154
+
155
+
156
+ class WindowAttentionWrapperV2(WindowAttentionV2):
157
+ def __init__(self, shift_size, input_resolution, **kwargs):
158
+ super(WindowAttentionWrapperV2, self).__init__(**kwargs)
159
+ self.shift_size = shift_size
160
+ self.input_resolution = input_resolution
161
+
162
+ if self.shift_size > 0:
163
+ attn_mask = calculate_mask(input_resolution, self.window_size, shift_size)
164
+ else:
165
+ attn_mask = None
166
+ self.register_buffer("attn_mask", attn_mask)
167
+
168
+ def forward(self, x, x_size):
169
+ H, W = x_size
170
+ B, L, C = x.shape
171
+ x = x.view(B, H, W, C)
172
+
173
+ # cyclic shift
174
+ if self.shift_size > 0:
175
+ x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
176
+
177
+ # partition windows
178
+ x = window_partition(x, self.window_size) # nW*B, wh, ww, C
179
+ x = x.view(-1, prod(self.window_size), C) # nW*B, wh*ww, C
180
+
181
+ # W-MSA/SW-MSA
182
+ if self.input_resolution == x_size:
183
+ attn_mask = self.attn_mask
184
+ else:
185
+ attn_mask = calculate_mask(x_size, self.window_size, self.shift_size)
186
+ attn_mask = attn_mask.to(x.device)
187
+
188
+ # attention
189
+ x = super(WindowAttentionWrapperV2, self).forward(x, mask=attn_mask)
190
+ # nW*B, wh*ww, C
191
+
192
+ # merge windows
193
+ x = x.view(-1, *self.window_size, C)
194
+ x = window_reverse(x, self.window_size, x_size) # B, H, W, C
195
+
196
+ # reverse cyclic shift
197
+ if self.shift_size > 0:
198
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
199
+ x = x.view(B, H * W, C)
200
+
201
+ return x
202
+
203
+
204
+ class SwinTransformerBlockV2(nn.Module):
205
+ r"""Swin Transformer Block.
206
+ Args:
207
+ dim (int): Number of input channels.
208
+ input_resolution (tuple[int]): Input resulotion.
209
+ num_heads (int): Number of attention heads.
210
+ window_size (int): Window size.
211
+ shift_size (int): Shift size for SW-MSA.
212
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
213
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
214
+ drop (float, optional): Dropout rate. Default: 0.0
215
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
216
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
217
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
218
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
219
+ pretrained_window_size (int): Window size in pre-training.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ dim,
225
+ input_resolution,
226
+ num_heads,
227
+ window_size=7,
228
+ shift_size=0,
229
+ mlp_ratio=4.0,
230
+ qkv_bias=True,
231
+ drop=0.0,
232
+ attn_drop=0.0,
233
+ drop_path=0.0,
234
+ act_layer=nn.GELU,
235
+ norm_layer=nn.LayerNorm,
236
+ pretrained_window_size=0,
237
+ use_pe=True,
238
+ res_scale=1.0,
239
+ ):
240
+ super().__init__()
241
+ self.dim = dim
242
+ self.input_resolution = input_resolution
243
+ self.num_heads = num_heads
244
+ self.window_size = window_size
245
+ self.shift_size = shift_size
246
+ self.mlp_ratio = mlp_ratio
247
+ if min(self.input_resolution) <= self.window_size:
248
+ # if window size is larger than input resolution, we don't partition windows
249
+ self.shift_size = 0
250
+ self.window_size = min(self.input_resolution)
251
+ assert (
252
+ 0 <= self.shift_size < self.window_size
253
+ ), "shift_size must in 0-window_size"
254
+ self.res_scale = res_scale
255
+
256
+ self.attn = WindowAttentionWrapperV2(
257
+ shift_size=self.shift_size,
258
+ input_resolution=self.input_resolution,
259
+ dim=dim,
260
+ window_size=to_2tuple(self.window_size),
261
+ num_heads=num_heads,
262
+ qkv_bias=qkv_bias,
263
+ attn_drop=attn_drop,
264
+ proj_drop=drop,
265
+ pretrained_window_size=to_2tuple(pretrained_window_size),
266
+ use_pe=use_pe,
267
+ )
268
+ self.norm1 = norm_layer(dim)
269
+
270
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
271
+
272
+ self.mlp = Mlp(
273
+ in_features=dim,
274
+ hidden_features=int(dim * mlp_ratio),
275
+ act_layer=act_layer,
276
+ drop=drop,
277
+ )
278
+ self.norm2 = norm_layer(dim)
279
+
280
+ def forward(self, x, x_size):
281
+ # Window attention
282
+ x = x + self.res_scale * self.drop_path(self.norm1(self.attn(x, x_size)))
283
+ # FFN
284
+ x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
285
+
286
+ return x
287
+
288
+ def extra_repr(self) -> str:
289
+ return (
290
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
291
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}, res_scale={self.res_scale}"
292
+ )
293
+
294
+ def flops(self):
295
+ flops = 0
296
+ H, W = self.input_resolution
297
+ # norm1
298
+ flops += self.dim * H * W
299
+ # W-MSA/SW-MSA
300
+ nW = H * W / self.window_size / self.window_size
301
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
302
+ # mlp
303
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
304
+ # norm2
305
+ flops += self.dim * H * W
306
+ return flops
architecture/grl_common/upsample.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ class Upsample(nn.Module):
7
+ """Upsample module.
8
+ Args:
9
+ scale (int): Scale factor. Supported scales: 2^n and 3.
10
+ num_feat (int): Channel number of intermediate features.
11
+ """
12
+
13
+ def __init__(self, scale, num_feat):
14
+ super(Upsample, self).__init__()
15
+ m = []
16
+ if (scale & (scale - 1)) == 0: # scale = 2^n
17
+ for _ in range(int(math.log(scale, 2))):
18
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
19
+ m.append(nn.PixelShuffle(2))
20
+ elif scale == 3:
21
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
22
+ m.append(nn.PixelShuffle(3))
23
+ else:
24
+ raise ValueError(
25
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
26
+ )
27
+ self.up = nn.Sequential(*m)
28
+
29
+ def forward(self, x):
30
+ return self.up(x)
31
+
32
+
33
+ class UpsampleOneStep(nn.Module):
34
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
35
+ Used in lightweight SR to save parameters.
36
+ Args:
37
+ scale (int): Scale factor. Supported scales: 2^n and 3.
38
+ num_feat (int): Channel number of intermediate features.
39
+ """
40
+
41
+ def __init__(self, scale, num_feat, num_out_ch):
42
+ super(UpsampleOneStep, self).__init__()
43
+ self.num_feat = num_feat
44
+ m = []
45
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
46
+ m.append(nn.PixelShuffle(scale))
47
+ self.up = nn.Sequential(*m)
48
+
49
+ def forward(self, x):
50
+ return self.up(x)
architecture/rrdb.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Paper Github Repository: https://github.com/xinntao/Real-ESRGAN
4
+ # Code snippet from: https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/archs/rrdbnet_arch.py
5
+ # Paper: https://arxiv.org/pdf/2107.10833.pdf
6
+
7
+ import os, sys
8
+ import torch
9
+ from torch import nn as nn
10
+ from torch.nn import functional as F
11
+ from itertools import repeat
12
+ from torch.nn import init as init
13
+ from torch.nn.modules.batchnorm import _BatchNorm
14
+
15
+
16
+ def pixel_unshuffle(x, scale):
17
+ """ Pixel unshuffle.
18
+
19
+ Args:
20
+ x (Tensor): Input feature with shape (b, c, hh, hw).
21
+ scale (int): Downsample ratio.
22
+
23
+ Returns:
24
+ Tensor: the pixel unshuffled feature.
25
+ """
26
+ b, c, hh, hw = x.size()
27
+ out_channel = c * (scale**2)
28
+ assert hh % scale == 0 and hw % scale == 0
29
+ h = hh // scale
30
+ w = hw // scale
31
+ x_view = x.view(b, c, h, scale, w, scale)
32
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
33
+
34
+ def make_layer(basic_block, num_basic_block, **kwarg):
35
+ """Make layers by stacking the same blocks.
36
+
37
+ Args:
38
+ basic_block (nn.module): nn.module class for basic block.
39
+ num_basic_block (int): number of blocks.
40
+
41
+ Returns:
42
+ nn.Sequential: Stacked blocks in nn.Sequential.
43
+ """
44
+ layers = []
45
+ for _ in range(num_basic_block):
46
+ layers.append(basic_block(**kwarg))
47
+ return nn.Sequential(*layers)
48
+
49
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
50
+ """Initialize network weights.
51
+
52
+ Args:
53
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
54
+ scale (float): Scale initialized weights, especially for residual
55
+ blocks. Default: 1.
56
+ bias_fill (float): The value to fill bias. Default: 0
57
+ kwargs (dict): Other arguments for initialization function.
58
+ """
59
+ if not isinstance(module_list, list):
60
+ module_list = [module_list]
61
+ for module in module_list:
62
+ for m in module.modules():
63
+ if isinstance(m, nn.Conv2d):
64
+ init.kaiming_normal_(m.weight, **kwargs)
65
+ m.weight.data *= scale
66
+ if m.bias is not None:
67
+ m.bias.data.fill_(bias_fill)
68
+ elif isinstance(m, nn.Linear):
69
+ init.kaiming_normal_(m.weight, **kwargs)
70
+ m.weight.data *= scale
71
+ if m.bias is not None:
72
+ m.bias.data.fill_(bias_fill)
73
+ elif isinstance(m, _BatchNorm):
74
+ init.constant_(m.weight, 1)
75
+ if m.bias is not None:
76
+ m.bias.data.fill_(bias_fill)
77
+
78
+ class ResidualDenseBlock(nn.Module):
79
+ """Residual Dense Block.
80
+
81
+ Used in RRDB block in ESRGAN.
82
+
83
+ Args:
84
+ num_feat (int): Channel number of intermediate features.
85
+ num_grow_ch (int): Channels for each growth.
86
+ """
87
+
88
+ def __init__(self, num_feat=64, num_grow_ch=32):
89
+ super(ResidualDenseBlock, self).__init__()
90
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
91
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
92
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
93
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
94
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
95
+
96
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
97
+
98
+ # initialization
99
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
100
+
101
+ def forward(self, x):
102
+ x1 = self.lrelu(self.conv1(x))
103
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
104
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
105
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
106
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
107
+ # Empirically, we use 0.2 to scale the residual for better performance
108
+ return x5 * 0.2 + x
109
+
110
+
111
+ class RRDB(nn.Module):
112
+ """Residual in Residual Dense Block.
113
+
114
+ Used in RRDB-Net in ESRGAN.
115
+
116
+ Args:
117
+ num_feat (int): Channel number of intermediate features.
118
+ num_grow_ch (int): Channels for each growth.
119
+ """
120
+
121
+ def __init__(self, num_feat, num_grow_ch=32):
122
+ super(RRDB, self).__init__()
123
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
124
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
125
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
126
+
127
+ def forward(self, x):
128
+ out = self.rdb1(x)
129
+ out = self.rdb2(out)
130
+ out = self.rdb3(out)
131
+ # Empirically, we use 0.2 to scale the residual for better performance
132
+ return out * 0.2 + x
133
+
134
+
135
+
136
+ class RRDBNet(nn.Module):
137
+ """Networks consisting of Residual in Residual Dense Block, which is used
138
+ in ESRGAN.
139
+
140
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
141
+
142
+ We extend ESRGAN for scale x2 and scale x1.
143
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
144
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
145
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
146
+
147
+ Args:
148
+ num_in_ch (int): Channel number of inputs.
149
+ num_out_ch (int): Channel number of outputs.
150
+ num_feat (int): Channel number of intermediate features.
151
+ Default: 64
152
+ num_block (int): Block number in the trunk network. Defaults: 6 for our Anime training cases
153
+ num_grow_ch (int): Channels for each growth. Default: 32.
154
+ """
155
+
156
+ def __init__(self, num_in_ch, num_out_ch, scale, num_feat=64, num_block=6, num_grow_ch=32):
157
+
158
+ super(RRDBNet, self).__init__()
159
+ self.scale = scale
160
+ if scale == 2:
161
+ num_in_ch = num_in_ch * 4
162
+ elif scale == 1:
163
+ num_in_ch = num_in_ch * 16
164
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
165
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
166
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
167
+ # upsample
168
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
169
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
170
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
171
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
172
+
173
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
174
+
175
+ def forward(self, x):
176
+ if self.scale == 2:
177
+ feat = pixel_unshuffle(x, scale=2)
178
+ elif self.scale == 1:
179
+ feat = pixel_unshuffle(x, scale=4)
180
+ else:
181
+ feat = x
182
+ feat = self.conv_first(feat)
183
+ body_feat = self.conv_body(self.body(feat))
184
+ feat = feat + body_feat
185
+ # upsample
186
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
187
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
188
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
189
+ return out
190
+
191
+
192
+
193
+ def main():
194
+ root_path = os.path.abspath('.')
195
+ sys.path.append(root_path)
196
+
197
+ from opt import opt # Manage GPU to choose
198
+ from pthflops import count_ops
199
+ from torchsummary import summary
200
+ import time
201
+
202
+ # We use RRDB 6Blocks by default.
203
+ model = RRDBNet(3, 3).cuda()
204
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
205
+ print(f"RRDB has param {pytorch_total_params//1000} K params")
206
+
207
+
208
+ # Count the number of FLOPs to double check
209
+ x = torch.randn((1, 3, 180, 180)).cuda()
210
+ start = time.time()
211
+ x = model(x)
212
+ print("output size is ", x.shape)
213
+ total = time.time() - start
214
+ print(total)
215
+
216
+
217
+ if __name__ == "__main__":
218
+ main()
architecture/swinir.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -----------------------------------------------------------------------------------
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+ # -----------------------------------------------------------------------------------
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12
+
13
+
14
+ class Mlp(nn.Module):
15
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
16
+ super().__init__()
17
+ out_features = out_features or in_features
18
+ hidden_features = hidden_features or in_features
19
+ self.fc1 = nn.Linear(in_features, hidden_features)
20
+ self.act = act_layer()
21
+ self.fc2 = nn.Linear(hidden_features, out_features)
22
+ self.drop = nn.Dropout(drop)
23
+
24
+ def forward(self, x):
25
+ x = self.fc1(x)
26
+ x = self.act(x)
27
+ x = self.drop(x)
28
+ x = self.fc2(x)
29
+ x = self.drop(x)
30
+ return x
31
+
32
+
33
+ def window_partition(x, window_size):
34
+ """
35
+ Args:
36
+ x: (B, H, W, C)
37
+ window_size (int): window size
38
+
39
+ Returns:
40
+ windows: (num_windows*B, window_size, window_size, C)
41
+ """
42
+ B, H, W, C = x.shape
43
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
44
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
45
+ return windows
46
+
47
+
48
+ def window_reverse(windows, window_size, H, W):
49
+ """
50
+ Args:
51
+ windows: (num_windows*B, window_size, window_size, C)
52
+ window_size (int): Window size
53
+ H (int): Height of image
54
+ W (int): Width of image
55
+
56
+ Returns:
57
+ x: (B, H, W, C)
58
+ """
59
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
60
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
61
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
62
+ return x
63
+
64
+
65
+ class WindowAttention(nn.Module):
66
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
67
+ It supports both of shifted and non-shifted window.
68
+
69
+ Args:
70
+ dim (int): Number of input channels.
71
+ window_size (tuple[int]): The height and width of the window.
72
+ num_heads (int): Number of attention heads.
73
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
74
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
75
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
76
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
77
+ """
78
+
79
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
80
+
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.window_size = window_size # Wh, Ww
84
+ self.num_heads = num_heads
85
+ head_dim = dim // num_heads
86
+ self.scale = qk_scale or head_dim ** -0.5
87
+
88
+ # define a parameter table of relative position bias
89
+ self.relative_position_bias_table = nn.Parameter(
90
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
91
+
92
+ # get pair-wise relative position index for each token inside the window
93
+ coords_h = torch.arange(self.window_size[0])
94
+ coords_w = torch.arange(self.window_size[1])
95
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
96
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
97
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
98
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
99
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
100
+ relative_coords[:, :, 1] += self.window_size[1] - 1
101
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
102
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
103
+ self.register_buffer("relative_position_index", relative_position_index)
104
+
105
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+
109
+ self.proj_drop = nn.Dropout(proj_drop)
110
+
111
+ trunc_normal_(self.relative_position_bias_table, std=.02)
112
+ self.softmax = nn.Softmax(dim=-1)
113
+
114
+ def forward(self, x, mask=None):
115
+ """
116
+ Args:
117
+ x: input features with shape of (num_windows*B, N, C)
118
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
119
+ """
120
+ B_, N, C = x.shape
121
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
122
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
123
+
124
+ q = q * self.scale
125
+ attn = (q @ k.transpose(-2, -1))
126
+
127
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
128
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
129
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
130
+ attn = attn + relative_position_bias.unsqueeze(0)
131
+
132
+ if mask is not None:
133
+ nW = mask.shape[0]
134
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
135
+ attn = attn.view(-1, self.num_heads, N, N)
136
+ attn = self.softmax(attn)
137
+ else:
138
+ attn = self.softmax(attn)
139
+
140
+ attn = self.attn_drop(attn)
141
+
142
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
143
+ x = self.proj(x)
144
+ x = self.proj_drop(x)
145
+ return x
146
+
147
+ def extra_repr(self) -> str:
148
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
149
+
150
+ def flops(self, N):
151
+ # calculate flops for 1 window with token length of N
152
+ flops = 0
153
+ # qkv = self.qkv(x)
154
+ flops += N * self.dim * 3 * self.dim
155
+ # attn = (q @ k.transpose(-2, -1))
156
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
157
+ # x = (attn @ v)
158
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
159
+ # x = self.proj(x)
160
+ flops += N * self.dim * self.dim
161
+ return flops
162
+
163
+
164
+ class SwinTransformerBlock(nn.Module):
165
+ r""" Swin Transformer Block.
166
+
167
+ Args:
168
+ dim (int): Number of input channels.
169
+ input_resolution (tuple[int]): Input resulotion.
170
+ num_heads (int): Number of attention heads.
171
+ window_size (int): Window size.
172
+ shift_size (int): Shift size for SW-MSA.
173
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
174
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
175
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
176
+ drop (float, optional): Dropout rate. Default: 0.0
177
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
178
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
179
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
180
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
181
+ """
182
+
183
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
184
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
185
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
186
+ super().__init__()
187
+ self.dim = dim
188
+ self.input_resolution = input_resolution
189
+ self.num_heads = num_heads
190
+ self.window_size = window_size
191
+ self.shift_size = shift_size
192
+ self.mlp_ratio = mlp_ratio
193
+ if min(self.input_resolution) <= self.window_size:
194
+ # if window size is larger than input resolution, we don't partition windows
195
+ self.shift_size = 0
196
+ self.window_size = min(self.input_resolution)
197
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
198
+
199
+ self.norm1 = norm_layer(dim)
200
+ self.attn = WindowAttention(
201
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
202
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
203
+
204
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
205
+ self.norm2 = norm_layer(dim)
206
+ mlp_hidden_dim = int(dim * mlp_ratio)
207
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
208
+
209
+ if self.shift_size > 0:
210
+ attn_mask = self.calculate_mask(self.input_resolution)
211
+ else:
212
+ attn_mask = None
213
+
214
+ self.register_buffer("attn_mask", attn_mask)
215
+
216
+ def calculate_mask(self, x_size):
217
+ # calculate attention mask for SW-MSA
218
+ H, W = x_size
219
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
220
+ h_slices = (slice(0, -self.window_size),
221
+ slice(-self.window_size, -self.shift_size),
222
+ slice(-self.shift_size, None))
223
+ w_slices = (slice(0, -self.window_size),
224
+ slice(-self.window_size, -self.shift_size),
225
+ slice(-self.shift_size, None))
226
+ cnt = 0
227
+ for h in h_slices:
228
+ for w in w_slices:
229
+ img_mask[:, h, w, :] = cnt
230
+ cnt += 1
231
+
232
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
233
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
234
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
235
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
236
+
237
+ return attn_mask
238
+
239
+ def forward(self, x, x_size):
240
+ H, W = x_size
241
+ B, L, C = x.shape
242
+ # assert L == H * W, "input feature has wrong size"
243
+
244
+ shortcut = x
245
+ x = self.norm1(x)
246
+ x = x.view(B, H, W, C)
247
+
248
+ # cyclic shift
249
+ if self.shift_size > 0:
250
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
251
+ else:
252
+ shifted_x = x
253
+
254
+ # partition windows
255
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
256
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
257
+
258
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
259
+ if self.input_resolution == x_size:
260
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
261
+ else:
262
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
263
+
264
+ # merge windows
265
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
266
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
267
+
268
+ # reverse cyclic shift
269
+ if self.shift_size > 0:
270
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
271
+ else:
272
+ x = shifted_x
273
+ x = x.view(B, H * W, C)
274
+
275
+ # FFN
276
+ x = shortcut + self.drop_path(x)
277
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
278
+
279
+ return x
280
+
281
+ def extra_repr(self) -> str:
282
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
283
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
284
+
285
+ def flops(self):
286
+ flops = 0
287
+ H, W = self.input_resolution
288
+ # norm1
289
+ flops += self.dim * H * W
290
+ # W-MSA/SW-MSA
291
+ nW = H * W / self.window_size / self.window_size
292
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
293
+ # mlp
294
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
295
+ # norm2
296
+ flops += self.dim * H * W
297
+ return flops
298
+
299
+
300
+ class PatchMerging(nn.Module):
301
+ r""" Patch Merging Layer.
302
+
303
+ Args:
304
+ input_resolution (tuple[int]): Resolution of input feature.
305
+ dim (int): Number of input channels.
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ """
308
+
309
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
310
+ super().__init__()
311
+ self.input_resolution = input_resolution
312
+ self.dim = dim
313
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
314
+ self.norm = norm_layer(4 * dim)
315
+
316
+ def forward(self, x):
317
+ """
318
+ x: B, H*W, C
319
+ """
320
+ H, W = self.input_resolution
321
+ B, L, C = x.shape
322
+ assert L == H * W, "input feature has wrong size"
323
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
324
+
325
+ x = x.view(B, H, W, C)
326
+
327
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
328
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
329
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
330
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
331
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
332
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
333
+
334
+ x = self.norm(x)
335
+ x = self.reduction(x)
336
+
337
+ return x
338
+
339
+ def extra_repr(self) -> str:
340
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
341
+
342
+ def flops(self):
343
+ H, W = self.input_resolution
344
+ flops = H * W * self.dim
345
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
346
+ return flops
347
+
348
+
349
+ class BasicLayer(nn.Module):
350
+ """ A basic Swin Transformer layer for one stage.
351
+
352
+ Args:
353
+ dim (int): Number of input channels.
354
+ input_resolution (tuple[int]): Input resolution.
355
+ depth (int): Number of blocks.
356
+ num_heads (int): Number of attention heads.
357
+ window_size (int): Local window size.
358
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
359
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
361
+ drop (float, optional): Dropout rate. Default: 0.0
362
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
363
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
364
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
365
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
366
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
367
+ """
368
+
369
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
370
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
371
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
372
+
373
+ super().__init__()
374
+ self.dim = dim
375
+ self.input_resolution = input_resolution
376
+ self.depth = depth
377
+ self.use_checkpoint = use_checkpoint
378
+
379
+ # build blocks
380
+ self.blocks = nn.ModuleList([
381
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
382
+ num_heads=num_heads, window_size=window_size,
383
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
384
+ mlp_ratio=mlp_ratio,
385
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
386
+ drop=drop, attn_drop=attn_drop,
387
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
388
+ norm_layer=norm_layer)
389
+ for i in range(depth)])
390
+
391
+ # patch merging layer
392
+ if downsample is not None:
393
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
394
+ else:
395
+ self.downsample = None
396
+
397
+ def forward(self, x, x_size):
398
+ for blk in self.blocks:
399
+ if self.use_checkpoint:
400
+ x = checkpoint.checkpoint(blk, x, x_size)
401
+ else:
402
+ x = blk(x, x_size)
403
+ if self.downsample is not None:
404
+ x = self.downsample(x)
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
409
+
410
+ def flops(self):
411
+ flops = 0
412
+ for blk in self.blocks:
413
+ flops += blk.flops()
414
+ if self.downsample is not None:
415
+ flops += self.downsample.flops()
416
+ return flops
417
+
418
+
419
+ class RSTB(nn.Module):
420
+ """Residual Swin Transformer Block (RSTB).
421
+
422
+ Args:
423
+ dim (int): Number of input channels.
424
+ input_resolution (tuple[int]): Input resolution.
425
+ depth (int): Number of blocks.
426
+ num_heads (int): Number of attention heads.
427
+ window_size (int): Local window size.
428
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
429
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
430
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
431
+ drop (float, optional): Dropout rate. Default: 0.0
432
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
433
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
434
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
435
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
436
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
437
+ img_size: Input image size.
438
+ patch_size: Patch size.
439
+ resi_connection: The convolutional block before residual connection.
440
+ """
441
+
442
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
443
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
444
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
445
+ img_size=224, patch_size=4, resi_connection='1conv'):
446
+ super(RSTB, self).__init__()
447
+
448
+ self.dim = dim
449
+ self.input_resolution = input_resolution
450
+
451
+ self.residual_group = BasicLayer(dim=dim,
452
+ input_resolution=input_resolution,
453
+ depth=depth,
454
+ num_heads=num_heads,
455
+ window_size=window_size,
456
+ mlp_ratio=mlp_ratio,
457
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
458
+ drop=drop, attn_drop=attn_drop,
459
+ drop_path=drop_path,
460
+ norm_layer=norm_layer,
461
+ downsample=downsample,
462
+ use_checkpoint=use_checkpoint)
463
+
464
+ if resi_connection == '1conv':
465
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
466
+ elif resi_connection == '3conv':
467
+ # to save parameters and memory
468
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
469
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
470
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
471
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
472
+
473
+ self.patch_embed = PatchEmbed(
474
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
475
+ norm_layer=None)
476
+
477
+ self.patch_unembed = PatchUnEmbed(
478
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
479
+ norm_layer=None)
480
+
481
+ def forward(self, x, x_size):
482
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
483
+
484
+ def flops(self):
485
+ flops = 0
486
+ flops += self.residual_group.flops()
487
+ H, W = self.input_resolution
488
+ flops += H * W * self.dim * self.dim * 9
489
+ flops += self.patch_embed.flops()
490
+ flops += self.patch_unembed.flops()
491
+
492
+ return flops
493
+
494
+
495
+ class PatchEmbed(nn.Module):
496
+ r""" Image to Patch Embedding
497
+
498
+ Args:
499
+ img_size (int): Image size. Default: 224.
500
+ patch_size (int): Patch token size. Default: 4.
501
+ in_chans (int): Number of input image channels. Default: 3.
502
+ embed_dim (int): Number of linear projection output channels. Default: 96.
503
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
504
+ """
505
+
506
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
507
+ super().__init__()
508
+ img_size = to_2tuple(img_size)
509
+ patch_size = to_2tuple(patch_size)
510
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
511
+ self.img_size = img_size
512
+ self.patch_size = patch_size
513
+ self.patches_resolution = patches_resolution
514
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
515
+
516
+ self.in_chans = in_chans
517
+ self.embed_dim = embed_dim
518
+
519
+ if norm_layer is not None:
520
+ self.norm = norm_layer(embed_dim)
521
+ else:
522
+ self.norm = None
523
+
524
+ def forward(self, x):
525
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
526
+ if self.norm is not None:
527
+ x = self.norm(x)
528
+ return x
529
+
530
+ def flops(self):
531
+ flops = 0
532
+ H, W = self.img_size
533
+ if self.norm is not None:
534
+ flops += H * W * self.embed_dim
535
+ return flops
536
+
537
+
538
+ class PatchUnEmbed(nn.Module):
539
+ r""" Image to Patch Unembedding
540
+
541
+ Args:
542
+ img_size (int): Image size. Default: 224.
543
+ patch_size (int): Patch token size. Default: 4.
544
+ in_chans (int): Number of input image channels. Default: 3.
545
+ embed_dim (int): Number of linear projection output channels. Default: 96.
546
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
547
+ """
548
+
549
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
550
+ super().__init__()
551
+ img_size = to_2tuple(img_size)
552
+ patch_size = to_2tuple(patch_size)
553
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
554
+ self.img_size = img_size
555
+ self.patch_size = patch_size
556
+ self.patches_resolution = patches_resolution
557
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
558
+
559
+ self.in_chans = in_chans
560
+ self.embed_dim = embed_dim
561
+
562
+ def forward(self, x, x_size):
563
+ B, HW, C = x.shape
564
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
565
+ return x
566
+
567
+ def flops(self):
568
+ flops = 0
569
+ return flops
570
+
571
+
572
+ class Upsample(nn.Sequential):
573
+ """Upsample module.
574
+
575
+ Args:
576
+ scale (int): Scale factor. Supported scales: 2^n and 3.
577
+ num_feat (int): Channel number of intermediate features.
578
+ """
579
+
580
+ def __init__(self, scale, num_feat):
581
+ m = []
582
+ if (scale & (scale - 1)) == 0: # scale = 2^n
583
+ for _ in range(int(math.log(scale, 2))):
584
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
585
+ m.append(nn.PixelShuffle(2))
586
+ elif scale == 3:
587
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
588
+ m.append(nn.PixelShuffle(3))
589
+ else:
590
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
591
+ super(Upsample, self).__init__(*m)
592
+
593
+
594
+ class UpsampleOneStep(nn.Sequential):
595
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
596
+ Used in lightweight SR to save parameters.
597
+
598
+ Args:
599
+ scale (int): Scale factor. Supported scales: 2^n and 3.
600
+ num_feat (int): Channel number of intermediate features.
601
+
602
+ """
603
+
604
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
605
+ self.num_feat = num_feat
606
+ self.input_resolution = input_resolution
607
+ m = []
608
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
609
+ m.append(nn.PixelShuffle(scale))
610
+ super(UpsampleOneStep, self).__init__(*m)
611
+
612
+ def flops(self):
613
+ H, W = self.input_resolution
614
+ flops = H * W * self.num_feat * 3 * 9
615
+ return flops
616
+
617
+
618
+ class SwinIR(nn.Module):
619
+ r""" SwinIR
620
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
621
+
622
+ Args:
623
+ img_size (int | tuple(int)): Input image size. Default 64
624
+ patch_size (int | tuple(int)): Patch size. Default: 1
625
+ in_chans (int): Number of input image channels. Default: 3
626
+ embed_dim (int): Patch embedding dimension. Default: 96
627
+ depths (tuple(int)): Depth of each Swin Transformer layer.
628
+ num_heads (tuple(int)): Number of attention heads in different layers.
629
+ window_size (int): Window size. Default: 7
630
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
631
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
632
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
633
+ drop_rate (float): Dropout rate. Default: 0
634
+ attn_drop_rate (float): Attention dropout rate. Default: 0
635
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
636
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
637
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
638
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
639
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
640
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
641
+ img_range: Image range. 1. or 255.
642
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
643
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
644
+ """
645
+
646
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
647
+ embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
648
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
649
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
650
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
651
+ use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
652
+ **kwargs):
653
+ super(SwinIR, self).__init__()
654
+ num_in_ch = in_chans
655
+ num_out_ch = in_chans
656
+ num_feat = 64
657
+ self.img_range = img_range
658
+ if in_chans == 3:
659
+ rgb_mean = (0.4488, 0.4371, 0.4040)
660
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
661
+ else:
662
+ self.mean = torch.zeros(1, 1, 1, 1)
663
+ self.upscale = upscale
664
+ self.upsampler = upsampler
665
+ self.window_size = window_size
666
+
667
+ #####################################################################################################
668
+ ################################### 1, shallow feature extraction ###################################
669
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
670
+
671
+ #####################################################################################################
672
+ ################################### 2, deep feature extraction ######################################
673
+ self.num_layers = len(depths)
674
+ self.embed_dim = embed_dim
675
+ self.ape = ape
676
+ self.patch_norm = patch_norm
677
+ self.num_features = embed_dim
678
+ self.mlp_ratio = mlp_ratio
679
+
680
+ # split image into non-overlapping patches
681
+ self.patch_embed = PatchEmbed(
682
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
683
+ norm_layer=norm_layer if self.patch_norm else None)
684
+ num_patches = self.patch_embed.num_patches
685
+ patches_resolution = self.patch_embed.patches_resolution
686
+ self.patches_resolution = patches_resolution
687
+
688
+ # merge non-overlapping patches into image
689
+ self.patch_unembed = PatchUnEmbed(
690
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
691
+ norm_layer=norm_layer if self.patch_norm else None)
692
+
693
+ # absolute position embedding
694
+ if self.ape:
695
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
696
+ trunc_normal_(self.absolute_pos_embed, std=.02)
697
+
698
+ self.pos_drop = nn.Dropout(p=drop_rate)
699
+
700
+ # stochastic depth
701
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
702
+
703
+ # build Residual Swin Transformer blocks (RSTB)
704
+ self.layers = nn.ModuleList()
705
+ for i_layer in range(self.num_layers):
706
+ layer = RSTB(dim=embed_dim,
707
+ input_resolution=(patches_resolution[0],
708
+ patches_resolution[1]),
709
+ depth=depths[i_layer],
710
+ num_heads=num_heads[i_layer],
711
+ window_size=window_size,
712
+ mlp_ratio=self.mlp_ratio,
713
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
714
+ drop=drop_rate, attn_drop=attn_drop_rate,
715
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
716
+ norm_layer=norm_layer,
717
+ downsample=None,
718
+ use_checkpoint=use_checkpoint,
719
+ img_size=img_size,
720
+ patch_size=patch_size,
721
+ resi_connection=resi_connection
722
+
723
+ )
724
+ self.layers.append(layer)
725
+ self.norm = norm_layer(self.num_features)
726
+
727
+ # build the last conv layer in deep feature extraction
728
+ if resi_connection == '1conv':
729
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
730
+ elif resi_connection == '3conv':
731
+ # to save parameters and memory
732
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
733
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
734
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
735
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
736
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
737
+
738
+ #####################################################################################################
739
+ ################################ 3, high quality image reconstruction ################################
740
+ if self.upsampler == 'pixelshuffle':
741
+ # for classical SR
742
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
743
+ nn.LeakyReLU(inplace=True))
744
+ self.upsample = Upsample(upscale, num_feat)
745
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
746
+ elif self.upsampler == 'pixelshuffledirect':
747
+ # for lightweight SR (to save parameters)
748
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
749
+ (patches_resolution[0], patches_resolution[1]))
750
+ elif self.upsampler == 'nearest+conv':
751
+ # for real-world SR (less artifacts)
752
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
753
+ nn.LeakyReLU(inplace=True))
754
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
755
+ if self.upscale == 4:
756
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
757
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
758
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
759
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
760
+ else:
761
+ # for image denoising and JPEG compression artifact reduction
762
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
763
+
764
+ self.apply(self._init_weights)
765
+
766
+ def _init_weights(self, m):
767
+ if isinstance(m, nn.Linear):
768
+ trunc_normal_(m.weight, std=.02)
769
+ if isinstance(m, nn.Linear) and m.bias is not None:
770
+ nn.init.constant_(m.bias, 0)
771
+ elif isinstance(m, nn.LayerNorm):
772
+ nn.init.constant_(m.bias, 0)
773
+ nn.init.constant_(m.weight, 1.0)
774
+
775
+ @torch.jit.ignore
776
+ def no_weight_decay(self):
777
+ return {'absolute_pos_embed'}
778
+
779
+ @torch.jit.ignore
780
+ def no_weight_decay_keywords(self):
781
+ return {'relative_position_bias_table'}
782
+
783
+ def check_image_size(self, x):
784
+ _, _, h, w = x.size()
785
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
786
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
787
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
788
+ return x
789
+
790
+ def forward_features(self, x):
791
+ x_size = (x.shape[2], x.shape[3])
792
+ x = self.patch_embed(x)
793
+ if self.ape:
794
+ x = x + self.absolute_pos_embed
795
+ x = self.pos_drop(x)
796
+
797
+ for layer in self.layers:
798
+ x = layer(x, x_size)
799
+
800
+ x = self.norm(x) # B L C
801
+ x = self.patch_unembed(x, x_size)
802
+
803
+ return x
804
+
805
+ def forward(self, x):
806
+ H, W = x.shape[2:]
807
+ x = self.check_image_size(x)
808
+
809
+ self.mean = self.mean.type_as(x)
810
+ x = (x - self.mean) * self.img_range
811
+
812
+ if self.upsampler == 'pixelshuffle':
813
+ # for classical SR
814
+ x = self.conv_first(x)
815
+ x = self.conv_after_body(self.forward_features(x)) + x
816
+ x = self.conv_before_upsample(x)
817
+ x = self.conv_last(self.upsample(x))
818
+ elif self.upsampler == 'pixelshuffledirect':
819
+ # for lightweight SR
820
+ x = self.conv_first(x)
821
+ x = self.conv_after_body(self.forward_features(x)) + x
822
+ x = self.upsample(x)
823
+ elif self.upsampler == 'nearest+conv':
824
+ # for real-world SR
825
+ x = self.conv_first(x)
826
+ x = self.conv_after_body(self.forward_features(x)) + x
827
+ x = self.conv_before_upsample(x)
828
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
829
+ if self.upscale == 4:
830
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
831
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
832
+ else:
833
+ # for image denoising and JPEG compression artifact reduction
834
+ x_first = self.conv_first(x)
835
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
836
+ x = x + self.conv_last(res)
837
+
838
+ x = x / self.img_range + self.mean
839
+
840
+ return x[:, :, :H*self.upscale, :W*self.upscale]
841
+
842
+ def flops(self):
843
+ flops = 0
844
+ H, W = self.patches_resolution
845
+ flops += H * W * 3 * self.embed_dim * 9
846
+ flops += self.patch_embed.flops()
847
+ for i, layer in enumerate(self.layers):
848
+ flops += layer.flops()
849
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
850
+ flops += self.upsample.flops()
851
+ return flops
852
+
853
+
854
+ if __name__ == '__main__':
855
+ upscale = 4
856
+ window_size = 8
857
+ height = (1024 // upscale // window_size + 1) * window_size
858
+ width = (720 // upscale // window_size + 1) * window_size
859
+ model = SwinIR(upscale=2, img_size=(height, width),
860
+ window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
861
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect').cuda()
862
+ print(model)
863
+
864
+ pytorch_total_params = sum(p.numel() for p in model.parameters())
865
+ print(f"pathGAN has param {pytorch_total_params//1000} K params")
866
+
867
+
868
+ # Count the time
869
+ import time
870
+ x = torch.randn((1, 3, 180, 180)).cuda()
871
+ start = time.time()
872
+ x = model(x)
873
+ total = time.time() - start
874
+ print("total time spent is ", total)
dataset_curation_pipeline/IC9600/ICNet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+
8
+ class slam(nn.Module):
9
+ def __init__(self, spatial_dim):
10
+ super(slam,self).__init__()
11
+ self.spatial_dim = spatial_dim
12
+ self.linear = nn.Sequential(
13
+ nn.Linear(spatial_dim**2,512),
14
+ nn.ReLU(),
15
+ nn.Linear(512,1),
16
+ nn.Sigmoid()
17
+ )
18
+
19
+ def forward(self, feature):
20
+ n,c,h,w = feature.shape
21
+ if (h != self.spatial_dim):
22
+ x = F.interpolate(feature,size=(self.spatial_dim,self.spatial_dim),mode= "bilinear", align_corners=True)
23
+ else:
24
+ x = feature
25
+
26
+
27
+ x = x.view(n,c,-1)
28
+ x = self.linear(x)
29
+ x = x.unsqueeze(dim =3)
30
+ out = x.expand_as(feature)*feature
31
+
32
+ return out
33
+
34
+
35
+ class to_map(nn.Module):
36
+ def __init__(self,channels):
37
+ super(to_map,self).__init__()
38
+ self.to_map = nn.Sequential(
39
+ nn.Conv2d(in_channels=channels,out_channels=1, kernel_size=1,stride=1),
40
+ nn.Sigmoid()
41
+ )
42
+
43
+ def forward(self,feature):
44
+ return self.to_map(feature)
45
+
46
+
47
+ class conv_bn_relu(nn.Module):
48
+ def __init__(self,in_channels, out_channels, kernel_size = 3, padding = 1, stride = 1):
49
+ super(conv_bn_relu,self).__init__()
50
+ self.conv = nn.Conv2d(in_channels= in_channels, out_channels= out_channels, kernel_size= kernel_size, padding= padding, stride = stride)
51
+ self.bn = nn.BatchNorm2d(out_channels)
52
+ self.relu = nn.ReLU()
53
+
54
+ def forward(self,x):
55
+ x = self.conv(x)
56
+ x = self.bn(x)
57
+ x = self.relu(x)
58
+ return x
59
+
60
+
61
+
62
+ class up_conv_bn_relu(nn.Module):
63
+ def __init__(self,up_size, in_channels, out_channels = 64, kernal_size = 1, padding =0, stride = 1):
64
+ super(up_conv_bn_relu,self).__init__()
65
+ self.upSample = nn.Upsample(size = (up_size,up_size),mode="bilinear",align_corners=True)
66
+ self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size = kernal_size, stride = stride, padding= padding)
67
+ self.bn = nn.BatchNorm2d(num_features=out_channels)
68
+ self.act = nn.ReLU()
69
+
70
+ def forward(self,x):
71
+ x = self.upSample(x)
72
+ x = self.conv(x)
73
+ x = self.bn(x)
74
+ x = self.act(x)
75
+ return x
76
+
77
+
78
+
79
+ class ICNet(nn.Module):
80
+ def __init__(self, is_pretrain = True, size1 = 512, size2 = 256):
81
+ super(ICNet,self).__init__()
82
+ resnet18Pretrained1 = torchvision.models.resnet18(pretrained= is_pretrain)
83
+ resnet18Pretrained2 = torchvision.models.resnet18(pretrained= is_pretrain)
84
+
85
+ self.size1 = size1
86
+ self.size2 = size2
87
+
88
+ ## detail branch
89
+ self.b1_1 = nn.Sequential(*list(resnet18Pretrained1.children())[:5])
90
+ self.b1_1_slam = slam(32)
91
+
92
+ self.b1_2 = list(resnet18Pretrained1.children())[5]
93
+ self.b1_2_slam = slam(32)
94
+
95
+ ## context branch
96
+ self.b2_1 = nn.Sequential(*list(resnet18Pretrained2.children())[:5])
97
+ self.b2_1_slam = slam(32)
98
+
99
+ self.b2_2 = list(resnet18Pretrained2.children())[5]
100
+ self.b2_2_slam = slam(32)
101
+
102
+ self.b2_3 = list(resnet18Pretrained2.children())[6]
103
+ self.b2_3_slam = slam(16)
104
+
105
+ self.b2_4 = list(resnet18Pretrained2.children())[7]
106
+ self.b2_4_slam = slam(8)
107
+
108
+ ## upsample
109
+ self.upsize = size1 // 8
110
+ self.up1 = up_conv_bn_relu(up_size = self.upsize, in_channels = 128, out_channels = 256)
111
+ self.up2 = up_conv_bn_relu(up_size = self.upsize, in_channels = 512, out_channels = 256)
112
+
113
+ ## map prediction head
114
+ self.to_map_f = conv_bn_relu(256*2,256*2)
115
+ self.to_map_f_slam = slam(32)
116
+ self.to_map = to_map(256*2)
117
+
118
+ ## score prediction head
119
+ self.to_score_f = conv_bn_relu(256*2,256*2)
120
+ self.to_score_f_slam = slam(32)
121
+ self.head = nn.Sequential(
122
+ nn.Linear(256*2,512),
123
+ nn.ReLU(),
124
+ nn.Linear(512,1),
125
+ nn.Sigmoid()
126
+ )
127
+ self.avgpool = nn.AdaptiveAvgPool2d((1,1))
128
+
129
+
130
+ def forward(self,x1):
131
+ assert(x1.shape[2] == x1.shape[3] == self.size1)
132
+ x2 = F.interpolate(x1, size= (self.size2,self.size2), mode = "bilinear", align_corners= True)
133
+
134
+ x1 = self.b1_2_slam(self.b1_2(self.b1_1_slam(self.b1_1(x1))))
135
+ x2 = self.b2_2_slam(self.b2_2(self.b2_1_slam(self.b2_1(x2))))
136
+ x2 = self.b2_4_slam(self.b2_4(self.b2_3_slam(self.b2_3(x2))))
137
+
138
+
139
+ x1 = self.up1(x1)
140
+ x2 = self.up2(x2)
141
+ x_cat = torch.cat((x1,x2),dim = 1)
142
+
143
+ cly_map = self.to_map(self.to_map_f_slam(self.to_map_f(x_cat)))
144
+
145
+ score_feature = self.to_score_f_slam(self.to_score_f(x_cat))
146
+ score_feature = self.avgpool(score_feature)
147
+ score_feature = score_feature.squeeze()
148
+ score = self.head(score_feature)
149
+ score = score.squeeze()
150
+
151
+ return score,cly_map
dataset_curation_pipeline/IC9600/gene.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os, sys
3
+ import torch
4
+ import cv2
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from matplotlib import pyplot as plt
10
+ from tqdm import tqdm
11
+
12
+ # Import files from the local folder
13
+ root_path = os.path.abspath('.')
14
+ sys.path.append(root_path)
15
+ from opt import opt
16
+ from dataset_curation_pipeline.IC9600.ICNet import ICNet
17
+
18
+
19
+
20
+ inference_transform = transforms.Compose([
21
+ transforms.Resize((512,512)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
24
+ ])
25
+
26
+ def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")):
27
+ cm_ic_map = cm(ic_img)
28
+ heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8))
29
+ ori_img = Image.fromarray(ori_img)
30
+ blend = Image.blend(ori_img,heatmap,alpha=alpha)
31
+ blend = np.array(blend)
32
+ return blend
33
+
34
+
35
+ def infer_one_image(model, img_path):
36
+ with torch.no_grad():
37
+ ori_img = Image.open(img_path).convert("RGB")
38
+ ori_height = ori_img.height
39
+ ori_width = ori_img.width
40
+ img = inference_transform(ori_img)
41
+ img = img.cuda()
42
+ img = img.unsqueeze(0)
43
+ ic_score, ic_map = model(img)
44
+ ic_score = ic_score.item()
45
+
46
+
47
+ # ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear')
48
+
49
+ ## gene ic map
50
+ # ic_map_np = ic_map.squeeze().detach().cpu().numpy()
51
+ # out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy'
52
+ # out_ic_map_path = os.path.join(args.output, out_ic_map_name)
53
+ # np.save(out_ic_map_path, ic_map_np)
54
+
55
+ ## gene blend map
56
+ # ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8')
57
+ # blend_img = blend(np.array(ori_img), ic_map_img)
58
+ # out_blend_img_name = os.path.basename(img_path).split('.')[0] + '.png'
59
+ # out_blend_img_path = os.path.join(args.output, out_blend_img_name)
60
+ # cv2.imwrite(out_blend_img_path, blend_img)
61
+ return ic_score
62
+
63
+
64
+
65
+ def infer_directory(img_dir):
66
+ imgs = sorted(os.listdir(img_dir))
67
+ scores = []
68
+ for img in tqdm(imgs):
69
+ img_path = os.path.join(img_dir, img)
70
+ score = infer_one_image(img_path)
71
+
72
+ scores.append((score, img_path))
73
+ print(img_path, score)
74
+
75
+ scores = sorted(scores, key=lambda x: x[0])
76
+ scores = scores[::-1]
77
+
78
+ for score in scores[:50]:
79
+ print(score)
80
+
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument('-i', '--input', type = str, default = './example')
85
+ parser.add_argument('-o', '--output', type = str, default = './out')
86
+ parser.add_argument('-d', '--device', type = int, default=0)
87
+
88
+ args = parser.parse_args()
89
+
90
+ model = ICNet()
91
+ model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu')))
92
+ model.eval()
93
+ device = torch.device(args.device)
94
+ model.to(device)
95
+
96
+ inference_transform = transforms.Compose([
97
+ transforms.Resize((512,512)),
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
100
+ ])
101
+ if os.path.isfile(args.input):
102
+ infer_one_image(args.input)
103
+ else:
104
+ infer_directory(args.input)
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
dataset_curation_pipeline/collect.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is the whole dataset curation pipeline to collect the least compressed and the most informative frames from video source.
3
+ '''
4
+ import os, time, sys
5
+ import shutil
6
+ import cv2
7
+ import torch
8
+ import argparse
9
+
10
+ # Import files from the local folder
11
+ root_path = os.path.abspath('.')
12
+ sys.path.append(root_path)
13
+ from opt import opt
14
+ from dataset_curation_pipeline.IC9600.gene import infer_one_image
15
+ from dataset_curation_pipeline.IC9600.ICNet import ICNet
16
+
17
+
18
+ class video_scoring:
19
+
20
+ def __init__(self, IC9600_pretrained_weight_path) -> None:
21
+
22
+ # Init the model
23
+ self.scorer = ICNet()
24
+ self.scorer.load_state_dict(torch.load(IC9600_pretrained_weight_path, map_location=torch.device('cpu')))
25
+ self.scorer.eval().cuda()
26
+
27
+
28
+ def select_frame(self, skip_num, img_lists, target_frame_num, save_dir, output_name_head, partition_idx):
29
+ ''' Execution of scoring to all I-Frame in img_folder and select target_frame to return back
30
+ Args:
31
+ skip_num (int): Only 1 in skip_num will be chosen to accelerate.
32
+ img_lists (str): The image lists of all files we want to process
33
+ target_frame_num (int): The number of frames we need to choose
34
+ save_dir (str): The path where we save those images
35
+ output_name_head (str): This is the input video name head
36
+ partition_idx (int): The partition idx
37
+ '''
38
+
39
+ stores = []
40
+ for idx, image_path in enumerate(sorted(img_lists)):
41
+ if idx % skip_num != 0:
42
+ # We only process 1 in 3 to accelerate and also prevent minor case of repeated scene.
43
+ continue
44
+
45
+
46
+ # Evaluate the image complexity score for this image
47
+ score = infer_one_image(self.scorer, image_path)
48
+
49
+ if verbose:
50
+ print(image_path, score)
51
+ stores.append((score, image_path))
52
+
53
+ if verbose:
54
+ print(image_path, score)
55
+
56
+
57
+ # Find the top most scores' images
58
+ stores.sort(key=lambda x:x[0])
59
+ selected = stores[-target_frame_num:]
60
+ # print(len(stores), len(selected))
61
+ if verbose:
62
+ print("The lowest selected score is ", selected[0]) # This is a kind of info
63
+
64
+
65
+ # Store the selected images
66
+ for idx, (score, img_path) in enumerate(selected):
67
+ output_name = output_name_head + "_" +str(partition_idx)+ "_" + str(idx) + ".png"
68
+ output_path = os.path.join(save_dir, output_name)
69
+ shutil.copyfile(img_path, output_path)
70
+
71
+
72
+ def run(self, skip_num, img_folder, target_frame_num, save_dir, output_name_head, partition_num):
73
+ ''' Execution of scoring to all I-Frame in img_folder and select target_frame to return back
74
+ Args:
75
+ skip_num (int): Only 1 in skip_num will be chosen to accelerate.
76
+ img_folder (str): The image folder of all I-Frames we need to process
77
+ target_frame_num (int): The number of frames we need to choose
78
+ save_dir (str): The path where we save those images
79
+ output_name_head (str): This is the input video name head
80
+ partition_num (int): The number of partition we want to crop the video to
81
+ '''
82
+ assert(target_frame_num%partition_num == 0)
83
+
84
+ img_lists = []
85
+ for img_name in sorted(os.listdir(img_folder)):
86
+ path = os.path.join(img_folder, img_name)
87
+ img_lists.append(path)
88
+ length = len(img_lists)
89
+ unit_length = (length // partition_num)
90
+ target_partition_num = target_frame_num // partition_num
91
+
92
+ # Cut the folder to several partition and select those with the highest score
93
+ for idx in range(partition_num):
94
+ select_lists = img_lists[unit_length*idx : unit_length*(idx+1)]
95
+ self.select_frame(skip_num, select_lists, target_partition_num, save_dir, output_name_head, idx)
96
+
97
+
98
+ class frame_collector:
99
+
100
+ def __init__(self, IC9600_pretrained_weight_path, verbose) -> None:
101
+
102
+ self.scoring = video_scoring(IC9600_pretrained_weight_path)
103
+ self.verbose = verbose
104
+
105
+
106
+ def video_split_by_IFrame(self, video_path, tmp_path):
107
+ ''' Split the video to its I-Frames format
108
+ Args:
109
+ video_path (str): The directory to a single video
110
+ tmp_path (str): A temporary working places to work and will be delete at the end
111
+ '''
112
+
113
+ # Prepare the work folder needed
114
+ if os.path.exists(tmp_path):
115
+ shutil.rmtree(tmp_path)
116
+ os.makedirs(tmp_path)
117
+
118
+
119
+ # Split Video I-frame
120
+ cmd = "ffmpeg -i " + video_path + " -loglevel error -vf select='eq(pict_type\,I)' -vsync 2 -f image2 -q:v 1 " + tmp_path + "/image-%06d.png" # At most support 100K I-Frames per video
121
+
122
+ if self.verbose:
123
+ print(cmd)
124
+ os.system(cmd)
125
+
126
+
127
+
128
+ def collect_frames(self, video_folder_dir, save_dir, tmp_path, skip_num, target_frames, partition_num):
129
+ ''' Automatically collect frames from the video dir
130
+ Args:
131
+ video_folder_dir (str): The directory of all videos input
132
+ save_dir (str): The directory we will store the selected frames
133
+ tmp_path (str): A temporary working places to work and will be delete at the end
134
+ skip_num (int): Only 1 in skip_num will be chosen to accelerate.
135
+ target_frames (list): [# of frames for video under 30 min, # of frames for video over 30 min]
136
+ partition_num (int): The number of partition we want to crop the video to
137
+ '''
138
+
139
+ # Iterate all video under video_folder_dir
140
+ for video_name in sorted(os.listdir(video_folder_dir)):
141
+ # Sanity check for this video file format
142
+ info = video_name.split('.')
143
+ if info[-1] not in ['mp4', 'mkv', '']:
144
+ continue
145
+ output_name_head, extension = info
146
+
147
+
148
+ # Get info of this video
149
+ video_path = os.path.join(video_folder_dir, video_name)
150
+ duration = get_duration(video_path) # unit in minutes
151
+ print("We are processing " + video_path + " with duration " + str(duration) + " min")
152
+
153
+
154
+ # Split the video to I-frame
155
+ self.video_split_by_IFrame(video_path, tmp_path)
156
+
157
+
158
+ # Score the frames and select those top scored frames we need
159
+ if duration <= 30:
160
+ target_frame_num = target_frames[0]
161
+ else:
162
+ target_frame_num = target_frames[1]
163
+
164
+ self.scoring.run(skip_num, tmp_path, target_frame_num, save_dir, output_name_head, partition_num)
165
+
166
+
167
+ # Remove folders if needed
168
+
169
+
170
+ def get_duration(filename):
171
+ video = cv2.VideoCapture(filename)
172
+ fps = video.get(cv2.CAP_PROP_FPS)
173
+ frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
174
+ seconds = frame_count / fps
175
+ minutes = int(seconds / 60)
176
+ return minutes
177
+
178
+
179
+ if __name__ == "__main__":
180
+
181
+ # Fundamental setting
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('--video_folder_dir', type = str, default = '../anime_videos', help = "A folder with video sources")
184
+ parser.add_argument('--IC9600_pretrained_weight_path', type = str, default = "pretrained/ck.pth", help = "The pretrained IC9600 weight")
185
+ parser.add_argument('--save_dir', type = str, default = 'APISR_dataset', help = "The folder to store filtered dataset")
186
+ parser.add_argument('--skip_num', type = int, default = 5, help = "Only 1 in skip_num will be chosen in sequential I-frames to accelerate.")
187
+ parser.add_argument('--target_frames', type = list, default = [16, 24], help = "[# of frames for video under 30 min, # of frames for video over 30 min]")
188
+ parser.add_argument('--partition_num', type = int, default = 8, help = "The number of partition we want to crop the video to, to increase diversity of sampling")
189
+ parser.add_argument('--verbose', type = bool, default = True, help = "Whether we print log message")
190
+ args = parser.parse_args()
191
+
192
+
193
+ # Transform to variable
194
+ video_folder_dir = args.video_folder_dir
195
+ IC9600_pretrained_weight_path = args.IC9600_pretrained_weight_path
196
+ save_dir = args.save_dir
197
+ skip_num = args.skip_num
198
+ target_frames = args.target_frames # [# of frames for video under 30 min, # of frames for video over 30 min]
199
+ partition_num = args.partition_num
200
+ verbose = args.verbose
201
+
202
+
203
+ # Secondary setting
204
+ tmp_path = "tmp_dataset"
205
+
206
+
207
+ # Prepare
208
+ if os.path.exists(save_dir):
209
+ shutil.rmtree(save_dir)
210
+ os.makedirs(save_dir)
211
+
212
+
213
+ # Process
214
+ start = time.time()
215
+
216
+ obj = frame_collector(IC9600_pretrained_weight_path, verbose)
217
+ obj.collect_frames(video_folder_dir, save_dir, tmp_path, skip_num, target_frames, partition_num)
218
+
219
+ total_time = (time.time() - start)//60
220
+ print("Total time spent is {} min".format(total_time))
221
+
222
+ shutil.rmtree(tmp_path)
degradation/ESR/degradation_esr_shared.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import argparse
4
+ import cv2
5
+ import torch
6
+ import numpy as np
7
+ import os, shutil, time
8
+ import sys, random
9
+ from multiprocessing import Pool
10
+ from os import path as osp
11
+ from tqdm import tqdm
12
+ from math import log10, sqrt
13
+ import torch.nn.functional as F
14
+
15
+ root_path = os.path.abspath('.')
16
+ sys.path.append(root_path)
17
+ from degradation.ESR.degradations_functionality import *
18
+ from degradation.ESR.diffjpeg import *
19
+ from degradation.ESR.utils import filter2D
20
+ from degradation.image_compression.jpeg import JPEG
21
+ from degradation.image_compression.webp import WEBP
22
+ from degradation.image_compression.heif import HEIF
23
+ from degradation.image_compression.avif import AVIF
24
+ from opt import opt
25
+
26
+
27
+ def PSNR(original, compressed):
28
+ mse = np.mean((original - compressed) ** 2)
29
+ if(mse == 0): # MSE is zero means no noise is present in the signal .
30
+ # Therefore PSNR have no importance.
31
+ return 100
32
+ max_pixel = 255.0
33
+ psnr = 20 * log10(max_pixel / sqrt(mse))
34
+ return psnr
35
+
36
+
37
+
38
+ def downsample_1st(out, opt):
39
+ # Resize with different mode
40
+ updown_type = random.choices(['up', 'down', 'keep'], opt['resize_prob'])[0]
41
+ if updown_type == 'up':
42
+ scale = np.random.uniform(1, opt['resize_range'][1])
43
+ elif updown_type == 'down':
44
+ scale = np.random.uniform(opt['resize_range'][0], 1)
45
+ else:
46
+ scale = 1
47
+ mode = random.choice(opt['resize_options'])
48
+ out = F.interpolate(out, scale_factor=scale, mode=mode)
49
+
50
+ return out
51
+
52
+
53
+ def downsample_2nd(out, opt, ori_h, ori_w):
54
+ # Second Resize for 4x scaling
55
+ if opt['scale'] == 4:
56
+ updown_type = random.choices(['up', 'down', 'keep'], opt['resize_prob2'])[0]
57
+ if updown_type == 'up':
58
+ scale = np.random.uniform(1, opt['resize_range2'][1])
59
+ elif updown_type == 'down':
60
+ scale = np.random.uniform(opt['resize_range2'][0], 1)
61
+ else:
62
+ scale = 1
63
+ mode = random.choice(opt['resize_options'])
64
+ # Resize这边改回来原来的版本,不用连续的resize了
65
+ # out = F.interpolate(out, scale_factor=scale, mode=mode)
66
+ out = F.interpolate(
67
+ out, size=(int(ori_h / opt['scale'] * scale), int(ori_w / opt['scale'] * scale)), mode=mode
68
+ )
69
+
70
+ return out
71
+
72
+
73
+ def common_degradation(out, opt, kernels, process_id, verbose = False):
74
+ jpeger = DiffJPEG(differentiable=False).cuda()
75
+ kernel1, kernel2 = kernels
76
+
77
+
78
+ downsample_1st_position = random.choices([0, 1, 2])[0]
79
+ if opt['scale'] == 4:
80
+ # Only do the second downsample at 4x scale
81
+ downsample_2nd_position = random.choices([0, 1, 2])[0]
82
+ else:
83
+ # print("We don't use the second resize")
84
+ downsample_2nd_position = -1
85
+
86
+
87
+ ####---------------------------- Frist Degradation ----------------------------------####
88
+ batch_size, _, ori_h, ori_w = out.size()
89
+
90
+ if downsample_1st_position == 0:
91
+ out = downsample_1st(out, opt)
92
+
93
+ # Bluring kernel
94
+ out = filter2D(out, kernel1)
95
+ if verbose: print(f"(1st) blur noise")
96
+
97
+
98
+ if downsample_1st_position == 1:
99
+ out = downsample_1st(out, opt)
100
+
101
+
102
+ # Noise effect (gaussian / poisson)
103
+ gray_noise_prob = opt['gray_noise_prob']
104
+ if np.random.uniform() < opt['gaussian_noise_prob']:
105
+ # Gaussian noise
106
+ out = random_add_gaussian_noise_pt(
107
+ out, sigma_range=opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
108
+ name = "gaussian_noise"
109
+ else:
110
+ # Poisson noise
111
+ out = random_add_poisson_noise_pt(
112
+ out,
113
+ scale_range=opt['poisson_scale_range'],
114
+ gray_prob=gray_noise_prob,
115
+ clip=True,
116
+ rounds=False)
117
+ name = "poisson_noise"
118
+ if verbose: print("(1st) " + str(name))
119
+
120
+
121
+ if downsample_1st_position == 2:
122
+ out = downsample_1st(out, opt)
123
+
124
+
125
+ # Choose an image compression codec (All degradation batch use the same codec)
126
+ image_codec = random.choices(opt['compression_codec1'], opt['compression_codec_prob1'])[0] # All lower case
127
+ if image_codec == "jpeg":
128
+ out = JPEG.compress_tensor(out)
129
+ elif image_codec == "webp":
130
+ try:
131
+ out = WEBP.compress_tensor(out, idx=process_id)
132
+ except Exception:
133
+ print("There is exception again in webp!")
134
+ out = WEBP.compress_tensor(out, idx=process_id)
135
+ elif image_codec == "heif":
136
+ out = HEIF.compress_tensor(out, idx=process_id)
137
+ elif image_codec == "avif":
138
+ out = AVIF.compress_tensor(out, idx=process_id)
139
+ else:
140
+ raise NotImplementedError("We don't have such image compression designed!")
141
+ # ##########################################################################################
142
+
143
+
144
+ # ####---------------------------- Second Degradation ----------------------------------####
145
+ if downsample_2nd_position == 0:
146
+ out = downsample_2nd(out, opt, ori_h, ori_w)
147
+
148
+
149
+ # Add blur 2nd time
150
+ if np.random.uniform() < opt['second_blur_prob']:
151
+ # 这个bluring不是必定触发的
152
+ if verbose: print("(2nd) blur noise")
153
+ out = filter2D(out, kernel2)
154
+
155
+
156
+ if downsample_2nd_position == 1:
157
+ out = downsample_2nd(out, opt, ori_h, ori_w)
158
+
159
+
160
+ # Add noise 2nd time
161
+ gray_noise_prob = opt['gray_noise_prob2']
162
+ if np.random.uniform() < opt['gaussian_noise_prob2']:
163
+ # gaussian noise
164
+ if verbose: print("(2nd) gaussian noise")
165
+ out = random_add_gaussian_noise_pt(
166
+ out, sigma_range=opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
167
+ name = "gaussian_noise"
168
+ else:
169
+ # poisson noise
170
+ if verbose: print("(2nd) poisson noise")
171
+ out = random_add_poisson_noise_pt(
172
+ out, scale_range=opt['poisson_scale_range2'], gray_prob=gray_noise_prob, clip=True, rounds=False)
173
+ name = "poisson_noise"
174
+
175
+
176
+ if downsample_2nd_position == 2:
177
+ out = downsample_2nd(out, opt, ori_h, ori_w)
178
+
179
+
180
+ return out
degradation/ESR/degradations_functionality.py ADDED
@@ -0,0 +1,785 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import cv2
4
+ import math
5
+ import numpy as np
6
+ import random
7
+ import torch
8
+ from scipy import special
9
+ from scipy.stats import multivariate_normal
10
+ from torchvision.transforms.functional_tensor import rgb_to_grayscale
11
+
12
+ # -------------------------------------------------------------------- #
13
+ # --------------------------- blur kernels --------------------------- #
14
+ # -------------------------------------------------------------------- #
15
+
16
+
17
+ # --------------------------- util functions --------------------------- #
18
+ def sigma_matrix2(sig_x, sig_y, theta):
19
+ """Calculate the rotated sigma matrix (two dimensional matrix).
20
+
21
+ Args:
22
+ sig_x (float):
23
+ sig_y (float):
24
+ theta (float): Radian measurement.
25
+
26
+ Returns:
27
+ ndarray: Rotated sigma matrix.
28
+ """
29
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
30
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
31
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
32
+
33
+
34
+ def mesh_grid(kernel_size):
35
+ """Generate the mesh grid, centering at zero.
36
+
37
+ Args:
38
+ kernel_size (int):
39
+
40
+ Returns:
41
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
42
+ xx (ndarray): with the shape (kernel_size, kernel_size)
43
+ yy (ndarray): with the shape (kernel_size, kernel_size)
44
+ """
45
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
46
+ xx, yy = np.meshgrid(ax, ax)
47
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
48
+ 1))).reshape(kernel_size, kernel_size, 2)
49
+ return xy, xx, yy
50
+
51
+
52
+ def pdf2(sigma_matrix, grid):
53
+ """Calculate PDF of the bivariate Gaussian distribution.
54
+
55
+ Args:
56
+ sigma_matrix (ndarray): with the shape (2, 2)
57
+ grid (ndarray): generated by :func:`mesh_grid`,
58
+ with the shape (K, K, 2), K is the kernel size.
59
+
60
+ Returns:
61
+ kernel (ndarrray): un-normalized kernel.
62
+ """
63
+ inverse_sigma = np.linalg.inv(sigma_matrix)
64
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
65
+ return kernel
66
+
67
+
68
+ def cdf2(d_matrix, grid):
69
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
70
+ Used in skewed Gaussian distribution.
71
+
72
+ Args:
73
+ d_matrix (ndarrasy): skew matrix.
74
+ grid (ndarray): generated by :func:`mesh_grid`,
75
+ with the shape (K, K, 2), K is the kernel size.
76
+
77
+ Returns:
78
+ cdf (ndarray): skewed cdf.
79
+ """
80
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
81
+ grid = np.dot(grid, d_matrix)
82
+ cdf = rv.cdf(grid)
83
+ return cdf
84
+
85
+
86
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
87
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
88
+
89
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
90
+
91
+ Args:
92
+ kernel_size (int):
93
+ sig_x (float):
94
+ sig_y (float):
95
+ theta (float): Radian measurement.
96
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
97
+ with the shape (K, K, 2), K is the kernel size. Default: None
98
+ isotropic (bool):
99
+
100
+ Returns:
101
+ kernel (ndarray): normalized kernel.
102
+ """
103
+ if grid is None:
104
+ grid, _, _ = mesh_grid(kernel_size)
105
+ if isotropic:
106
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
107
+ else:
108
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
109
+ kernel = pdf2(sigma_matrix, grid)
110
+ kernel = kernel / np.sum(kernel)
111
+ return kernel
112
+
113
+
114
+ def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
115
+ """Generate a bivariate generalized Gaussian kernel.
116
+ Described in `Parameter Estimation For Multivariate Generalized
117
+ Gaussian Distributions`_
118
+ by Pascal et. al (2013).
119
+
120
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
121
+
122
+ Args:
123
+ kernel_size (int):
124
+ sig_x (float):
125
+ sig_y (float):
126
+ theta (float): Radian measurement.
127
+ beta (float): shape parameter, beta = 1 is the normal distribution.
128
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
129
+ with the shape (K, K, 2), K is the kernel size. Default: None
130
+
131
+ Returns:
132
+ kernel (ndarray): normalized kernel.
133
+
134
+ .. _Parameter Estimation For Multivariate Generalized Gaussian
135
+ Distributions: https://arxiv.org/abs/1302.6498
136
+ """
137
+ if grid is None:
138
+ grid, _, _ = mesh_grid(kernel_size)
139
+ if isotropic:
140
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
141
+ else:
142
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
143
+ inverse_sigma = np.linalg.inv(sigma_matrix)
144
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
145
+ kernel = kernel / np.sum(kernel)
146
+ return kernel
147
+
148
+
149
+ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
150
+ """Generate a plateau-like anisotropic kernel.
151
+ 1 / (1+x^(beta))
152
+
153
+ Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
154
+
155
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
156
+
157
+ Args:
158
+ kernel_size (int):
159
+ sig_x (float):
160
+ sig_y (float):
161
+ theta (float): Radian measurement.
162
+ beta (float): shape parameter, beta = 1 is the normal distribution.
163
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
164
+ with the shape (K, K, 2), K is the kernel size. Default: None
165
+
166
+ Returns:
167
+ kernel (ndarray): normalized kernel.
168
+ """
169
+ if grid is None:
170
+ grid, _, _ = mesh_grid(kernel_size)
171
+ if isotropic:
172
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
173
+ else:
174
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
175
+ inverse_sigma = np.linalg.inv(sigma_matrix)
176
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
177
+ kernel = kernel / np.sum(kernel)
178
+ return kernel
179
+
180
+
181
+ def random_bivariate_Gaussian(kernel_size,
182
+ sigma_x_range,
183
+ sigma_y_range,
184
+ rotation_range,
185
+ noise_range=None,
186
+ isotropic=True):
187
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
188
+
189
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
190
+
191
+ Args:
192
+ kernel_size (int):
193
+ sigma_x_range (tuple): [0.6, 5]
194
+ sigma_y_range (tuple): [0.6, 5]
195
+ rotation range (tuple): [-math.pi, math.pi]
196
+ noise_range(tuple, optional): multiplicative kernel noise,
197
+ [0.75, 1.25]. Default: None
198
+
199
+ Returns:
200
+ kernel (ndarray):
201
+ """
202
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
203
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
204
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
205
+ if isotropic is False:
206
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
207
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
208
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
209
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
210
+ else:
211
+ sigma_y = sigma_x
212
+ rotation = 0
213
+
214
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
215
+
216
+ # add multiplicative noise
217
+ if noise_range is not None:
218
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
219
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
220
+ kernel = kernel * noise
221
+ kernel = kernel / np.sum(kernel)
222
+ return kernel
223
+
224
+
225
+ def random_bivariate_generalized_Gaussian(kernel_size,
226
+ sigma_x_range,
227
+ sigma_y_range,
228
+ rotation_range,
229
+ beta_range,
230
+ noise_range=None,
231
+ isotropic=True):
232
+ """Randomly generate bivariate generalized Gaussian kernels.
233
+
234
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
235
+
236
+ Args:
237
+ kernel_size (int):
238
+ sigma_x_range (tuple): [0.6, 5]
239
+ sigma_y_range (tuple): [0.6, 5]
240
+ rotation range (tuple): [-math.pi, math.pi]
241
+ beta_range (tuple): [0.5, 8]
242
+ noise_range(tuple, optional): multiplicative kernel noise,
243
+ [0.75, 1.25]. Default: None
244
+
245
+ Returns:
246
+ kernel (ndarray):
247
+ """
248
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
249
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
250
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
251
+ if isotropic is False:
252
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
253
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
254
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
255
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
256
+ else:
257
+ sigma_y = sigma_x
258
+ rotation = 0
259
+
260
+ # assume beta_range[0] < 1 < beta_range[1]
261
+ if np.random.uniform() < 0.5:
262
+ beta = np.random.uniform(beta_range[0], 1)
263
+ else:
264
+ beta = np.random.uniform(1, beta_range[1])
265
+
266
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
267
+
268
+ # add multiplicative noise
269
+ if noise_range is not None:
270
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
271
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
272
+ kernel = kernel * noise
273
+ kernel = kernel / np.sum(kernel)
274
+ return kernel
275
+
276
+
277
+ def random_bivariate_plateau(kernel_size,
278
+ sigma_x_range,
279
+ sigma_y_range,
280
+ rotation_range,
281
+ beta_range,
282
+ noise_range=None,
283
+ isotropic=True):
284
+ """Randomly generate bivariate plateau kernels.
285
+
286
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
287
+
288
+ Args:
289
+ kernel_size (int):
290
+ sigma_x_range (tuple): [0.6, 5]
291
+ sigma_y_range (tuple): [0.6, 5]
292
+ rotation range (tuple): [-math.pi/2, math.pi/2]
293
+ beta_range (tuple): [1, 4]
294
+ noise_range(tuple, optional): multiplicative kernel noise,
295
+ [0.75, 1.25]. Default: None
296
+
297
+ Returns:
298
+ kernel (ndarray):
299
+ """
300
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
301
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
302
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
303
+ if isotropic is False:
304
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
305
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
306
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
307
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
308
+ else:
309
+ sigma_y = sigma_x
310
+ rotation = 0
311
+
312
+ # TODO: this may be not proper
313
+ if np.random.uniform() < 0.5:
314
+ beta = np.random.uniform(beta_range[0], 1)
315
+ else:
316
+ beta = np.random.uniform(1, beta_range[1])
317
+
318
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
319
+ # add multiplicative noise
320
+ if noise_range is not None:
321
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
322
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
323
+ kernel = kernel * noise
324
+ kernel = kernel / np.sum(kernel)
325
+
326
+ return kernel
327
+
328
+
329
+ def random_mixed_kernels(kernel_list,
330
+ kernel_prob,
331
+ kernel_size=21,
332
+ sigma_x_range=(0.6, 5),
333
+ sigma_y_range=(0.6, 5),
334
+ rotation_range=(-math.pi, math.pi),
335
+ betag_range=(0.5, 8),
336
+ betap_range=(0.5, 8),
337
+ noise_range=None):
338
+ """Randomly generate mixed kernels.
339
+
340
+ Args:
341
+ kernel_list (tuple): a list name of kernel types,
342
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
343
+ 'plateau_aniso']
344
+ kernel_prob (tuple): corresponding kernel probability for each
345
+ kernel type
346
+ kernel_size (int):
347
+ sigma_x_range (tuple): [0.6, 5]
348
+ sigma_y_range (tuple): [0.6, 5]
349
+ rotation range (tuple): [-math.pi, math.pi]
350
+ beta_range (tuple): [0.5, 8]
351
+ noise_range(tuple, optional): multiplicative kernel noise,
352
+ [0.75, 1.25]. Default: None
353
+
354
+ Returns:
355
+ kernel (ndarray):
356
+ """
357
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
358
+ if kernel_type == 'iso':
359
+ kernel = random_bivariate_Gaussian(
360
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
361
+ elif kernel_type == 'aniso':
362
+ kernel = random_bivariate_Gaussian(
363
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
364
+ elif kernel_type == 'generalized_iso':
365
+ kernel = random_bivariate_generalized_Gaussian(
366
+ kernel_size,
367
+ sigma_x_range,
368
+ sigma_y_range,
369
+ rotation_range,
370
+ betag_range,
371
+ noise_range=noise_range,
372
+ isotropic=True)
373
+ elif kernel_type == 'generalized_aniso':
374
+ kernel = random_bivariate_generalized_Gaussian(
375
+ kernel_size,
376
+ sigma_x_range,
377
+ sigma_y_range,
378
+ rotation_range,
379
+ betag_range,
380
+ noise_range=noise_range,
381
+ isotropic=False)
382
+ elif kernel_type == 'plateau_iso':
383
+ kernel = random_bivariate_plateau(
384
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
385
+ elif kernel_type == 'plateau_aniso':
386
+ kernel = random_bivariate_plateau(
387
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
388
+ return kernel
389
+
390
+
391
+ np.seterr(divide='ignore', invalid='ignore')
392
+
393
+
394
+ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
395
+ """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
396
+ =====》 这个地方好好调研一下,能做出来的效果决定了后面的上线!
397
+ Args:
398
+ cutoff (float): cutoff frequency in radians (pi is max)
399
+ kernel_size (int): horizontal and vertical size, must be odd.
400
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
401
+ """
402
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
403
+ kernel = np.fromfunction(
404
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
405
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
406
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
407
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
408
+ kernel = kernel / np.sum(kernel)
409
+ if pad_to > kernel_size:
410
+ pad_size = (pad_to - kernel_size) // 2
411
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
412
+ return kernel
413
+
414
+
415
+ # ------------------------------------------------------------- #
416
+ # --------------------------- noise --------------------------- #
417
+ # ------------------------------------------------------------- #
418
+
419
+ # ----------------------- Gaussian Noise ----------------------- #
420
+
421
+
422
+ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
423
+ """Generate Gaussian noise.
424
+
425
+ Args:
426
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
427
+ sigma (float): Noise scale (measured in range 255). Default: 10.
428
+
429
+ Returns:
430
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
431
+ float32.
432
+ """
433
+ if gray_noise:
434
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
435
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
436
+ else:
437
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
438
+ return noise
439
+
440
+
441
+ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
442
+ """Add Gaussian noise.
443
+
444
+ Args:
445
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
446
+ sigma (float): Noise scale (measured in range 255). Default: 10.
447
+
448
+ Returns:
449
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
450
+ float32.
451
+ """
452
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
453
+ out = img + noise
454
+ if clip and rounds:
455
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
456
+ elif clip:
457
+ out = np.clip(out, 0, 1)
458
+ elif rounds:
459
+ out = (out * 255.0).round() / 255.
460
+ return out
461
+
462
+
463
+ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
464
+ """Add Gaussian noise (PyTorch version).
465
+
466
+ Args:
467
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
468
+ sigma (float | Tensor): 每一个batch都被分配了一个(share 一个)
469
+ gray_noise (float | Tensor): 不是1就是0
470
+
471
+ Returns:
472
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
473
+ float32.
474
+ """
475
+ b, _, h, w = img.size()
476
+ if not isinstance(sigma, (float, int)):
477
+ sigma = sigma.view(img.size(0), 1, 1, 1)
478
+ if isinstance(gray_noise, (float, int)):
479
+ cal_gray_noise = gray_noise > 0
480
+ else:
481
+ gray_noise = gray_noise.view(b, 1, 1, 1)
482
+ cal_gray_noise = torch.sum(gray_noise) > 0
483
+
484
+ if cal_gray_noise:
485
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
486
+ noise_gray = noise_gray.view(b, 1, h, w)
487
+
488
+ # always calculate color noise
489
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
490
+
491
+ if cal_gray_noise:
492
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
493
+ return noise
494
+
495
+
496
+ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
497
+ """Add Gaussian noise (PyTorch version).
498
+
499
+ Args:
500
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
501
+ scale (float | Tensor): Noise scale. Default: 1.0.
502
+
503
+ Returns:
504
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
505
+ float32.
506
+ """
507
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise) # sigma 就是gray_noise的保存率
508
+ out = img + noise
509
+ if clip and rounds:
510
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
511
+ elif clip:
512
+ out = torch.clamp(out, 0, 1)
513
+ elif rounds:
514
+ out = (out * 255.0).round() / 255.
515
+ return out
516
+
517
+
518
+ # ----------------------- Random Gaussian Noise ----------------------- #
519
+ def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
520
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
521
+ if np.random.uniform() < gray_prob:
522
+ gray_noise = True
523
+ else:
524
+ gray_noise = False
525
+ return generate_gaussian_noise(img, sigma, gray_noise)
526
+
527
+
528
+ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
529
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
530
+ out = img + noise
531
+ if clip and rounds:
532
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
533
+ elif clip:
534
+ out = np.clip(out, 0, 1)
535
+ elif rounds:
536
+ out = (out * 255.0).round() / 255.
537
+ return out
538
+
539
+
540
+ def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
541
+ sigma = torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
542
+
543
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
544
+ gray_noise = (gray_noise < gray_prob).float()
545
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
546
+
547
+
548
+ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
549
+ # sigma_range 就是noise保存比例
550
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
551
+ out = img + noise
552
+ if clip and rounds:
553
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
554
+ elif clip:
555
+ out = torch.clamp(out, 0, 1)
556
+ elif rounds:
557
+ out = (out * 255.0).round() / 255.
558
+ return out
559
+
560
+
561
+ # ----------------------- Poisson (Shot) Noise ----------------------- #
562
+
563
+
564
+ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
565
+ """Generate poisson noise.
566
+
567
+ Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
568
+
569
+ Args:
570
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
571
+ scale (float): Noise scale. Default: 1.0.
572
+ gray_noise (bool): Whether generate gray noise. Default: False.
573
+
574
+ Returns:
575
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
576
+ float32.
577
+ """
578
+ if gray_noise:
579
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
580
+ # round and clip image for counting vals correctly
581
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
582
+ vals = len(np.unique(img))
583
+ vals = 2**np.ceil(np.log2(vals))
584
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
585
+ noise = out - img
586
+ if gray_noise:
587
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
588
+ return noise * scale
589
+
590
+
591
+ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
592
+ """Add poisson noise.
593
+
594
+ Args:
595
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
596
+ scale (float): Noise scale. Default: 1.0.
597
+ gray_noise (bool): Whether generate gray noise. Default: False.
598
+
599
+ Returns:
600
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
601
+ float32.
602
+ """
603
+ noise = generate_poisson_noise(img, scale, gray_noise)
604
+ out = img + noise
605
+ if clip and rounds:
606
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
607
+ elif clip:
608
+ out = np.clip(out, 0, 1)
609
+ elif rounds:
610
+ out = (out * 255.0).round() / 255.
611
+ return out
612
+
613
+
614
+ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
615
+ """Generate a batch of poisson noise (PyTorch version)
616
+
617
+ Args:
618
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
619
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
620
+ Default: 1.0.
621
+ 可以是个batch形式(Tensor)
622
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
623
+ 0 for False, 1 for True. Default: 0.
624
+ 可以是个batch形式(Tensor)
625
+
626
+ Returns:
627
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
628
+ float32.
629
+ """
630
+ b, _, h, w = img.size()
631
+ if isinstance(gray_noise, (float, int)):
632
+ cal_gray_noise = gray_noise > 0
633
+ else:
634
+ gray_noise = gray_noise.view(b, 1, 1, 1)
635
+ # 这下面跟原论文有点小不一样的地方,如果按照我现在128 batch size,基本上每个都会有gray noise
636
+ cal_gray_noise = torch.sum(gray_noise) > 0
637
+ if cal_gray_noise:
638
+ # 这里实际上我是觉得写的不是很efficient,因为有些地方如果不加那不是完全白计算了吗,现在gray noise的概率低得很
639
+ img_gray = rgb_to_grayscale(img, num_output_channels=1) # 返回的只有luminance这一个channel
640
+ # round and clip image for counting vals correctly, ensure that it only has 256 possible floats at the end
641
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
642
+ # use for-loop to get the unique values for each sample
643
+
644
+ # Note: 这里加上noise完全看的是本图片(一张)的颜色diversity,这应该就解释了为什么在比较单一的flat图像,他会noise更加明显
645
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
646
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
647
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
648
+
649
+ # Since the img is in range [0,1], the noise by possion distribution should also lies in [0,1]
650
+ # Note: 这只是我个人的理解,现在对于单调的图片,整体会比较集中poisson noise在一个高点,就不如unique值高的图片会广泛分布(看possison distribution的图都看的出来)
651
+ out = torch.poisson(img_gray * vals) / vals
652
+ noise_gray = out - img_gray
653
+ noise_gray = noise_gray.expand(b, 3, h, w)
654
+
655
+ # always calculate color noise
656
+ # round and clip image for counting vals correctly
657
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
658
+ # use for-loop to get the unique values for each sample
659
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
660
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
661
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
662
+ out = torch.poisson(img * vals) / vals # output还是正数
663
+ noise = out - img # 这个会导致负值的产生
664
+ if cal_gray_noise:
665
+ # Note: 这里noise要么全加,要么不加(换成gray_noise)
666
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise # In this place, I don't know why it sometimes run out of memory
667
+ if not isinstance(scale, (float, int)):
668
+ scale = scale.view(b, 1, 1, 1)
669
+
670
+ # Note: noise这边产出的值都是-0.x ---- +0.x 这个范围: 负的值相当于减弱pixel值的效果
671
+ # print("poisson noise range is ", sorted(torch.unique(noise))[:10])
672
+ # print(sorted(torch.unique(noise))[-10:])
673
+ return noise * scale
674
+
675
+
676
+ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
677
+ """Add poisson noise to a batch of images (PyTorch version).
678
+
679
+ Args:
680
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
681
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
682
+ Default: 1.0.
683
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
684
+ 0 for False, 1 for True. Default: 0.
685
+
686
+ Returns:
687
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
688
+ float32.
689
+ """
690
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
691
+ out = img + noise
692
+ if clip and rounds:
693
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
694
+ elif clip:
695
+ out = torch.clamp(out, 0, 1)
696
+ elif rounds:
697
+ out = (out * 255.0).round() / 255.
698
+ return out
699
+
700
+
701
+ # ----------------------- Random Poisson (Shot) Noise ----------------------- #
702
+
703
+
704
+ def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
705
+ scale = np.random.uniform(scale_range[0], scale_range[1])
706
+ if np.random.uniform() < gray_prob:
707
+ gray_noise = True
708
+ else:
709
+ gray_noise = False
710
+ return generate_poisson_noise(img, scale, gray_noise)
711
+
712
+
713
+ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
714
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
715
+ out = img + noise
716
+ if clip and rounds:
717
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
718
+ elif clip:
719
+ out = np.clip(out, 0, 1)
720
+ elif rounds:
721
+ out = (out * 255.0).round() / 255.
722
+ return out
723
+
724
+
725
+ def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
726
+ # scale_range 还是保存的大小
727
+ # img.size(0) 代表就是batch中的每个图片都有一个自己的scale level
728
+ scale = torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
729
+
730
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
731
+ gray_noise = (gray_noise < gray_prob).float()
732
+ return generate_poisson_noise_pt(img, scale, gray_noise) # scale 和 gray_noise应该都是tensor的batch形式
733
+
734
+
735
+ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
736
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
737
+ out = img + noise
738
+ if clip and rounds:
739
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
740
+ elif clip:
741
+ out = torch.clamp(out, 0, 1)
742
+ elif rounds:
743
+ out = (out * 255.0).round() / 255.
744
+ return out
745
+
746
+
747
+ # ------------------------------------------------------------------------ #
748
+ # --------------------------- JPEG compression --------------------------- #
749
+ # ------------------------------------------------------------------------ #
750
+
751
+
752
+ def add_jpg_compression(img, quality=90):
753
+ """Add JPG compression artifacts.
754
+
755
+ Args:
756
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
757
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
758
+ best quality. Default: 90.
759
+
760
+ Returns:
761
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
762
+ float32.
763
+ """
764
+ img = np.clip(img, 0, 1)
765
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
766
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
767
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
768
+ return img
769
+
770
+
771
+ def random_add_jpg_compression(img, quality_range=(90, 100)):
772
+ """Randomly add JPG compression artifacts.
773
+
774
+ Args:
775
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
776
+ quality_range (tuple[float] | list[float]): JPG compression quality
777
+ range. 0 for lowest quality, 100 for best quality.
778
+ Default: (90, 100).
779
+
780
+ Returns:
781
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
782
+ float32.
783
+ """
784
+ quality = np.random.uniform(quality_range[0], quality_range[1])
785
+ return add_jpg_compression(img, quality)
degradation/ESR/diffjpeg.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ """
4
+ Modified from https://github.com/mlomnitz/DiffJPEG
5
+
6
+ For images not divisible by 8
7
+ https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
8
+ """
9
+ import itertools
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+ # ------------------------ utils ------------------------#
16
+ y_table = np.array(
17
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
18
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
19
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
20
+ dtype=np.float32).T
21
+ y_table = nn.Parameter(torch.from_numpy(y_table))
22
+ c_table = np.empty((8, 8), dtype=np.float32)
23
+ c_table.fill(99)
24
+ c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T
25
+ c_table = nn.Parameter(torch.from_numpy(c_table))
26
+
27
+
28
+ def diff_round(x):
29
+ """ Differentiable rounding function
30
+ """
31
+ return torch.round(x) + (x - torch.round(x))**3
32
+
33
+
34
+ def quality_to_factor(quality):
35
+ """ Calculate factor corresponding to quality
36
+
37
+ Args:
38
+ quality(float): Quality for jpeg compression.
39
+
40
+ Returns:
41
+ float: Compression factor.
42
+ """
43
+ if quality < 50:
44
+ quality = 5000. / quality
45
+ else:
46
+ quality = 200. - quality * 2
47
+ return quality / 100.
48
+
49
+
50
+ # ------------------------ compression ------------------------#
51
+ class RGB2YCbCrJpeg(nn.Module):
52
+ """ Converts RGB image to YCbCr
53
+ """
54
+
55
+ def __init__(self):
56
+ super(RGB2YCbCrJpeg, self).__init__()
57
+ matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
58
+ dtype=np.float32).T
59
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
60
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
61
+
62
+ def forward(self, image):
63
+ """
64
+ Args:
65
+ image(Tensor): batch x 3 x height x width
66
+
67
+ Returns:
68
+ Tensor: batch x height x width x 3
69
+ """
70
+ image = image.permute(0, 2, 3, 1)
71
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
72
+ return result.view(image.shape)
73
+
74
+
75
+ class ChromaSubsampling(nn.Module):
76
+ """ Chroma subsampling on CbCr channels
77
+ """
78
+
79
+ def __init__(self):
80
+ super(ChromaSubsampling, self).__init__()
81
+
82
+ def forward(self, image):
83
+ """
84
+ Args:
85
+ image(tensor): batch x height x width x 3
86
+
87
+ Returns:
88
+ y(tensor): batch x height x width
89
+ cb(tensor): batch x height/2 x width/2
90
+ cr(tensor): batch x height/2 x width/2
91
+ """
92
+ image_2 = image.permute(0, 3, 1, 2).clone()
93
+ cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
94
+ cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False)
95
+ cb = cb.permute(0, 2, 3, 1)
96
+ cr = cr.permute(0, 2, 3, 1)
97
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
98
+
99
+
100
+ class BlockSplitting(nn.Module):
101
+ """ Splitting image into patches
102
+ """
103
+
104
+ def __init__(self):
105
+ super(BlockSplitting, self).__init__()
106
+ self.k = 8
107
+
108
+ def forward(self, image):
109
+ """
110
+ Args:
111
+ image(tensor): batch x height x width
112
+
113
+ Returns:
114
+ Tensor: batch x h*w/64 x h x w
115
+ """
116
+ height, _ = image.shape[1:3]
117
+ batch_size = image.shape[0]
118
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
119
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
120
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
121
+
122
+
123
+ class DCT8x8(nn.Module):
124
+ """ Discrete Cosine Transformation
125
+ """
126
+
127
+ def __init__(self):
128
+ super(DCT8x8, self).__init__()
129
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
130
+ for x, y, u, v in itertools.product(range(8), repeat=4):
131
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
132
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
133
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
134
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
135
+
136
+ def forward(self, image):
137
+ """
138
+ Args:
139
+ image(tensor): batch x height x width
140
+
141
+ Returns:
142
+ Tensor: batch x height x width
143
+ """
144
+ image = image - 128
145
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
146
+ result.view(image.shape)
147
+ return result
148
+
149
+
150
+ class YQuantize(nn.Module):
151
+ """ JPEG Quantization for Y channel
152
+
153
+ Args:
154
+ rounding(function): rounding function to use
155
+ """
156
+
157
+ def __init__(self, rounding):
158
+ super(YQuantize, self).__init__()
159
+ self.rounding = rounding
160
+ self.y_table = y_table
161
+
162
+ def forward(self, image, factor=1):
163
+ """
164
+ Args:
165
+ image(tensor): batch x height x width
166
+
167
+ Returns:
168
+ Tensor: batch x height x width
169
+ """
170
+ if isinstance(factor, (int, float)):
171
+ image = image.float() / (self.y_table * factor)
172
+ else:
173
+ b = factor.size(0)
174
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
175
+ image = image.float() / table
176
+ image = self.rounding(image)
177
+ return image
178
+
179
+
180
+ class CQuantize(nn.Module):
181
+ """ JPEG Quantization for CbCr channels
182
+
183
+ Args:
184
+ rounding(function): rounding function to use
185
+ """
186
+
187
+ def __init__(self, rounding):
188
+ super(CQuantize, self).__init__()
189
+ self.rounding = rounding
190
+ self.c_table = c_table
191
+
192
+ def forward(self, image, factor=1):
193
+ """
194
+ Args:
195
+ image(tensor): batch x height x width
196
+
197
+ Returns:
198
+ Tensor: batch x height x width
199
+ """
200
+ if isinstance(factor, (int, float)):
201
+ image = image.float() / (self.c_table * factor)
202
+ else:
203
+ b = factor.size(0)
204
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
205
+ image = image.float() / table
206
+ image = self.rounding(image)
207
+ return image
208
+
209
+
210
+ class CompressJpeg(nn.Module):
211
+ """Full JPEG compression algorithm
212
+
213
+ Args:
214
+ rounding(function): rounding function to use
215
+ """
216
+
217
+ def __init__(self, rounding=torch.round):
218
+ super(CompressJpeg, self).__init__()
219
+ self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling())
220
+ self.l2 = nn.Sequential(BlockSplitting(), DCT8x8())
221
+ self.c_quantize = CQuantize(rounding=rounding)
222
+ self.y_quantize = YQuantize(rounding=rounding)
223
+
224
+ def forward(self, image, factor=1):
225
+ """
226
+ Args:
227
+ image(tensor): batch x 3 x height x width
228
+
229
+ Returns:
230
+ dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8.
231
+ """
232
+ y, cb, cr = self.l1(image * 255)
233
+ components = {'y': y, 'cb': cb, 'cr': cr}
234
+ for k in components.keys():
235
+ comp = self.l2(components[k])
236
+ if k in ('cb', 'cr'):
237
+ comp = self.c_quantize(comp, factor=factor)
238
+ else:
239
+ comp = self.y_quantize(comp, factor=factor)
240
+
241
+ components[k] = comp
242
+
243
+ return components['y'], components['cb'], components['cr']
244
+
245
+
246
+ # ------------------------ decompression ------------------------#
247
+
248
+
249
+ class YDequantize(nn.Module):
250
+ """Dequantize Y channel
251
+ """
252
+
253
+ def __init__(self):
254
+ super(YDequantize, self).__init__()
255
+ self.y_table = y_table
256
+
257
+ def forward(self, image, factor=1):
258
+ """
259
+ Args:
260
+ image(tensor): batch x height x width
261
+
262
+ Returns:
263
+ Tensor: batch x height x width
264
+ """
265
+ if isinstance(factor, (int, float)):
266
+ out = image * (self.y_table * factor)
267
+ else:
268
+ b = factor.size(0)
269
+ table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
270
+ out = image * table
271
+ return out
272
+
273
+
274
+ class CDequantize(nn.Module):
275
+ """Dequantize CbCr channel
276
+ """
277
+
278
+ def __init__(self):
279
+ super(CDequantize, self).__init__()
280
+ self.c_table = c_table
281
+
282
+ def forward(self, image, factor=1):
283
+ """
284
+ Args:
285
+ image(tensor): batch x height x width
286
+
287
+ Returns:
288
+ Tensor: batch x height x width
289
+ """
290
+ if isinstance(factor, (int, float)):
291
+ out = image * (self.c_table * factor)
292
+ else:
293
+ b = factor.size(0)
294
+ table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1)
295
+ out = image * table
296
+ return out
297
+
298
+
299
+ class iDCT8x8(nn.Module):
300
+ """Inverse discrete Cosine Transformation
301
+ """
302
+
303
+ def __init__(self):
304
+ super(iDCT8x8, self).__init__()
305
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
306
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
307
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
308
+ for x, y, u, v in itertools.product(range(8), repeat=4):
309
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)
310
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
311
+
312
+ def forward(self, image):
313
+ """
314
+ Args:
315
+ image(tensor): batch x height x width
316
+
317
+ Returns:
318
+ Tensor: batch x height x width
319
+ """
320
+ image = image * self.alpha
321
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
322
+ result.view(image.shape)
323
+ return result
324
+
325
+
326
+ class BlockMerging(nn.Module):
327
+ """Merge patches into image
328
+ """
329
+
330
+ def __init__(self):
331
+ super(BlockMerging, self).__init__()
332
+
333
+ def forward(self, patches, height, width):
334
+ """
335
+ Args:
336
+ patches(tensor) batch x height*width/64, height x width
337
+ height(int)
338
+ width(int)
339
+
340
+ Returns:
341
+ Tensor: batch x height x width
342
+ """
343
+ k = 8
344
+ batch_size = patches.shape[0]
345
+ image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
346
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
347
+ return image_transposed.contiguous().view(batch_size, height, width)
348
+
349
+
350
+ class ChromaUpsampling(nn.Module):
351
+ """Upsample chroma layers
352
+ """
353
+
354
+ def __init__(self):
355
+ super(ChromaUpsampling, self).__init__()
356
+
357
+ def forward(self, y, cb, cr):
358
+ """
359
+ Args:
360
+ y(tensor): y channel image
361
+ cb(tensor): cb channel
362
+ cr(tensor): cr channel
363
+
364
+ Returns:
365
+ Tensor: batch x height x width x 3
366
+ """
367
+
368
+ def repeat(x, k=2):
369
+ height, width = x.shape[1:3]
370
+ x = x.unsqueeze(-1)
371
+ x = x.repeat(1, 1, k, k)
372
+ x = x.view(-1, height * k, width * k)
373
+ return x
374
+
375
+ cb = repeat(cb)
376
+ cr = repeat(cr)
377
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
378
+
379
+
380
+ class YCbCr2RGBJpeg(nn.Module):
381
+ """Converts YCbCr image to RGB JPEG
382
+ """
383
+
384
+ def __init__(self):
385
+ super(YCbCr2RGBJpeg, self).__init__()
386
+
387
+ matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
388
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
389
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
390
+
391
+ def forward(self, image):
392
+ """
393
+ Args:
394
+ image(tensor): batch x height x width x 3
395
+
396
+ Returns:
397
+ Tensor: batch x 3 x height x width
398
+ """
399
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
400
+ return result.view(image.shape).permute(0, 3, 1, 2)
401
+
402
+
403
+ class DeCompressJpeg(nn.Module):
404
+ """Full JPEG decompression algorithm
405
+
406
+ Args:
407
+ rounding(function): rounding function to use
408
+ """
409
+
410
+ def __init__(self, rounding=torch.round):
411
+ super(DeCompressJpeg, self).__init__()
412
+ self.c_dequantize = CDequantize()
413
+ self.y_dequantize = YDequantize()
414
+ self.idct = iDCT8x8()
415
+ self.merging = BlockMerging()
416
+ self.chroma = ChromaUpsampling()
417
+ self.colors = YCbCr2RGBJpeg()
418
+
419
+ def forward(self, y, cb, cr, imgh, imgw, factor=1):
420
+ """
421
+ Args:
422
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
423
+ imgh(int)
424
+ imgw(int)
425
+ factor(float)
426
+
427
+ Returns:
428
+ Tensor: batch x 3 x height x width
429
+ """
430
+ components = {'y': y, 'cb': cb, 'cr': cr}
431
+ for k in components.keys():
432
+ if k in ('cb', 'cr'):
433
+ comp = self.c_dequantize(components[k], factor=factor)
434
+ height, width = int(imgh / 2), int(imgw / 2)
435
+ else:
436
+ comp = self.y_dequantize(components[k], factor=factor)
437
+ height, width = imgh, imgw
438
+ comp = self.idct(comp)
439
+ components[k] = self.merging(comp, height, width)
440
+ #
441
+ image = self.chroma(components['y'], components['cb'], components['cr'])
442
+ image = self.colors(image)
443
+
444
+ image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image))
445
+ return image / 255
446
+
447
+
448
+ # ------------------------ main DiffJPEG ------------------------ #
449
+
450
+
451
+ class DiffJPEG(nn.Module):
452
+ """This JPEG algorithm result is slightly different from cv2.
453
+ DiffJPEG supports batch processing.
454
+
455
+ Args:
456
+ differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round
457
+ """
458
+
459
+ def __init__(self, differentiable=True):
460
+ super(DiffJPEG, self).__init__()
461
+ if differentiable:
462
+ rounding = diff_round
463
+ else:
464
+ rounding = torch.round
465
+
466
+ self.compress = CompressJpeg(rounding=rounding)
467
+ self.decompress = DeCompressJpeg(rounding=rounding)
468
+
469
+ def forward(self, x, quality):
470
+ """
471
+ Args:
472
+ x (Tensor): Input image, bchw, rgb, [0, 1]
473
+ quality(float): Quality factor for jpeg compression scheme.
474
+ """
475
+ factor = quality
476
+ if isinstance(factor, (int, float)):
477
+ factor = quality_to_factor(factor)
478
+ else:
479
+ for i in range(factor.size(0)):
480
+ factor[i] = quality_to_factor(factor[i])
481
+ h, w = x.size()[-2:]
482
+ h_pad, w_pad = 0, 0
483
+ # why should use 16
484
+ if h % 16 != 0:
485
+ h_pad = 16 - h % 16
486
+ if w % 16 != 0:
487
+ w_pad = 16 - w % 16
488
+ x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0)
489
+
490
+ y, cb, cr = self.compress(x, factor=factor)
491
+ recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor)
492
+ recovered = recovered[:, :, 0:h, 0:w]
493
+ return recovered
494
+
495
+
496
+ if __name__ == '__main__':
497
+ import cv2
498
+
499
+ from basicsr.utils import img2tensor, tensor2img
500
+
501
+ img_gt = cv2.imread('test.png') / 255.
502
+
503
+ # -------------- cv2 -------------- #
504
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
505
+ _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
506
+ img_lq = np.float32(cv2.imdecode(encimg, 1))
507
+ cv2.imwrite('cv2_JPEG_20.png', img_lq)
508
+
509
+ # -------------- DiffJPEG -------------- #
510
+ jpeger = DiffJPEG(differentiable=False).cuda()
511
+ img_gt = img2tensor(img_gt)
512
+ img_gt = torch.stack([img_gt, img_gt]).cuda()
513
+ quality = img_gt.new_tensor([20, 40])
514
+ out = jpeger(img_gt, quality=quality)
515
+
516
+ cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0]))
517
+ cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1]))
degradation/ESR/usm_sharp.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from torch.nn import functional as F
7
+
8
+ import os, sys
9
+ root_path = os.path.abspath('.')
10
+ sys.path.append(root_path)
11
+ from degradation.ESR.utils import filter2D, np2tensor, tensor2np
12
+
13
+
14
+ def usm_sharp_func(img, weight=0.5, radius=50, threshold=10):
15
+ """USM sharpening.
16
+
17
+ Input image: I; Blurry image: B.
18
+ 1. sharp = I + weight * (I - B)
19
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
20
+ 3. Blur mask:
21
+ 4. Out = Mask * sharp + (1 - Mask) * I
22
+
23
+
24
+ Args:
25
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
26
+ weight (float): Sharp weight. Default: 1.
27
+ radius (float): Kernel size of Gaussian blur. Default: 50.
28
+ threshold (int):
29
+ """
30
+ if radius % 2 == 0:
31
+ radius += 1
32
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
33
+ residual = img - blur
34
+ mask = np.abs(residual) * 255 > threshold
35
+ mask = mask.astype('float32')
36
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
37
+
38
+ sharp = img + weight * residual
39
+ sharp = np.clip(sharp, 0, 1)
40
+ return soft_mask * sharp + (1 - soft_mask) * img
41
+
42
+
43
+
44
+ class USMSharp(torch.nn.Module):
45
+
46
+ def __init__(self, type, radius=50, sigma=0):
47
+ super(USMSharp, self).__init__()
48
+ if radius % 2 == 0:
49
+ radius += 1
50
+ self.radius = radius
51
+ kernel = cv2.getGaussianKernel(radius, sigma)
52
+ kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0).cuda()
53
+ self.register_buffer('kernel', kernel)
54
+
55
+ self.type = type
56
+
57
+
58
+ def forward(self, img, weight=0.5, threshold=10, store=False):
59
+
60
+ if self.type == "cv2":
61
+ # pre-process cv2 type
62
+ img = np2tensor(img)
63
+
64
+ blur = filter2D(img, self.kernel.cuda())
65
+ if store:
66
+ cv2.imwrite("blur.png", tensor2np(blur))
67
+
68
+ residual = img - blur
69
+ if store:
70
+ cv2.imwrite("residual.png", tensor2np(residual))
71
+
72
+ mask = torch.abs(residual) * 255 > threshold
73
+ if store:
74
+ cv2.imwrite("mask.png", tensor2np(mask))
75
+
76
+
77
+ mask = mask.float()
78
+ soft_mask = filter2D(mask, self.kernel.cuda())
79
+ if store:
80
+ cv2.imwrite("soft_mask.png", tensor2np(soft_mask))
81
+
82
+ sharp = img + weight * residual
83
+ sharp = torch.clip(sharp, 0, 1)
84
+ if store:
85
+ cv2.imwrite("sharp.png", tensor2np(sharp))
86
+
87
+ output = soft_mask * sharp + (1 - soft_mask) * img
88
+ if self.type == "cv2":
89
+ output = tensor2np(output)
90
+
91
+ return output
92
+
93
+
94
+
95
+ if __name__ == "__main__":
96
+
97
+ usm_sharper = USMSharp(type="cv2")
98
+ img = cv2.imread("sample3.png")
99
+ print(img.shape)
100
+ sharp_output = usm_sharper(img, store=False, threshold=10)
101
+ cv2.imwrite(os.path.join("output.png"), sharp_output)
102
+
103
+
104
+ # dir = r"C:\Users\HikariDawn\Desktop\Real-CUGAN\datasets\sample"
105
+ # output_dir = r"C:\Users\HikariDawn\Desktop\Real-CUGAN\datasets\sharp_regular"
106
+ # if not os.path.exists(output_dir):
107
+ # os.makedirs(output_dir)
108
+
109
+ # for file_name in sorted(os.listdir(dir)):
110
+ # print(file_name)
111
+ # file = os.path.join(dir, file_name)
112
+ # img = cv2.imread(file)
113
+ # sharp_output = usm_sharper(img)
114
+ # cv2.imwrite(os.path.join(output_dir, file_name), sharp_output)
degradation/ESR/utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ '''
4
+ From ESRGAN
5
+ '''
6
+
7
+
8
+ import os, sys
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from scipy import special
14
+ import random
15
+ import math
16
+ from torchvision.utils import make_grid
17
+
18
+ from degradation.ESR.degradations_functionality import *
19
+
20
+ root_path = os.path.abspath('.')
21
+ sys.path.append(root_path)
22
+
23
+
24
+ def np2tensor(np_frame):
25
+ return torch.from_numpy(np.transpose(np_frame, (2, 0, 1))).unsqueeze(0).cuda().float()/255
26
+
27
+ def tensor2np(tensor):
28
+ # tensor should be batch size1 and cannot be grayscale input
29
+ return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (1, 2, 0))) * 255
30
+
31
+ def mass_tensor2np(tensor):
32
+ ''' The input tensor is massive tensor
33
+ '''
34
+ return (np.transpose(tensor.detach().squeeze(0).cpu().numpy(), (0, 2, 3, 1))) * 255
35
+
36
+ def save_img(tensor, save_name):
37
+ np_img = tensor2np(tensor)[:,:,16]
38
+ # np_img = np.expand_dims(np_img, axis=2)
39
+ cv2.imwrite(save_name, np_img)
40
+
41
+
42
+ def filter2D(img, kernel):
43
+ """PyTorch version of cv2.filter2D
44
+
45
+ Args:
46
+ img (Tensor): (b, c, h, w)
47
+ kernel (Tensor): (b, k, k)
48
+ """
49
+ k = kernel.size(-1)
50
+ b, c, h, w = img.size()
51
+ if k % 2 == 1:
52
+ img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
53
+ else:
54
+ raise ValueError('Wrong kernel size')
55
+
56
+ ph, pw = img.size()[-2:]
57
+
58
+ if kernel.size(0) == 1:
59
+ # apply the same kernel to all batch images
60
+ img = img.view(b * c, 1, ph, pw)
61
+ kernel = kernel.view(1, 1, k, k)
62
+ return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
63
+ else:
64
+ img = img.view(1, b * c, ph, pw)
65
+ kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
66
+ return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
67
+
68
+
69
+ def generate_kernels(opt):
70
+
71
+ kernel_range = [2 * v + 1 for v in range(opt["kernel_range"][0], opt["kernel_range"][1])]
72
+
73
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
74
+ kernel_size = random.choice(kernel_range)
75
+ if np.random.uniform() < opt['sinc_prob']:
76
+ # 里面加一层sinc filter,但是10%的概率
77
+ # this sinc filter setting is for kernels ranging from [7, 21]
78
+ if kernel_size < 13:
79
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
80
+ else:
81
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
82
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
83
+ else:
84
+ kernel = random_mixed_kernels(
85
+ opt['kernel_list'],
86
+ opt['kernel_prob'],
87
+ kernel_size,
88
+ opt['blur_sigma'],
89
+ opt['blur_sigma'], [-math.pi, math.pi],
90
+ opt['betag_range'],
91
+ opt['betap_range'],
92
+ noise_range=None)
93
+ # pad kernel: -在v2我是直接省略了padding
94
+ pad_size = (21 - kernel_size) // 2
95
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
96
+
97
+
98
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
99
+ kernel_size = random.choice(kernel_range)
100
+ if np.random.uniform() < opt['sinc_prob2']:
101
+ # 里面加一层sinc filter,但是10%的概率
102
+ if kernel_size < 13:
103
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
104
+ else:
105
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
106
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
107
+ else:
108
+ kernel2 = random_mixed_kernels(
109
+ opt['kernel_list2'],
110
+ opt['kernel_prob2'],
111
+ kernel_size,
112
+ opt['blur_sigma2'],
113
+ opt['blur_sigma2'], [-math.pi, math.pi],
114
+ opt['betag_range2'],
115
+ opt['betap_range2'],
116
+ noise_range=None)
117
+
118
+ # pad kernel
119
+ pad_size = (21 - kernel_size) // 2
120
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
121
+
122
+ kernel = torch.FloatTensor(kernel)
123
+ kernel2 = torch.FloatTensor(kernel2)
124
+ return (kernel, kernel2)
125
+
126
+
degradation/degradation_esr.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import os
4
+ import sys
5
+ import torch.nn.functional as F
6
+
7
+ root_path = os.path.abspath('.')
8
+ sys.path.append(root_path)
9
+ # Import files from the local folder
10
+ from opt import opt
11
+ from degradation.ESR.utils import generate_kernels, mass_tensor2np, tensor2np
12
+ from degradation.ESR.degradations_functionality import *
13
+ from degradation.ESR.degradation_esr_shared import common_degradation as regular_common_degradation
14
+ from degradation.image_compression.jpeg import JPEG # 这里最好后面用一个继承解决一切
15
+ from degradation.image_compression.webp import WEBP
16
+ from degradation.image_compression.heif import HEIF
17
+ from degradation.image_compression.avif import AVIF
18
+ from degradation.video_compression.h264 import H264
19
+ from degradation.video_compression.h265 import H265
20
+ from degradation.video_compression.mpeg2 import MPEG2
21
+ from degradation.video_compression.mpeg4 import MPEG4
22
+
23
+
24
+ class degradation_v1:
25
+ def __init__(self):
26
+ self.kernel1, self.kernel2, self.sinc_kernel = None, None, None
27
+ self.queue_size = 160
28
+
29
+ # Init the compression instance
30
+ self.jpeg_instance = JPEG()
31
+ self.webp_instance = WEBP()
32
+ # self.heif_instance = HEIF()
33
+ self.avif_instance = AVIF()
34
+ self.H264_instance = H264()
35
+ self.H265_instance = H265()
36
+ self.MPEG2_instance = MPEG2()
37
+ self.MPEG4_instance = MPEG4()
38
+
39
+
40
+ def reset_kernels(self, opt):
41
+ kernel1, kernel2 = generate_kernels(opt)
42
+ self.kernel1 = kernel1.unsqueeze(0).cuda()
43
+ self.kernel2 = kernel2.unsqueeze(0).cuda()
44
+
45
+
46
+ @torch.no_grad()
47
+ def degradate_process(self, out, opt, store_path, process_id, verbose = False):
48
+ ''' ESR Degradation V1 mode (Same as the original paper)
49
+ Args:
50
+ out (tensor): BxCxHxW All input images as tensor
51
+ opt (dict): All configuration we need to process
52
+ store_path (str): Store Directory
53
+ process_id (int): The id we used to store temporary file
54
+ verbose (bool): Whether print some information for auxiliary log (default: False)
55
+ '''
56
+
57
+ batch_size, _, ori_h, ori_w = out.size()
58
+
59
+ # Shared degradation until the last step
60
+ resize_mode = random.choice(opt['resize_options'])
61
+ out = regular_common_degradation(out, opt, [self.kernel1, self.kernel2], process_id, verbose=verbose)
62
+
63
+
64
+ # Resize back
65
+ out = F.interpolate(out, size=(ori_h // opt['scale'], ori_w // opt['scale']), mode = resize_mode)
66
+ out = torch.clamp(out, 0, 1)
67
+ # TODO: 可能Tensor2Numpy会放在之前,而不是在这里,一起转换节约时间
68
+
69
+ # Tensor2np
70
+ np_frame = tensor2np(out)
71
+
72
+ # Choose an image compression codec (All degradation batch use the same codec)
73
+ compression_codec = random.choices(opt['compression_codec2'], opt['compression_codec_prob2'])[0] # All lower case
74
+
75
+ if compression_codec == "jpeg":
76
+ self.jpeg_instance.compress_and_store(np_frame, store_path, process_id)
77
+
78
+ elif compression_codec == "webp":
79
+ try:
80
+ self.webp_instance.compress_and_store(np_frame, store_path, process_id)
81
+ except Exception:
82
+ print("There appears to be exception in webp again!")
83
+ if os.path.exists(store_path):
84
+ os.remove(store_path)
85
+ self.webp_instance.compress_and_store(np_frame, store_path, process_id)
86
+
87
+ elif compression_codec == "avif":
88
+ self.avif_instance.compress_and_store(np_frame, store_path, process_id)
89
+
90
+ elif compression_codec == "h264":
91
+ self.H264_instance.compress_and_store(np_frame, store_path, process_id)
92
+
93
+ elif compression_codec == "h265":
94
+ self.H265_instance.compress_and_store(np_frame, store_path, process_id)
95
+
96
+ elif compression_codec == "mpeg2":
97
+ self.MPEG2_instance.compress_and_store(np_frame, store_path, process_id)
98
+
99
+ elif compression_codec == "mpeg4":
100
+ self.MPEG4_instance.compress_and_store(np_frame, store_path, process_id)
101
+
102
+ else:
103
+ raise NotImplementedError("This compression codec is not supported! Please check the implementation!")
104
+
105
+
106
+
107
+
108
+
109
+
110
+
degradation/image_compression/avif.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, sys, os, random
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import cv2
5
+ from multiprocessing import Process, Queue
6
+ from PIL import Image
7
+ import pillow_heif
8
+
9
+ root_path = os.path.abspath('.')
10
+ sys.path.append(root_path)
11
+ # Import files from the local folder
12
+ from opt import opt
13
+ from degradation.ESR.utils import tensor2np, np2tensor
14
+
15
+
16
+
17
+ class AVIF():
18
+ def __init__(self) -> None:
19
+ # Choose an image compression degradation
20
+ pass
21
+
22
+ def compress_and_store(self, np_frames, store_path, idx):
23
+ ''' Compress and Store the whole batch as AVIF (~ AV1)
24
+ Args:
25
+ np_frames (numpy): The numpy format of the data (Shape:?)
26
+ store_path (str): The store path
27
+ Return:
28
+ None
29
+ '''
30
+ # Init call for avif
31
+ pillow_heif.register_avif_opener()
32
+
33
+
34
+ single_frame = np_frames
35
+
36
+ # Prepare
37
+ essential_name = "tmp/temp_"+str(idx)
38
+
39
+ # Choose the quality
40
+ quality = random.randint(*opt['avif_quality_range2'])
41
+ method = random.randint(*opt['avif_encode_speed2'])
42
+
43
+ # Transform to PIL and then compress
44
+ PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
45
+ PIL_image.save(essential_name+'.avif', quality=quality, method=method)
46
+
47
+ # Read as png
48
+ avif_file = pillow_heif.open_heif(essential_name+'.avif', convert_hdr_to_8bit=False, bgr_mode=True)
49
+ np_array = np.asarray(avif_file)
50
+ cv2.imwrite(store_path, np_array)
51
+
52
+ os.remove(essential_name+'.avif')
53
+
54
+
55
+
56
+ @staticmethod
57
+ def compress_tensor(tensor_frames, idx=0):
58
+ ''' Compress tensor input to AVIF and then return it
59
+ Args:
60
+ tensor_frame (tensor): Tensor inputs
61
+ Returns:
62
+ result (tensor): Tensor outputs (same shape as input)
63
+ '''
64
+ # Init call for avif
65
+ pillow_heif.register_avif_opener()
66
+
67
+ # Prepare
68
+ single_frame = tensor2np(tensor_frames)
69
+ essential_name = "tmp/temp_"+str(idx)
70
+
71
+ # Choose the quality
72
+ quality = random.randint(*opt['avif_quality_range1'])
73
+ method = random.randint(*opt['avif_encode_speed1'])
74
+
75
+ # Transform to PIL and then compress
76
+ PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
77
+ PIL_image.save(essential_name+'.avif', quality=quality, method=method)
78
+
79
+ # Transform as png format
80
+ avif_file = pillow_heif.open_heif(essential_name+'.avif', convert_hdr_to_8bit=False, bgr_mode=True)
81
+ decimg = np.asarray(avif_file)
82
+ os.remove(essential_name+'.avif')
83
+
84
+ # Read back
85
+ result = np2tensor(decimg)
86
+
87
+
88
+ return result
degradation/image_compression/heif.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, sys, os, random
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import cv2
5
+ from multiprocessing import Process, Queue
6
+ from PIL import Image
7
+ from pillow_heif import register_heif_opener
8
+ import pillow_heif
9
+
10
+ root_path = os.path.abspath('.')
11
+ sys.path.append(root_path)
12
+ # Import files from the local folder
13
+ from opt import opt
14
+ from degradation.ESR.utils import tensor2np, np2tensor
15
+
16
+
17
+
18
+
19
+ class HEIF():
20
+ def __init__(self) -> None:
21
+ # Choose an image compression degradation
22
+ pass
23
+
24
+ def compress_and_store(self, np_frames, store_path):
25
+ ''' Compress and Store the whole batch as HEIF (~ HEVC)
26
+ Args:
27
+ np_frames (numpy): The numpy format of the data (Shape:?)
28
+ store_path (str): The store path
29
+ Return:
30
+ None
31
+ '''
32
+ # Init call for heif
33
+ register_heif_opener()
34
+
35
+ single_frame = np_frames
36
+
37
+ # Prepare
38
+ essential_name = store_path.split('.')[0]
39
+
40
+ # Choose the quality
41
+ quality = random.randint(*opt['heif_quality_range1'])
42
+ method = random.randint(*opt['heif_encode_speed1'])
43
+
44
+ # Transform to PIL and then compress
45
+ PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
46
+ PIL_image.save(essential_name+'.heic', quality=quality, method=method)
47
+
48
+ # Transform as png format
49
+ heif_file = pillow_heif.open_heif(essential_name+'.heic', convert_hdr_to_8bit=False, bgr_mode=True)
50
+ np_array = np.asarray(heif_file)
51
+ cv2.imwrite(store_path, np_array)
52
+
53
+ os.remove(essential_name+'.heic')
54
+
55
+
56
+ @staticmethod
57
+ def compress_tensor(tensor_frames, idx=0):
58
+ ''' Compress tensor input to HEIF and then return it
59
+ Args:
60
+ tensor_frame (tensor): Tensor inputs
61
+ Returns:
62
+ result (tensor): Tensor outputs (same shape as input)
63
+ '''
64
+
65
+ # Init call for heif
66
+ register_heif_opener()
67
+
68
+ # Prepare
69
+ single_frame = tensor2np(tensor_frames)
70
+ essential_name = "tmp/temp_"+str(idx)
71
+
72
+ # Choose the quality
73
+ quality = random.randint(*opt['heif_quality_range1'])
74
+ method = random.randint(*opt['heif_encode_speed1'])
75
+
76
+ # Transform to PIL and then compress
77
+ PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
78
+ PIL_image.save(essential_name+'.heic', quality=quality, method=method)
79
+
80
+ # Transform as png format
81
+ heif_file = pillow_heif.open_heif(essential_name+'.heic', convert_hdr_to_8bit=False, bgr_mode=True)
82
+ decimg = np.asarray(heif_file)
83
+ os.remove(essential_name+'.heic')
84
+
85
+ # Read back
86
+ result = np2tensor(decimg)
87
+
88
+ return result
89
+
90
+
degradation/image_compression/jpeg.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, random
2
+ import cv2, torch
3
+ from multiprocessing import Process, Queue
4
+
5
+ root_path = os.path.abspath('.')
6
+ sys.path.append(root_path)
7
+ # Import files from the local folder
8
+ from opt import opt
9
+ from degradation.ESR.utils import tensor2np, np2tensor
10
+
11
+
12
+
13
+ class JPEG():
14
+ def __init__(self) -> None:
15
+ # Choose an image compression degradation
16
+ # self.jpeger = DiffJPEG(differentiable=False).cuda()
17
+ pass
18
+
19
+ def compress_and_store(self, np_frames, store_path, idx):
20
+ ''' Compress and Store the whole batch as JPEG
21
+ Args:
22
+ np_frames (numpy): The numpy format of the data (Shape:?)
23
+ store_path (str): The store path
24
+ Return:
25
+ None
26
+ '''
27
+
28
+ # Preparation
29
+ single_frame = np_frames
30
+
31
+ # Compress as JPEG
32
+ jpeg_quality = random.randint(*opt['jpeg_quality_range2'])
33
+
34
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
35
+ _, encimg = cv2.imencode('.jpg', single_frame, encode_param)
36
+ decimg = cv2.imdecode(encimg, 1)
37
+
38
+ # Store the image with quality
39
+ cv2.imwrite(store_path, decimg)
40
+
41
+
42
+
43
+ @staticmethod
44
+ def compress_tensor(tensor_frames):
45
+ ''' Compress tensor input to JPEG and then return it
46
+ Args:
47
+ tensor_frame (tensor): Tensor inputs
48
+ Returns:
49
+ result (tensor): Tensor outputs (same shape as input)
50
+ '''
51
+
52
+ single_frame = tensor2np(tensor_frames)
53
+
54
+ # Compress as JPEG
55
+ jpeg_quality = random.randint(*opt['jpeg_quality_range1'])
56
+
57
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality]
58
+ _, encimg = cv2.imencode('.jpg', single_frame, encode_param)
59
+ decimg = cv2.imdecode(encimg, 1)
60
+
61
+ # Store the image with quality
62
+ # cv2.imwrite(store_name, decimg)
63
+ result = np2tensor(decimg)
64
+
65
+ return result
66
+
67
+
68
+
degradation/image_compression/webp.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, sys, os, random
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ import cv2
5
+ from multiprocessing import Process, Queue
6
+ from PIL import Image
7
+
8
+ root_path = os.path.abspath('.')
9
+ sys.path.append(root_path)
10
+ # Import files from the local folder
11
+ from opt import opt
12
+ from degradation.ESR.utils import tensor2np, np2tensor
13
+
14
+
15
+
16
+
17
+ class WEBP():
18
+ def __init__(self) -> None:
19
+ # Choose an image compression degradation
20
+ pass
21
+
22
+ def compress_and_store(self, np_frames, store_path, idx):
23
+ ''' Compress and Store the whole batch as WebP (~ VP8)
24
+ Args:
25
+ np_frames (numpy): The numpy format of the data (Shape:?)
26
+ store_path (str): The store path
27
+ Return:
28
+ None
29
+ '''
30
+ single_frame = np_frames
31
+
32
+ # Choose the quality
33
+ quality = random.randint(*opt['webp_quality_range2'])
34
+ method = random.randint(*opt['webp_encode_speed2'])
35
+
36
+ # Transform to PIL and then compress
37
+ PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
38
+ PIL_image.save(store_path, 'webp', quality=quality, method=method)
39
+
40
+
41
+ @staticmethod
42
+ def compress_tensor(tensor_frames, idx = 0):
43
+ ''' Compress tensor input to WEBP and then return it
44
+ Args:
45
+ tensor_frame (tensor): Tensor inputs
46
+ Returns:
47
+ result (tensor): Tensor outputs (same shape as input)
48
+ '''
49
+ single_frame = tensor2np(tensor_frames)
50
+
51
+ # Choose the quality
52
+ quality = random.randint(*opt['webp_quality_range1'])
53
+ method = random.randint(*opt['webp_encode_speed1'])
54
+
55
+ # Transform to PIL and then compress
56
+ PIL_image = Image.fromarray(np.uint8(single_frame[...,::-1])).convert('RGB')
57
+ store_path = os.path.join("tmp", "temp_"+str(idx)+".webp")
58
+ PIL_image.save(store_path, 'webp', quality=quality, method=method)
59
+
60
+ # Read back
61
+ decimg = cv2.imread(store_path)
62
+ result = np2tensor(decimg)
63
+ os.remove(store_path)
64
+
65
+ return result
degradation/video_compression/h264.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, sys, os, random
2
+ import cv2
3
+ import shutil
4
+
5
+ root_path = os.path.abspath('.')
6
+ sys.path.append(root_path)
7
+ # Import files from the local folder
8
+ from opt import opt
9
+
10
+
11
+
12
+ class H264():
13
+ def __init__(self) -> None:
14
+ # Choose an image compression degradation
15
+ pass
16
+
17
+ def compress_and_store(self, single_frame, store_path, idx):
18
+ ''' Compress and Store the whole batch as H.264 (for 2nd stage)
19
+ Args:
20
+ single_frame (numpy): The numpy format of the data (Shape:?)
21
+ store_path (str): The store path
22
+ idx (int): A unique process idx
23
+ Return:
24
+ None
25
+ '''
26
+
27
+ # Prepare
28
+ temp_input_path = "tmp/input_"+str(idx)
29
+ video_store_dir = "tmp/encoded_"+str(idx)+".mp4"
30
+ temp_store_path = "tmp/output_"+str(idx)
31
+ os.makedirs(temp_input_path)
32
+ os.makedirs(temp_store_path)
33
+
34
+ # Move frame
35
+ cv2.imwrite(os.path.join(temp_input_path, "1.png"), single_frame)
36
+
37
+
38
+ # Decide the quality
39
+ crf = str(random.randint(*opt['h264_crf_range2']))
40
+ preset = random.choices(opt['h264_preset_mode2'], opt['h264_preset_prob2'])[0]
41
+
42
+ # Encode
43
+ ffmpeg_encode_cmd = "ffmpeg -i " + temp_input_path + "/%d.png -vcodec libx264 -crf " + crf + " -preset " + preset + " -pix_fmt yuv420p " + video_store_dir + " -loglevel 0"
44
+ os.system(ffmpeg_encode_cmd)
45
+
46
+
47
+ # Decode
48
+ ffmpeg_decode_cmd = "ffmpeg -i " + video_store_dir + " " + temp_store_path + "/%d.png -loglevel 0"
49
+ os.system(ffmpeg_decode_cmd)
50
+ if len(os.listdir(temp_store_path)) != 1:
51
+ print("This is strange")
52
+ assert(len(os.listdir(temp_store_path)) == 1)
53
+
54
+ # Move frame to the target places
55
+ shutil.copy(os.path.join(temp_store_path, "1.png"), store_path)
56
+
57
+ # Clean temp files
58
+ os.remove(video_store_dir)
59
+ shutil.rmtree(temp_input_path)
60
+ shutil.rmtree(temp_store_path)
61
+
62
+
63
+
64
+ @staticmethod
65
+ def compress_tensor(tensor_frames, idx=0):
66
+ ''' Compress tensor input to H.264 and then return it (for 1st stage)
67
+ Args:
68
+ tensor_frame (tensor): Tensor inputs
69
+ Returns:
70
+ result (tensor): Tensor outputs (same shape as input)
71
+ '''
72
+
73
+ pass