Siwon123 commited on
Commit
7f43945
·
1 Parent(s): f9a4268
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gradio/certificate.pem +31 -0
  3. __pycache__/inference.cpython-310.pyc +0 -0
  4. app.py +42 -4
  5. ckpt/epoch_287.pth +3 -0
  6. inference.py +59 -0
  7. text_net/DGRN.py +232 -0
  8. text_net/__pycache__/DGRN.cpython-310.pyc +0 -0
  9. text_net/__pycache__/DGRN.cpython-38.pyc +0 -0
  10. text_net/__pycache__/deform_conv.cpython-310.pyc +0 -0
  11. text_net/__pycache__/deform_conv.cpython-36.pyc +0 -0
  12. text_net/__pycache__/deform_conv.cpython-38.pyc +0 -0
  13. text_net/__pycache__/encoder.cpython-310.pyc +0 -0
  14. text_net/__pycache__/encoder.cpython-36.pyc +0 -0
  15. text_net/__pycache__/encoder.cpython-38.pyc +0 -0
  16. text_net/__pycache__/moco.cpython-310.pyc +0 -0
  17. text_net/__pycache__/moco.cpython-36.pyc +0 -0
  18. text_net/__pycache__/moco.cpython-38.pyc +0 -0
  19. text_net/__pycache__/model.cpython-310.pyc +0 -0
  20. text_net/__pycache__/model.cpython-36.pyc +0 -0
  21. text_net/__pycache__/model.cpython-38.pyc +0 -0
  22. text_net/deform_conv.py +65 -0
  23. text_net/encoder.py +67 -0
  24. text_net/moco.py +166 -0
  25. text_net/model.py +29 -0
  26. utils/.DS_Store +0 -0
  27. utils/__init__.py +0 -0
  28. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  29. utils/__pycache__/__init__.cpython-36.pyc +0 -0
  30. utils/__pycache__/__init__.cpython-38.pyc +0 -0
  31. utils/__pycache__/dataset_utils.cpython-310.pyc +0 -0
  32. utils/__pycache__/dataset_utils.cpython-36.pyc +0 -0
  33. utils/__pycache__/dataset_utils.cpython-38.pyc +0 -0
  34. utils/__pycache__/dataset_utils_CDD.cpython-310.pyc +0 -0
  35. utils/__pycache__/degradation_utils.cpython-310.pyc +0 -0
  36. utils/__pycache__/degradation_utils.cpython-36.pyc +0 -0
  37. utils/__pycache__/degradation_utils.cpython-38.pyc +0 -0
  38. utils/__pycache__/image_io.cpython-310.pyc +0 -0
  39. utils/__pycache__/image_io.cpython-36.pyc +0 -0
  40. utils/__pycache__/image_io.cpython-38.pyc +0 -0
  41. utils/__pycache__/image_utils.cpython-310.pyc +0 -0
  42. utils/__pycache__/image_utils.cpython-36.pyc +0 -0
  43. utils/__pycache__/image_utils.cpython-38.pyc +0 -0
  44. utils/__pycache__/imresize.cpython-36.pyc +0 -0
  45. utils/__pycache__/imresize.cpython-38.pyc +0 -0
  46. utils/__pycache__/loss_utils.cpython-38.pyc +0 -0
  47. utils/__pycache__/val_utils.cpython-310.pyc +0 -0
  48. utils/__pycache__/val_utils.cpython-36.pyc +0 -0
  49. utils/__pycache__/val_utils.cpython-38.pyc +0 -0
  50. utils/dataset_utils.py +309 -0
.gitattributes CHANGED
@@ -19,6 +19,7 @@
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
 
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.png filter=lfs diff=lfs merge=lfs -text
23
  *.pt filter=lfs diff=lfs merge=lfs -text
24
  *.pth filter=lfs diff=lfs merge=lfs -text
25
  *.rar filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
__pycache__/inference.cpython-310.pyc ADDED
Binary file (1.99 kB). View file
 
app.py CHANGED
@@ -1,7 +1,45 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs=[gr.components.Image(), "Text Instruction"], outputs=gr.components.Image())
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from inference import infer
3
 
4
+ def greet(image, prompt):
5
+ restore_img = infer(img=image, text_prompt=prompt)
6
+ return restore_img
7
 
8
+
9
+ title = "🖼️ ICDR 🖼️"
10
+ description = ''' ## ICDR: Image Restoration Framework for Composite Degradation following Human Instructions
11
+ Our Github : https://github.com/
12
+
13
+ Siwon Kim, Donghyeon Yoon
14
+
15
+ Ajou Univ
16
+ '''
17
+
18
+
19
+ article = "<p style='text-align: center'><a href='https://github.com/' target='_blank'>ICDR</a></p>"
20
+
21
+ #### Image,Prompts examples
22
+ examples = [['input/00010.png', "I love this photo, could you remove the haze and more brighter?"],
23
+ ['input/00058.png', "I have to post an emotional shot on Instagram, but it was shot too foggy and too dark. Change it like a sunny day and brighten it up!"]]
24
+
25
+ css = """
26
+ .image-frame img, .image-container img {
27
+ width: auto;
28
+ height: auto;
29
+ max-width: none;
30
+ }
31
+ """
32
+
33
+
34
+ demo = gr.Interface(
35
+ fn=greet,
36
+ inputs=[gr.Image(type="pil", label="Input"),
37
+ gr.Text(label="Prompt") ],
38
+ outputs=[gr.Image(type="pil", label="Ouput")],
39
+ title=title,
40
+ description=description,
41
+ article=article,
42
+ examples=examples,
43
+ css=css,
44
+ )
45
+ demo.launch(share=True)
ckpt/epoch_287.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db279692728bd4614759c08a0478d9d07200768e5fb7fa893e78aaa05f3ca707
3
+ size 48705338
inference.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+ from utils.dataset_utils_CDD import DerainDehazeDataset
10
+ from utils.val_utils import AverageMeter, compute_psnr_ssim
11
+ from utils.image_io import save_image_tensor
12
+
13
+ from text_net.model import AirNet
14
+
15
+ def test_Derain_Dehaze(opt, net, dataset, task="derain"):
16
+ output_path = opt.output_path + task + '/'
17
+ subprocess.check_output(['mkdir', '-p', output_path])
18
+
19
+ # dataset.set_dataset(task)
20
+ testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
21
+ print(len(testloader))
22
+
23
+ with torch.no_grad():
24
+ for ([degraded_name], degradation, degrad_patch, clean_patch, text_prompt) in tqdm(testloader):
25
+ degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda()
26
+ restored = net(x_query=degrad_patch, x_key=degrad_patch, text_prompt = text_prompt)
27
+
28
+ return save_image_tensor(restored)
29
+
30
+
31
+ def infer(text_prompt = "", img=None):
32
+ parser = argparse.ArgumentParser()
33
+ # Input Parameters
34
+ parser.add_argument('--cuda', type=int, default=0)
35
+ parser.add_argument('--derain_path', type=str, default="data/Test_prompting/", help='save path of test raining images')
36
+ parser.add_argument('--output_path', type=str, default="output/demo11", help='output save path')
37
+ parser.add_argument('--ckpt_path', type=str, default="ckpt/epoch_287.pth", help='checkpoint save path')
38
+ # parser.add_argument('--text_prompt', type=str, default="derain")
39
+
40
+ opt = parser.parse_args()
41
+ # opt.text_prompt = text_prompt
42
+
43
+ np.random.seed(0)
44
+ torch.manual_seed(0)
45
+ torch.cuda.set_device(opt.cuda)
46
+
47
+ opt.batch_size = 7
48
+ ckpt_path = opt.ckpt_path
49
+
50
+ derain_set = DerainDehazeDataset(opt, img=img, text_prompt = text_prompt)
51
+
52
+ # Make network
53
+ net = AirNet(opt).cuda()
54
+ net.eval()
55
+ net.load_state_dict(torch.load(ckpt_path, map_location=torch.device(opt.cuda)))
56
+
57
+ restored = test_Derain_Dehaze(opt, net, derain_set, task="derain")
58
+
59
+ return restored
text_net/DGRN.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from .deform_conv import DCN_layer
4
+ import clip
5
+
6
+ clip_model, preprocess = clip.load("ViT-B/32", device='cuda')
7
+
8
+ # 동적으로 텍스트 임베딩 차원 가져오기
9
+ text_embed_dim = clip_model.text_projection.shape[1]
10
+
11
+
12
+ def default_conv(in_channels, out_channels, kernel_size, bias=True):
13
+ return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
14
+
15
+
16
+ class DGM(nn.Module):
17
+ def __init__(self, channels_in, channels_out, kernel_size):
18
+ super(DGM, self).__init__()
19
+ self.channels_out = channels_out
20
+ self.channels_in = channels_in
21
+ self.kernel_size = kernel_size
22
+
23
+ self.dcn = DCN_layer(self.channels_in, self.channels_out, kernel_size,
24
+ padding=(kernel_size - 1) // 2, bias=False)
25
+ self.sft = SFT_layer(self.channels_in, self.channels_out)
26
+
27
+ self.relu = nn.LeakyReLU(0.1, True)
28
+
29
+ def forward(self, x, inter, text_prompt):
30
+ '''
31
+ :param x: feature map: B * C * H * W
32
+ :inter: degradation map: B * C * H * W
33
+ '''
34
+ dcn_out = self.dcn(x, inter)
35
+ sft_out = self.sft(x, inter, text_prompt)
36
+ out = dcn_out + sft_out
37
+ out = x + out
38
+
39
+ return out
40
+
41
+ # Projection Head 정의
42
+ class TextProjectionHead(nn.Module):
43
+ def __init__(self, input_dim, output_dim):
44
+ super(TextProjectionHead, self).__init__()
45
+ self.proj = nn.Sequential(
46
+ nn.Linear(input_dim, output_dim),
47
+ nn.ReLU(),
48
+ nn.Linear(output_dim, output_dim)
49
+ ).float()
50
+
51
+ def forward(self, x):
52
+ return self.proj(x.float())
53
+
54
+
55
+
56
+ class SFT_layer(nn.Module):
57
+ def __init__(self, channels_in, channels_out):
58
+ super(SFT_layer, self).__init__()
59
+ self.conv_gamma = nn.Sequential(
60
+ nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False),
61
+ nn.LeakyReLU(0.1, True),
62
+ nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
63
+ )
64
+ self.conv_beta = nn.Sequential(
65
+ nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False),
66
+ nn.LeakyReLU(0.1, True),
67
+ nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
68
+ )
69
+
70
+ self.text_proj_head = TextProjectionHead(text_embed_dim, channels_out)
71
+
72
+ '''
73
+ self.text_gamma = nn.Sequential(
74
+ nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
75
+ nn.LeakyReLU(0.1, True),
76
+ nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
77
+ ).float()
78
+ self.text_beta = nn.Sequential(
79
+ nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
80
+ nn.LeakyReLU(0.1, True),
81
+ nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
82
+ ).float()
83
+ '''
84
+
85
+ self.cross_attention = nn.MultiheadAttention(embed_dim=channels_out, num_heads=2)
86
+
87
+
88
+ def forward(self, x, inter, text_prompt):
89
+ '''
90
+ :param x: degradation representation: B * C
91
+ :param inter: degradation intermediate representation map: B * C * H * W
92
+ '''
93
+ # img_gamma = self.conv_gamma(inter)
94
+ # img_beta = self.conv_beta(inter)
95
+
96
+ B, C, H, W = inter.shape #cross attention
97
+
98
+
99
+ text_tokens = clip.tokenize(text_prompt).to(x.device) # Tokenize the text prompts (Batch size)
100
+ with torch.no_grad():
101
+ text_embed = clip_model.encode_text(text_tokens)
102
+
103
+ text_proj = self.text_proj_head(text_embed).float()
104
+
105
+ # 텍스트 임베딩 차원 확장: (B, C, H, W)로 변경 #concat
106
+ # text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, self.conv_gamma[0].out_channels, H, W)
107
+ text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W)
108
+
109
+ # 이미지 중간 표현과 텍스트 임베딩 결합 (concat)
110
+ combined = inter * text_proj_expanded
111
+ # combined = torch.cat([inter, text_proj_expanded], dim=1)
112
+
113
+ # 이미지와 텍스트 기반 gamma와 beta 계산
114
+ img_gamma = self.conv_gamma(combined)
115
+ img_beta = self.conv_beta(combined)
116
+
117
+ ''' simple concat
118
+ text_gamma = self.text_gamma(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W)
119
+ text_beta = self.text_beta(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W)
120
+ '''
121
+
122
+ '''
123
+ text_proj = text_proj.unsqueeze(1).expand(-1, H*W, -1) # B * (H*W) * C
124
+
125
+ # 이미지 중간 표현 변환: B * (H*W) * C로 변경
126
+ inter_flat = inter.view(B, C, -1).permute(2, 0, 1) # (H*W) * B * C
127
+
128
+ # Cross-attention 적용
129
+ attn_output, _ = self.cross_attention(text_proj.permute(1, 0, 2), inter_flat, inter_flat)
130
+ attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) # B * C * H * W
131
+
132
+ # Gamma와 Beta 계산
133
+ img_gamma = self.conv_gamma(attn_output)
134
+ img_beta = self.conv_beta(attn_output)
135
+ '''
136
+ # concat으로 text 결합 실험
137
+ return x * img_gamma + img_beta
138
+
139
+
140
+ class DGB(nn.Module):
141
+ def __init__(self, conv, n_feat, kernel_size):
142
+ super(DGB, self).__init__()
143
+
144
+ # self.da_conv1 = DGM(n_feat, n_feat, kernel_size)
145
+ # self.da_conv2 = DGM(n_feat, n_feat, kernel_size)
146
+ self.dgm1 = DGM(n_feat, n_feat, kernel_size)
147
+ self.dgm2 = DGM(n_feat, n_feat, kernel_size)
148
+ self.conv1 = conv(n_feat, n_feat, kernel_size)
149
+ self.conv2 = conv(n_feat, n_feat, kernel_size)
150
+
151
+ self.relu = nn.LeakyReLU(0.1, True)
152
+
153
+ def forward(self, x, inter, text_prompt):
154
+ '''
155
+ :param x: feature map: B * C * H * W
156
+ :param inter: degradation representation: B * C * H * W
157
+ '''
158
+
159
+ out = self.relu(self.dgm1(x, inter, text_prompt))
160
+ out = self.relu(self.conv1(out))
161
+ out = self.relu(self.dgm2(out, inter, text_prompt))
162
+ out = self.conv2(out) + x
163
+
164
+ return out
165
+
166
+
167
+ class DGG(nn.Module):
168
+ def __init__(self, conv, n_feat, kernel_size, n_blocks):
169
+ super(DGG, self).__init__()
170
+ self.n_blocks = n_blocks
171
+ modules_body = [
172
+ DGB(conv, n_feat, kernel_size) \
173
+ for _ in range(n_blocks)
174
+ ]
175
+ modules_body.append(conv(n_feat, n_feat, kernel_size))
176
+
177
+ self.body = nn.Sequential(*modules_body)
178
+
179
+ def forward(self, x, inter, text_prompt):
180
+ '''
181
+ :param x: feature map: B * C * H * W
182
+ :param inter: degradation representation: B * C * H * W
183
+ '''
184
+ res = x
185
+ for i in range(self.n_blocks):
186
+ res = self.body[i](res, inter, text_prompt)
187
+ res = self.body[-1](res)
188
+ res = res + x
189
+
190
+ return res
191
+
192
+
193
+ class DGRN(nn.Module):
194
+ def __init__(self, opt, conv=default_conv):
195
+ super(DGRN, self).__init__()
196
+
197
+ self.n_groups = 5
198
+ n_blocks = 5
199
+ n_feats = 64
200
+ kernel_size = 3
201
+
202
+ # head module
203
+ modules_head = [conv(3, n_feats, kernel_size)]
204
+ self.head = nn.Sequential(*modules_head)
205
+
206
+ # body
207
+ modules_body = [
208
+ DGG(default_conv, n_feats, kernel_size, n_blocks) \
209
+ for _ in range(self.n_groups)
210
+ ]
211
+ modules_body.append(conv(n_feats, n_feats, kernel_size))
212
+ self.body = nn.Sequential(*modules_body)
213
+
214
+ # tail
215
+ modules_tail = [conv(n_feats, 3, kernel_size)]
216
+ self.tail = nn.Sequential(*modules_tail)
217
+
218
+ def forward(self, x, inter, text_prompt):
219
+ # head
220
+ x = self.head(x)
221
+
222
+ # body
223
+ res = x
224
+ for i in range(self.n_groups):
225
+ res = self.body[i](res, inter, text_prompt)
226
+ res = self.body[-1](res)
227
+ res = res + x
228
+
229
+ # tail
230
+ x = self.tail(res)
231
+
232
+ return x
text_net/__pycache__/DGRN.cpython-310.pyc ADDED
Binary file (5.61 kB). View file
 
text_net/__pycache__/DGRN.cpython-38.pyc ADDED
Binary file (4.53 kB). View file
 
text_net/__pycache__/deform_conv.cpython-310.pyc ADDED
Binary file (2.2 kB). View file
 
text_net/__pycache__/deform_conv.cpython-36.pyc ADDED
Binary file (2.14 kB). View file
 
text_net/__pycache__/deform_conv.cpython-38.pyc ADDED
Binary file (2.21 kB). View file
 
text_net/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (2.33 kB). View file
 
text_net/__pycache__/encoder.cpython-36.pyc ADDED
Binary file (2.38 kB). View file
 
text_net/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (2.36 kB). View file
 
text_net/__pycache__/moco.cpython-310.pyc ADDED
Binary file (4.43 kB). View file
 
text_net/__pycache__/moco.cpython-36.pyc ADDED
Binary file (4.39 kB). View file
 
text_net/__pycache__/moco.cpython-38.pyc ADDED
Binary file (4.43 kB). View file
 
text_net/__pycache__/model.cpython-310.pyc ADDED
Binary file (936 Bytes). View file
 
text_net/__pycache__/model.cpython-36.pyc ADDED
Binary file (914 Bytes). View file
 
text_net/__pycache__/model.cpython-38.pyc ADDED
Binary file (916 Bytes). View file
 
text_net/deform_conv.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn.modules.utils import _pair
6
+
7
+ from mmcv.ops import modulated_deform_conv2d
8
+
9
+
10
+ class DCN_layer(nn.Module):
11
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
12
+ groups=1, deformable_groups=1, bias=True, extra_offset_mask=True):
13
+ super(DCN_layer, self).__init__()
14
+ self.in_channels = in_channels
15
+ self.out_channels = out_channels
16
+ self.kernel_size = _pair(kernel_size)
17
+ self.stride = stride
18
+ self.padding = padding
19
+ self.dilation = dilation
20
+ self.groups = groups
21
+ self.deformable_groups = deformable_groups
22
+ self.with_bias = bias
23
+
24
+ self.weight = nn.Parameter(
25
+ torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
26
+
27
+ self.extra_offset_mask = extra_offset_mask
28
+ self.conv_offset_mask = nn.Conv2d(
29
+ self.in_channels * 2,
30
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
31
+ kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
32
+ bias=True
33
+ )
34
+
35
+ if bias:
36
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
37
+ else:
38
+ self.register_parameter('bias', None)
39
+
40
+ self.init_offset()
41
+ self.reset_parameters()
42
+
43
+ def reset_parameters(self):
44
+ n = self.in_channels
45
+ for k in self.kernel_size:
46
+ n *= k
47
+ stdv = 1. / math.sqrt(n)
48
+ self.weight.data.uniform_(-stdv, stdv)
49
+ if self.bias is not None:
50
+ self.bias.data.zero_()
51
+
52
+ def init_offset(self):
53
+ self.conv_offset_mask.weight.data.zero_()
54
+ self.conv_offset_mask.bias.data.zero_()
55
+
56
+ def forward(self, input_feat, inter):
57
+ feat_degradation = torch.cat([input_feat, inter], dim=1)
58
+
59
+ out = self.conv_offset_mask(feat_degradation)
60
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
61
+ offset = torch.cat((o1, o2), dim=1)
62
+ mask = torch.sigmoid(mask)
63
+
64
+ return modulated_deform_conv2d(input_feat.contiguous(), offset, mask, self.weight, self.bias, self.stride,
65
+ self.padding, self.dilation, self.groups, self.deformable_groups)
text_net/encoder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from text_net.moco import MoCo
3
+
4
+
5
+ class ResBlock(nn.Module):
6
+ def __init__(self, in_feat, out_feat, stride=1):
7
+ super(ResBlock, self).__init__()
8
+ self.backbone = nn.Sequential(
9
+ nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=stride, padding=1, bias=False),
10
+ nn.BatchNorm2d(out_feat),
11
+ nn.LeakyReLU(0.1, True),
12
+ nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1, bias=False),
13
+ nn.BatchNorm2d(out_feat),
14
+ )
15
+ self.shortcut = nn.Sequential(
16
+ nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=stride, bias=False),
17
+ nn.BatchNorm2d(out_feat)
18
+ )
19
+
20
+ def forward(self, x):
21
+ return nn.LeakyReLU(0.1, True)(self.backbone(x) + self.shortcut(x))
22
+
23
+
24
+ class ResEncoder(nn.Module):
25
+ def __init__(self):
26
+ super(ResEncoder, self).__init__()
27
+
28
+ self.E_pre = ResBlock(in_feat=3, out_feat=64, stride=1)
29
+ self.E = nn.Sequential(
30
+ ResBlock(in_feat=64, out_feat=128, stride=2),
31
+ ResBlock(in_feat=128, out_feat=256, stride=2),
32
+ nn.AdaptiveAvgPool2d(1)
33
+ )
34
+
35
+ self.mlp = nn.Sequential(
36
+ nn.Linear(256, 256),
37
+ nn.LeakyReLU(0.1, True),
38
+ nn.Linear(256, 256),
39
+ )
40
+
41
+ def forward(self, x):
42
+ inter = self.E_pre(x)
43
+ fea = self.E(inter).squeeze(-1).squeeze(-1)
44
+ out = self.mlp(fea)
45
+
46
+ return fea, out, inter
47
+
48
+
49
+ class CBDE(nn.Module):
50
+ def __init__(self, opt):
51
+ super(CBDE, self).__init__()
52
+
53
+ dim = 256
54
+
55
+ # Encoder
56
+ self.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim)
57
+
58
+ def forward(self, x_query, x_key):
59
+ if self.training:
60
+ # degradation-aware represenetion learning
61
+ fea, logits, labels, inter = self.E(x_query, x_key)
62
+
63
+ return fea, logits, labels, inter
64
+ else:
65
+ # degradation-aware represenetion learning
66
+ fea, inter = self.E(x_query, x_query)
67
+ return fea, inter
text_net/moco.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class MoCo(nn.Module):
7
+ """
8
+ Build a MoCo model with: a query encoder, a key encoder, and a queue
9
+ https://arxiv.org/abs/1911.05722
10
+ """
11
+ def __init__(self, base_encoder, dim=256, K=3*256, m=0.999, T=0.07, mlp=False):
12
+ """
13
+ dim: feature dimension (default: 128)
14
+ K: queue size; number of negative keys (default: 65536)
15
+ m: moco momentum of updating key encoder (default: 0.999)
16
+ T: softmax temperature (default: 0.07)
17
+ """
18
+ super(MoCo, self).__init__()
19
+
20
+ self.K = K
21
+ self.m = m
22
+ self.T = T
23
+
24
+ # create the encoders
25
+ # num_classes is the output fc dimension
26
+ self.encoder_q = base_encoder()
27
+ self.encoder_k = base_encoder()
28
+
29
+ for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
30
+ param_k.data.copy_(param_q.data) # initialize
31
+ param_k.requires_grad = False # not update by gradient
32
+
33
+ # create the queue
34
+ self.register_buffer("queue", torch.randn(dim, K))
35
+ self.queue = nn.functional.normalize(self.queue, dim=0)
36
+
37
+ self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
38
+
39
+ @torch.no_grad()
40
+ def _momentum_update_key_encoder(self):
41
+ """
42
+ Momentum update of the key encoder
43
+ """
44
+ for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
45
+ param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
46
+
47
+ @torch.no_grad()
48
+ def _dequeue_and_enqueue(self, keys):
49
+ # gather keys before updating queue
50
+ # keys = concat_all_gather(keys)
51
+ batch_size = keys.shape[0]
52
+
53
+ ptr = int(self.queue_ptr)
54
+ assert self.K % batch_size == 0 # for simplicity
55
+
56
+ # replace the keys at ptr (dequeue and enqueue)
57
+ self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
58
+ ptr = (ptr + batch_size) % self.K # move pointer
59
+
60
+ self.queue_ptr[0] = ptr
61
+
62
+ @torch.no_grad()
63
+ def _batch_shuffle_ddp(self, x):
64
+ """
65
+ Batch shuffle, for making use of BatchNorm.
66
+ *** Only support DistributedDataParallel (DDP) model. ***
67
+ """
68
+ # gather from all gpus
69
+ batch_size_this = x.shape[0]
70
+ x_gather = concat_all_gather(x)
71
+ batch_size_all = x_gather.shape[0]
72
+
73
+ num_gpus = batch_size_all // batch_size_this
74
+
75
+ # random shuffle index
76
+ idx_shuffle = torch.randperm(batch_size_all).cuda()
77
+
78
+ # broadcast to all gpus
79
+ torch.distributed.broadcast(idx_shuffle, src=0)
80
+
81
+ # index for restoring
82
+ idx_unshuffle = torch.argsort(idx_shuffle)
83
+
84
+ # shuffled index for this gpu
85
+ gpu_idx = torch.distributed.get_rank()
86
+ idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
87
+
88
+ return x_gather[idx_this], idx_unshuffle
89
+
90
+ @torch.no_grad()
91
+ def _batch_unshuffle_ddp(self, x, idx_unshuffle):
92
+ """
93
+ Undo batch shuffle.
94
+ *** Only support DistributedDataParallel (DDP) model. ***
95
+ """
96
+ # gather from all gpus
97
+ batch_size_this = x.shape[0]
98
+ x_gather = concat_all_gather(x)
99
+ batch_size_all = x_gather.shape[0]
100
+
101
+ num_gpus = batch_size_all // batch_size_this
102
+
103
+ # restored index for this gpu
104
+ gpu_idx = torch.distributed.get_rank()
105
+ idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
106
+
107
+ return x_gather[idx_this]
108
+
109
+ def forward(self, im_q, im_k):
110
+ """
111
+ Input:
112
+ im_q: a batch of query images
113
+ im_k: a batch of key images
114
+ Output:
115
+ logits, targets
116
+ """
117
+ if self.training:
118
+ # compute query features
119
+ embedding, q, inter = self.encoder_q(im_q) # queries: NxC
120
+ q = nn.functional.normalize(q, dim=1)
121
+
122
+ # compute key features
123
+ with torch.no_grad(): # no gradient to keys
124
+ self._momentum_update_key_encoder() # update the key encoder
125
+
126
+ _, k, _ = self.encoder_k(im_k) # keys: NxC
127
+ k = nn.functional.normalize(k, dim=1)
128
+
129
+ # compute logits
130
+ # Einstein sum is more intuitive
131
+ # positive logits: Nx1
132
+ l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
133
+ # negative logits: NxK
134
+ l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
135
+
136
+ # logits: Nx(1+K)
137
+ logits = torch.cat([l_pos, l_neg], dim=1)
138
+
139
+ # apply temperature
140
+ logits /= self.T
141
+
142
+ # labels: positive key indicators
143
+ labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
144
+ # dequeue and enqueue
145
+ self._dequeue_and_enqueue(k)
146
+
147
+ return embedding, logits, labels, inter
148
+ else:
149
+ embedding, _, inter = self.encoder_q(im_q)
150
+
151
+ return embedding, inter
152
+
153
+
154
+ # utils
155
+ @torch.no_grad()
156
+ def concat_all_gather(tensor):
157
+ """
158
+ Performs all_gather operation on the provided tensors.
159
+ *** Warning ***: torch.distributed.all_gather has no gradient.
160
+ """
161
+ tensors_gather = [torch.ones_like(tensor)
162
+ for _ in range(torch.distributed.get_world_size())]
163
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
164
+
165
+ output = torch.cat(tensors_gather, dim=0)
166
+ return output
text_net/model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from text_net.encoder import CBDE
4
+ from text_net.DGRN import DGRN
5
+
6
+
7
+ class AirNet(nn.Module):
8
+ def __init__(self, opt):
9
+ super(AirNet, self).__init__()
10
+
11
+ # Restorer
12
+ self.R = DGRN(opt)
13
+
14
+ # Encoder
15
+ self.E = CBDE(opt)
16
+
17
+ def forward(self, x_query, x_key, text_prompt):
18
+ if self.training:
19
+ fea, logits, labels, inter = self.E(x_query, x_key)
20
+
21
+ restored = self.R(x_query, inter, text_prompt)
22
+
23
+ return restored, logits, labels
24
+ else:
25
+ fea, inter = self.E(x_query, x_query)
26
+
27
+ restored = self.R(x_query, inter, text_prompt)
28
+
29
+ return restored
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (138 Bytes). View file
 
utils/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (123 Bytes). View file
 
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (148 Bytes). View file
 
utils/__pycache__/dataset_utils.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
utils/__pycache__/dataset_utils.cpython-36.pyc ADDED
Binary file (30.8 kB). View file
 
utils/__pycache__/dataset_utils.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
utils/__pycache__/dataset_utils_CDD.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
utils/__pycache__/degradation_utils.cpython-310.pyc ADDED
Binary file (1.8 kB). View file
 
utils/__pycache__/degradation_utils.cpython-36.pyc ADDED
Binary file (3.4 kB). View file
 
utils/__pycache__/degradation_utils.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
 
utils/__pycache__/image_io.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
utils/__pycache__/image_io.cpython-36.pyc ADDED
Binary file (11.3 kB). View file
 
utils/__pycache__/image_io.cpython-38.pyc ADDED
Binary file (11.1 kB). View file
 
utils/__pycache__/image_utils.cpython-310.pyc ADDED
Binary file (7.48 kB). View file
 
utils/__pycache__/image_utils.cpython-36.pyc ADDED
Binary file (7.61 kB). View file
 
utils/__pycache__/image_utils.cpython-38.pyc ADDED
Binary file (7.46 kB). View file
 
utils/__pycache__/imresize.cpython-36.pyc ADDED
Binary file (4.75 kB). View file
 
utils/__pycache__/imresize.cpython-38.pyc ADDED
Binary file (4.75 kB). View file
 
utils/__pycache__/loss_utils.cpython-38.pyc ADDED
Binary file (1.43 kB). View file
 
utils/__pycache__/val_utils.cpython-310.pyc ADDED
Binary file (3.35 kB). View file
 
utils/__pycache__/val_utils.cpython-36.pyc ADDED
Binary file (2.34 kB). View file
 
utils/__pycache__/val_utils.cpython-38.pyc ADDED
Binary file (3.27 kB). View file
 
utils/dataset_utils.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import copy
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ from torch.utils.data import Dataset
8
+ from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
9
+
10
+ from utils.image_utils import random_augmentation, crop_img
11
+ from utils.degradation_utils import Degradation
12
+
13
+
14
+ class TrainDataset(Dataset):
15
+ def __init__(self, args):
16
+ super(TrainDataset, self).__init__()
17
+ self.args = args
18
+ self.rs_ids = []
19
+ self.hazy_ids = []
20
+ self.D = Degradation(args)
21
+ self.de_temp = 0
22
+ self.de_type = self.args.de_type
23
+ self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
24
+
25
+ self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4}
26
+
27
+ self._init_ids()
28
+
29
+ self.crop_transform = Compose([
30
+ ToPILImage(),
31
+ RandomCrop(args.patch_size),
32
+ ])
33
+
34
+ self.toTensor = ToTensor()
35
+
36
+ def _init_ids(self):
37
+ if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type:
38
+ self._init_clean_ids()
39
+ if 'derain' in self.de_type:
40
+ self._init_rs_ids()
41
+ if 'dehaze' in self.de_type:
42
+ self._init_hazy_ids()
43
+
44
+ random.shuffle(self.de_type)
45
+
46
+ def _init_clean_ids(self):
47
+ clean_ids = []
48
+ # 파일 목록 중 이미지 파일만 필터링
49
+ name_list = os.listdir(self.args.denoise_dir)
50
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
51
+
52
+ clean_ids += [self.args.denoise_dir + id_ for id_ in name_list]
53
+
54
+ if 'denoise_15' in self.de_type:
55
+ self.s15_ids = copy.deepcopy(clean_ids)
56
+ random.shuffle(self.s15_ids)
57
+ self.s15_counter = 0
58
+ if 'denoise_25' in self.de_type:
59
+ self.s25_ids = copy.deepcopy(clean_ids)
60
+ random.shuffle(self.s25_ids)
61
+ self.s25_counter = 0
62
+ if 'denoise_50' in self.de_type:
63
+ self.s50_ids = copy.deepcopy(clean_ids)
64
+ random.shuffle(self.s50_ids)
65
+ self.s50_counter = 0
66
+
67
+ # print(clean_ids)
68
+
69
+ self.num_clean = len(clean_ids)
70
+
71
+ def _init_hazy_ids(self):
72
+ # 파일 목록 중 이미지 파일만 필터링
73
+ dehaze_ids = []
74
+ name_list = os.listdir(self.args.dehaze_dir)
75
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
76
+ dehaze_ids += [self.args.dehaze_dir + id_ for id_ in name_list]
77
+ self.hazy_ids = dehaze_ids
78
+
79
+ self.hazy_counter = 0
80
+ self.num_hazy = len(self.hazy_ids)
81
+
82
+ def _init_rs_ids(self):
83
+ # 파일 목록 중 이미지 파일만 필터링
84
+ derain_ids = []
85
+ name_list = os.listdir(self.args.derain_dir)
86
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
87
+ derain_ids += [self.args.derain_dir + id_ for id_ in name_list]
88
+ self.rs_ids = derain_ids
89
+
90
+ self.rl_counter = 0
91
+ # print(derain_ids)
92
+
93
+ self.num_rl = len(self.rs_ids)
94
+
95
+ def _crop_patch(self, img_1, img_2):
96
+ H = img_1.shape[0]
97
+ W = img_1.shape[1]
98
+ ind_H = random.randint(0, H - self.args.patch_size)
99
+ ind_W = random.randint(0, W - self.args.patch_size)
100
+
101
+ patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
102
+ patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
103
+
104
+ return patch_1, patch_2
105
+
106
+ def _get_gt_name(self, rainy_name):
107
+ gt_name = 'data/' + 'Target/Derain/norain-' + rainy_name.split('rain-')[-1]
108
+ return gt_name
109
+
110
+ def _get_nonhazy_name(self, hazy_name):
111
+ gt_name = 'data/' + 'Target/Dehaze/nohaze-' + rainy_name.split('haze-')[-1]
112
+ return gt_name
113
+
114
+ def __getitem__(self, _):
115
+ de_id = self.de_dict[self.de_type[self.de_temp]]
116
+
117
+ if de_id < 3:
118
+ if de_id == 0:
119
+ clean_id = self.s15_ids[self.s15_counter]
120
+ self.s15_counter = (self.s15_counter + 1) % self.num_clean
121
+ if self.s15_counter == 0:
122
+ random.shuffle(self.s15_ids)
123
+ elif de_id == 1:
124
+ clean_id = self.s25_ids[self.s25_counter]
125
+ self.s25_counter = (self.s25_counter + 1) % self.num_clean
126
+ if self.s25_counter == 0:
127
+ random.shuffle(self.s25_ids)
128
+ elif de_id == 2:
129
+ clean_id = self.s50_ids[self.s50_counter]
130
+ self.s50_counter = (self.s50_counter + 1) % self.num_clean
131
+ if self.s50_counter == 0:
132
+ random.shuffle(self.s50_ids)
133
+
134
+ # clean_id = random.randint(0, len(self.clean_ids) - 1)
135
+ clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
136
+ clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)
137
+ clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2)
138
+
139
+ # clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
140
+ clean_name = clean_id.split("/")[-1].split('.')[0]
141
+
142
+ clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)
143
+ degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)
144
+ else:
145
+ if de_id == 3:
146
+ # Rain Streak Removal
147
+ # rl_id = random.randint(0, len(self.rl_ids) - 1)
148
+ degrad_img = crop_img(np.array(Image.open(self.rs_ids[self.rl_counter]).convert('RGB')), base=16)
149
+ clean_name = self._get_gt_name(self.rs_ids[self.rl_counter])
150
+ clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
151
+
152
+ self.rl_counter = (self.rl_counter + 1) % self.num_rl
153
+ if self.rl_counter == 0:
154
+ random.shuffle(self.rs_ids)
155
+ elif de_id == 4:
156
+ # Dehazing with SOTS outdoor training set
157
+ # hazy_id = random.randint(0, len(self.hazy_ids) - 1)
158
+ degrad_img = crop_img(np.array(Image.open(self.hazy_ids[self.hazy_counter]).convert('RGB')), base=16)
159
+ clean_name = self._get_nonhazy_name(self.hazy_ids[self.hazy_counter])
160
+ clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
161
+
162
+ self.hazy_counter = (self.hazy_counter + 1) % self.num_hazy
163
+ if self.hazy_counter == 0:
164
+ random.shuffle(self.hazy_ids)
165
+ degrad_patch_1, clean_patch_1 = random_augmentation(*self._crop_patch(degrad_img, clean_img))
166
+ degrad_patch_2, clean_patch_2 = random_augmentation(*self._crop_patch(degrad_img, clean_img))
167
+
168
+ clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2)
169
+ degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2)
170
+
171
+ self.de_temp = (self.de_temp + 1) % len(self.de_type)
172
+ if self.de_temp == 0:
173
+ random.shuffle(self.de_type)
174
+
175
+ return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2
176
+
177
+ def __len__(self):
178
+ return 400 * len(self.args.de_type)
179
+
180
+
181
+ class DenoiseTestDataset(Dataset):
182
+ def __init__(self, args):
183
+ super(DenoiseTestDataset, self).__init__()
184
+ self.args = args
185
+ self.clean_ids = []
186
+ self.sigma = 15
187
+ self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
188
+
189
+ self._init_clean_ids()
190
+
191
+ self.toTensor = ToTensor()
192
+
193
+ def _init_clean_ids(self):
194
+ clean_ids = []
195
+ # 파일 목록 중 이미지 파일만 필터링
196
+ name_list = os.listdir(self.args.denoise_path)
197
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
198
+ self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list]
199
+
200
+ self.num_clean = len(self.clean_ids)
201
+
202
+ def _add_gaussian_noise(self, clean_patch):
203
+ noise = np.random.randn(*clean_patch.shape)
204
+ noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8)
205
+ return noisy_patch, clean_patch
206
+
207
+ def set_sigma(self, sigma):
208
+ self.sigma = sigma
209
+
210
+ def __getitem__(self, clean_id):
211
+ clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16)
212
+ clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
213
+
214
+ noisy_img, _ = self._add_gaussian_noise(clean_img)
215
+ clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img)
216
+
217
+ return [clean_name], noisy_img, clean_img
218
+
219
+ def __len__(self):
220
+ return self.num_clean
221
+
222
+
223
+ class DerainDehazeDataset(Dataset):
224
+ def __init__(self, args, task="derain"):
225
+ super(DerainDehazeDataset, self).__init__()
226
+ self.ids = []
227
+ self.task_idx = 0
228
+ self.args = args
229
+ self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
230
+
231
+ self.task_dict = {'derain': 0, 'dehaze': 1}
232
+ self.toTensor = ToTensor()
233
+
234
+ self.set_dataset(task)
235
+
236
+ def _init_input_ids(self):
237
+ if self.task_idx == 0:
238
+ self.ids = []
239
+ # 파일 목록 중 이미지 파일만 필터링
240
+ name_list = os.listdir(self.args.derain_path + 'input/')
241
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
242
+ self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list]
243
+ elif self.task_idx == 1:
244
+ self.ids = []
245
+ # 파일 목록 중 이미지 파일만 필터링
246
+ name_list = os.listdir(self.args.dehaze_path + 'input/')
247
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
248
+ self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list]
249
+
250
+ self.length = len(self.ids)
251
+
252
+ def _get_gt_path(self, degraded_name):
253
+ if self.task_idx == 0:
254
+ gt_name = '/'.join(degraded_name.replace("input", "target").split('/')[:-1] + degraded_name.replace("input", "target").replace("rain", "norain").split('/')[-1:])
255
+ print(gt_name)
256
+ elif self.task_idx == 1:
257
+ dir_name = degraded_name.split("input")[0] + 'target/'
258
+ name = degraded_name.split('/')[-1].split('_')[0] + '.png'
259
+ gt_name = dir_name + name
260
+ return gt_name
261
+
262
+ def set_dataset(self, task):
263
+ self.task_idx = self.task_dict[task]
264
+ self._init_input_ids()
265
+
266
+ def __getitem__(self, idx):
267
+ degraded_path = self.ids[idx]
268
+ clean_path = self._get_gt_path(degraded_path)
269
+
270
+ degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16)
271
+ clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16)
272
+
273
+ clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
274
+ degraded_name = degraded_path.split('/')[-1][:-4]
275
+
276
+ return [degraded_name], degraded_img, clean_img
277
+
278
+ def __len__(self):
279
+ return self.length
280
+
281
+
282
+ class TestSpecificDataset(Dataset):
283
+ def __init__(self, args):
284
+ super(TestSpecificDataset, self).__init__()
285
+ self.args = args
286
+ self.degraded_ids = []
287
+ self._init_clean_ids(args.test_path)
288
+
289
+ self.toTensor = ToTensor()
290
+
291
+ def _init_clean_ids(self, root):
292
+ degraded_ids = []
293
+ # 파일 목록 중 이미지 파일만 필터링
294
+ name_list = os.listdir(root)
295
+ name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
296
+ self.degraded_ids += [root + id_ for id_ in name_list]
297
+
298
+ self.num_img = len(self.degraded_ids)
299
+
300
+ def __getitem__(self, idx):
301
+ degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16)
302
+ name = self.degraded_ids[idx].split('/')[-1][:-4]
303
+
304
+ degraded_img = self.toTensor(degraded_img)
305
+
306
+ return [name], degraded_img
307
+
308
+ def __len__(self):
309
+ return self.num_img