This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gradio/certificate.pem +31 -0
- __pycache__/inference.cpython-310.pyc +0 -0
- app.py +42 -4
- ckpt/epoch_287.pth +3 -0
- inference.py +59 -0
- text_net/DGRN.py +232 -0
- text_net/__pycache__/DGRN.cpython-310.pyc +0 -0
- text_net/__pycache__/DGRN.cpython-38.pyc +0 -0
- text_net/__pycache__/deform_conv.cpython-310.pyc +0 -0
- text_net/__pycache__/deform_conv.cpython-36.pyc +0 -0
- text_net/__pycache__/deform_conv.cpython-38.pyc +0 -0
- text_net/__pycache__/encoder.cpython-310.pyc +0 -0
- text_net/__pycache__/encoder.cpython-36.pyc +0 -0
- text_net/__pycache__/encoder.cpython-38.pyc +0 -0
- text_net/__pycache__/moco.cpython-310.pyc +0 -0
- text_net/__pycache__/moco.cpython-36.pyc +0 -0
- text_net/__pycache__/moco.cpython-38.pyc +0 -0
- text_net/__pycache__/model.cpython-310.pyc +0 -0
- text_net/__pycache__/model.cpython-36.pyc +0 -0
- text_net/__pycache__/model.cpython-38.pyc +0 -0
- text_net/deform_conv.py +65 -0
- text_net/encoder.py +67 -0
- text_net/moco.py +166 -0
- text_net/model.py +29 -0
- utils/.DS_Store +0 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-36.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/dataset_utils.cpython-310.pyc +0 -0
- utils/__pycache__/dataset_utils.cpython-36.pyc +0 -0
- utils/__pycache__/dataset_utils.cpython-38.pyc +0 -0
- utils/__pycache__/dataset_utils_CDD.cpython-310.pyc +0 -0
- utils/__pycache__/degradation_utils.cpython-310.pyc +0 -0
- utils/__pycache__/degradation_utils.cpython-36.pyc +0 -0
- utils/__pycache__/degradation_utils.cpython-38.pyc +0 -0
- utils/__pycache__/image_io.cpython-310.pyc +0 -0
- utils/__pycache__/image_io.cpython-36.pyc +0 -0
- utils/__pycache__/image_io.cpython-38.pyc +0 -0
- utils/__pycache__/image_utils.cpython-310.pyc +0 -0
- utils/__pycache__/image_utils.cpython-36.pyc +0 -0
- utils/__pycache__/image_utils.cpython-38.pyc +0 -0
- utils/__pycache__/imresize.cpython-36.pyc +0 -0
- utils/__pycache__/imresize.cpython-38.pyc +0 -0
- utils/__pycache__/loss_utils.cpython-38.pyc +0 -0
- utils/__pycache__/val_utils.cpython-310.pyc +0 -0
- utils/__pycache__/val_utils.cpython-36.pyc +0 -0
- utils/__pycache__/val_utils.cpython-38.pyc +0 -0
- utils/dataset_utils.py +309 -0
.gitattributes
CHANGED
@@ -19,6 +19,7 @@
|
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
24 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
25 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
__pycache__/inference.cpython-310.pyc
ADDED
Binary file (1.99 kB). View file
|
|
app.py
CHANGED
@@ -1,7 +1,45 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
|
3 |
-
def greet(
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from inference import infer
|
3 |
|
4 |
+
def greet(image, prompt):
|
5 |
+
restore_img = infer(img=image, text_prompt=prompt)
|
6 |
+
return restore_img
|
7 |
|
8 |
+
|
9 |
+
title = "🖼️ ICDR 🖼️"
|
10 |
+
description = ''' ## ICDR: Image Restoration Framework for Composite Degradation following Human Instructions
|
11 |
+
Our Github : https://github.com/
|
12 |
+
|
13 |
+
Siwon Kim, Donghyeon Yoon
|
14 |
+
|
15 |
+
Ajou Univ
|
16 |
+
'''
|
17 |
+
|
18 |
+
|
19 |
+
article = "<p style='text-align: center'><a href='https://github.com/' target='_blank'>ICDR</a></p>"
|
20 |
+
|
21 |
+
#### Image,Prompts examples
|
22 |
+
examples = [['input/00010.png', "I love this photo, could you remove the haze and more brighter?"],
|
23 |
+
['input/00058.png', "I have to post an emotional shot on Instagram, but it was shot too foggy and too dark. Change it like a sunny day and brighten it up!"]]
|
24 |
+
|
25 |
+
css = """
|
26 |
+
.image-frame img, .image-container img {
|
27 |
+
width: auto;
|
28 |
+
height: auto;
|
29 |
+
max-width: none;
|
30 |
+
}
|
31 |
+
"""
|
32 |
+
|
33 |
+
|
34 |
+
demo = gr.Interface(
|
35 |
+
fn=greet,
|
36 |
+
inputs=[gr.Image(type="pil", label="Input"),
|
37 |
+
gr.Text(label="Prompt") ],
|
38 |
+
outputs=[gr.Image(type="pil", label="Ouput")],
|
39 |
+
title=title,
|
40 |
+
description=description,
|
41 |
+
article=article,
|
42 |
+
examples=examples,
|
43 |
+
css=css,
|
44 |
+
)
|
45 |
+
demo.launch(share=True)
|
ckpt/epoch_287.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db279692728bd4614759c08a0478d9d07200768e5fb7fa893e78aaa05f3ca707
|
3 |
+
size 48705338
|
inference.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import subprocess
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
|
9 |
+
from utils.dataset_utils_CDD import DerainDehazeDataset
|
10 |
+
from utils.val_utils import AverageMeter, compute_psnr_ssim
|
11 |
+
from utils.image_io import save_image_tensor
|
12 |
+
|
13 |
+
from text_net.model import AirNet
|
14 |
+
|
15 |
+
def test_Derain_Dehaze(opt, net, dataset, task="derain"):
|
16 |
+
output_path = opt.output_path + task + '/'
|
17 |
+
subprocess.check_output(['mkdir', '-p', output_path])
|
18 |
+
|
19 |
+
# dataset.set_dataset(task)
|
20 |
+
testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
|
21 |
+
print(len(testloader))
|
22 |
+
|
23 |
+
with torch.no_grad():
|
24 |
+
for ([degraded_name], degradation, degrad_patch, clean_patch, text_prompt) in tqdm(testloader):
|
25 |
+
degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda()
|
26 |
+
restored = net(x_query=degrad_patch, x_key=degrad_patch, text_prompt = text_prompt)
|
27 |
+
|
28 |
+
return save_image_tensor(restored)
|
29 |
+
|
30 |
+
|
31 |
+
def infer(text_prompt = "", img=None):
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
# Input Parameters
|
34 |
+
parser.add_argument('--cuda', type=int, default=0)
|
35 |
+
parser.add_argument('--derain_path', type=str, default="data/Test_prompting/", help='save path of test raining images')
|
36 |
+
parser.add_argument('--output_path', type=str, default="output/demo11", help='output save path')
|
37 |
+
parser.add_argument('--ckpt_path', type=str, default="ckpt/epoch_287.pth", help='checkpoint save path')
|
38 |
+
# parser.add_argument('--text_prompt', type=str, default="derain")
|
39 |
+
|
40 |
+
opt = parser.parse_args()
|
41 |
+
# opt.text_prompt = text_prompt
|
42 |
+
|
43 |
+
np.random.seed(0)
|
44 |
+
torch.manual_seed(0)
|
45 |
+
torch.cuda.set_device(opt.cuda)
|
46 |
+
|
47 |
+
opt.batch_size = 7
|
48 |
+
ckpt_path = opt.ckpt_path
|
49 |
+
|
50 |
+
derain_set = DerainDehazeDataset(opt, img=img, text_prompt = text_prompt)
|
51 |
+
|
52 |
+
# Make network
|
53 |
+
net = AirNet(opt).cuda()
|
54 |
+
net.eval()
|
55 |
+
net.load_state_dict(torch.load(ckpt_path, map_location=torch.device(opt.cuda)))
|
56 |
+
|
57 |
+
restored = test_Derain_Dehaze(opt, net, derain_set, task="derain")
|
58 |
+
|
59 |
+
return restored
|
text_net/DGRN.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
from .deform_conv import DCN_layer
|
4 |
+
import clip
|
5 |
+
|
6 |
+
clip_model, preprocess = clip.load("ViT-B/32", device='cuda')
|
7 |
+
|
8 |
+
# 동적으로 텍스트 임베딩 차원 가져오기
|
9 |
+
text_embed_dim = clip_model.text_projection.shape[1]
|
10 |
+
|
11 |
+
|
12 |
+
def default_conv(in_channels, out_channels, kernel_size, bias=True):
|
13 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)
|
14 |
+
|
15 |
+
|
16 |
+
class DGM(nn.Module):
|
17 |
+
def __init__(self, channels_in, channels_out, kernel_size):
|
18 |
+
super(DGM, self).__init__()
|
19 |
+
self.channels_out = channels_out
|
20 |
+
self.channels_in = channels_in
|
21 |
+
self.kernel_size = kernel_size
|
22 |
+
|
23 |
+
self.dcn = DCN_layer(self.channels_in, self.channels_out, kernel_size,
|
24 |
+
padding=(kernel_size - 1) // 2, bias=False)
|
25 |
+
self.sft = SFT_layer(self.channels_in, self.channels_out)
|
26 |
+
|
27 |
+
self.relu = nn.LeakyReLU(0.1, True)
|
28 |
+
|
29 |
+
def forward(self, x, inter, text_prompt):
|
30 |
+
'''
|
31 |
+
:param x: feature map: B * C * H * W
|
32 |
+
:inter: degradation map: B * C * H * W
|
33 |
+
'''
|
34 |
+
dcn_out = self.dcn(x, inter)
|
35 |
+
sft_out = self.sft(x, inter, text_prompt)
|
36 |
+
out = dcn_out + sft_out
|
37 |
+
out = x + out
|
38 |
+
|
39 |
+
return out
|
40 |
+
|
41 |
+
# Projection Head 정의
|
42 |
+
class TextProjectionHead(nn.Module):
|
43 |
+
def __init__(self, input_dim, output_dim):
|
44 |
+
super(TextProjectionHead, self).__init__()
|
45 |
+
self.proj = nn.Sequential(
|
46 |
+
nn.Linear(input_dim, output_dim),
|
47 |
+
nn.ReLU(),
|
48 |
+
nn.Linear(output_dim, output_dim)
|
49 |
+
).float()
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
return self.proj(x.float())
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class SFT_layer(nn.Module):
|
57 |
+
def __init__(self, channels_in, channels_out):
|
58 |
+
super(SFT_layer, self).__init__()
|
59 |
+
self.conv_gamma = nn.Sequential(
|
60 |
+
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False),
|
61 |
+
nn.LeakyReLU(0.1, True),
|
62 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
63 |
+
)
|
64 |
+
self.conv_beta = nn.Sequential(
|
65 |
+
nn.Conv2d(channels_in, channels_out, 1, 1, 0, bias=False),
|
66 |
+
nn.LeakyReLU(0.1, True),
|
67 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
68 |
+
)
|
69 |
+
|
70 |
+
self.text_proj_head = TextProjectionHead(text_embed_dim, channels_out)
|
71 |
+
|
72 |
+
'''
|
73 |
+
self.text_gamma = nn.Sequential(
|
74 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
75 |
+
nn.LeakyReLU(0.1, True),
|
76 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
77 |
+
).float()
|
78 |
+
self.text_beta = nn.Sequential(
|
79 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
80 |
+
nn.LeakyReLU(0.1, True),
|
81 |
+
nn.Conv2d(channels_out, channels_out, 1, 1, 0, bias=False),
|
82 |
+
).float()
|
83 |
+
'''
|
84 |
+
|
85 |
+
self.cross_attention = nn.MultiheadAttention(embed_dim=channels_out, num_heads=2)
|
86 |
+
|
87 |
+
|
88 |
+
def forward(self, x, inter, text_prompt):
|
89 |
+
'''
|
90 |
+
:param x: degradation representation: B * C
|
91 |
+
:param inter: degradation intermediate representation map: B * C * H * W
|
92 |
+
'''
|
93 |
+
# img_gamma = self.conv_gamma(inter)
|
94 |
+
# img_beta = self.conv_beta(inter)
|
95 |
+
|
96 |
+
B, C, H, W = inter.shape #cross attention
|
97 |
+
|
98 |
+
|
99 |
+
text_tokens = clip.tokenize(text_prompt).to(x.device) # Tokenize the text prompts (Batch size)
|
100 |
+
with torch.no_grad():
|
101 |
+
text_embed = clip_model.encode_text(text_tokens)
|
102 |
+
|
103 |
+
text_proj = self.text_proj_head(text_embed).float()
|
104 |
+
|
105 |
+
# 텍스트 임베딩 차원 확장: (B, C, H, W)로 변경 #concat
|
106 |
+
# text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, self.conv_gamma[0].out_channels, H, W)
|
107 |
+
text_proj_expanded = text_proj.unsqueeze(-1).unsqueeze(-1).expand(B, C, H, W)
|
108 |
+
|
109 |
+
# 이미지 중간 표현과 텍스트 임베딩 결합 (concat)
|
110 |
+
combined = inter * text_proj_expanded
|
111 |
+
# combined = torch.cat([inter, text_proj_expanded], dim=1)
|
112 |
+
|
113 |
+
# 이미지와 텍스트 기반 gamma와 beta 계산
|
114 |
+
img_gamma = self.conv_gamma(combined)
|
115 |
+
img_beta = self.conv_beta(combined)
|
116 |
+
|
117 |
+
''' simple concat
|
118 |
+
text_gamma = self.text_gamma(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W)
|
119 |
+
text_beta = self.text_beta(text_proj.unsqueeze(-1).unsqueeze(-1)) # Reshape to match (B, C, H, W)
|
120 |
+
'''
|
121 |
+
|
122 |
+
'''
|
123 |
+
text_proj = text_proj.unsqueeze(1).expand(-1, H*W, -1) # B * (H*W) * C
|
124 |
+
|
125 |
+
# 이미지 중간 표현 변환: B * (H*W) * C로 변경
|
126 |
+
inter_flat = inter.view(B, C, -1).permute(2, 0, 1) # (H*W) * B * C
|
127 |
+
|
128 |
+
# Cross-attention 적용
|
129 |
+
attn_output, _ = self.cross_attention(text_proj.permute(1, 0, 2), inter_flat, inter_flat)
|
130 |
+
attn_output = attn_output.permute(1, 2, 0).view(B, C, H, W) # B * C * H * W
|
131 |
+
|
132 |
+
# Gamma와 Beta 계산
|
133 |
+
img_gamma = self.conv_gamma(attn_output)
|
134 |
+
img_beta = self.conv_beta(attn_output)
|
135 |
+
'''
|
136 |
+
# concat으로 text 결합 실험
|
137 |
+
return x * img_gamma + img_beta
|
138 |
+
|
139 |
+
|
140 |
+
class DGB(nn.Module):
|
141 |
+
def __init__(self, conv, n_feat, kernel_size):
|
142 |
+
super(DGB, self).__init__()
|
143 |
+
|
144 |
+
# self.da_conv1 = DGM(n_feat, n_feat, kernel_size)
|
145 |
+
# self.da_conv2 = DGM(n_feat, n_feat, kernel_size)
|
146 |
+
self.dgm1 = DGM(n_feat, n_feat, kernel_size)
|
147 |
+
self.dgm2 = DGM(n_feat, n_feat, kernel_size)
|
148 |
+
self.conv1 = conv(n_feat, n_feat, kernel_size)
|
149 |
+
self.conv2 = conv(n_feat, n_feat, kernel_size)
|
150 |
+
|
151 |
+
self.relu = nn.LeakyReLU(0.1, True)
|
152 |
+
|
153 |
+
def forward(self, x, inter, text_prompt):
|
154 |
+
'''
|
155 |
+
:param x: feature map: B * C * H * W
|
156 |
+
:param inter: degradation representation: B * C * H * W
|
157 |
+
'''
|
158 |
+
|
159 |
+
out = self.relu(self.dgm1(x, inter, text_prompt))
|
160 |
+
out = self.relu(self.conv1(out))
|
161 |
+
out = self.relu(self.dgm2(out, inter, text_prompt))
|
162 |
+
out = self.conv2(out) + x
|
163 |
+
|
164 |
+
return out
|
165 |
+
|
166 |
+
|
167 |
+
class DGG(nn.Module):
|
168 |
+
def __init__(self, conv, n_feat, kernel_size, n_blocks):
|
169 |
+
super(DGG, self).__init__()
|
170 |
+
self.n_blocks = n_blocks
|
171 |
+
modules_body = [
|
172 |
+
DGB(conv, n_feat, kernel_size) \
|
173 |
+
for _ in range(n_blocks)
|
174 |
+
]
|
175 |
+
modules_body.append(conv(n_feat, n_feat, kernel_size))
|
176 |
+
|
177 |
+
self.body = nn.Sequential(*modules_body)
|
178 |
+
|
179 |
+
def forward(self, x, inter, text_prompt):
|
180 |
+
'''
|
181 |
+
:param x: feature map: B * C * H * W
|
182 |
+
:param inter: degradation representation: B * C * H * W
|
183 |
+
'''
|
184 |
+
res = x
|
185 |
+
for i in range(self.n_blocks):
|
186 |
+
res = self.body[i](res, inter, text_prompt)
|
187 |
+
res = self.body[-1](res)
|
188 |
+
res = res + x
|
189 |
+
|
190 |
+
return res
|
191 |
+
|
192 |
+
|
193 |
+
class DGRN(nn.Module):
|
194 |
+
def __init__(self, opt, conv=default_conv):
|
195 |
+
super(DGRN, self).__init__()
|
196 |
+
|
197 |
+
self.n_groups = 5
|
198 |
+
n_blocks = 5
|
199 |
+
n_feats = 64
|
200 |
+
kernel_size = 3
|
201 |
+
|
202 |
+
# head module
|
203 |
+
modules_head = [conv(3, n_feats, kernel_size)]
|
204 |
+
self.head = nn.Sequential(*modules_head)
|
205 |
+
|
206 |
+
# body
|
207 |
+
modules_body = [
|
208 |
+
DGG(default_conv, n_feats, kernel_size, n_blocks) \
|
209 |
+
for _ in range(self.n_groups)
|
210 |
+
]
|
211 |
+
modules_body.append(conv(n_feats, n_feats, kernel_size))
|
212 |
+
self.body = nn.Sequential(*modules_body)
|
213 |
+
|
214 |
+
# tail
|
215 |
+
modules_tail = [conv(n_feats, 3, kernel_size)]
|
216 |
+
self.tail = nn.Sequential(*modules_tail)
|
217 |
+
|
218 |
+
def forward(self, x, inter, text_prompt):
|
219 |
+
# head
|
220 |
+
x = self.head(x)
|
221 |
+
|
222 |
+
# body
|
223 |
+
res = x
|
224 |
+
for i in range(self.n_groups):
|
225 |
+
res = self.body[i](res, inter, text_prompt)
|
226 |
+
res = self.body[-1](res)
|
227 |
+
res = res + x
|
228 |
+
|
229 |
+
# tail
|
230 |
+
x = self.tail(res)
|
231 |
+
|
232 |
+
return x
|
text_net/__pycache__/DGRN.cpython-310.pyc
ADDED
Binary file (5.61 kB). View file
|
|
text_net/__pycache__/DGRN.cpython-38.pyc
ADDED
Binary file (4.53 kB). View file
|
|
text_net/__pycache__/deform_conv.cpython-310.pyc
ADDED
Binary file (2.2 kB). View file
|
|
text_net/__pycache__/deform_conv.cpython-36.pyc
ADDED
Binary file (2.14 kB). View file
|
|
text_net/__pycache__/deform_conv.cpython-38.pyc
ADDED
Binary file (2.21 kB). View file
|
|
text_net/__pycache__/encoder.cpython-310.pyc
ADDED
Binary file (2.33 kB). View file
|
|
text_net/__pycache__/encoder.cpython-36.pyc
ADDED
Binary file (2.38 kB). View file
|
|
text_net/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (2.36 kB). View file
|
|
text_net/__pycache__/moco.cpython-310.pyc
ADDED
Binary file (4.43 kB). View file
|
|
text_net/__pycache__/moco.cpython-36.pyc
ADDED
Binary file (4.39 kB). View file
|
|
text_net/__pycache__/moco.cpython-38.pyc
ADDED
Binary file (4.43 kB). View file
|
|
text_net/__pycache__/model.cpython-310.pyc
ADDED
Binary file (936 Bytes). View file
|
|
text_net/__pycache__/model.cpython-36.pyc
ADDED
Binary file (914 Bytes). View file
|
|
text_net/__pycache__/model.cpython-38.pyc
ADDED
Binary file (916 Bytes). View file
|
|
text_net/deform_conv.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn.modules.utils import _pair
|
6 |
+
|
7 |
+
from mmcv.ops import modulated_deform_conv2d
|
8 |
+
|
9 |
+
|
10 |
+
class DCN_layer(nn.Module):
|
11 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
|
12 |
+
groups=1, deformable_groups=1, bias=True, extra_offset_mask=True):
|
13 |
+
super(DCN_layer, self).__init__()
|
14 |
+
self.in_channels = in_channels
|
15 |
+
self.out_channels = out_channels
|
16 |
+
self.kernel_size = _pair(kernel_size)
|
17 |
+
self.stride = stride
|
18 |
+
self.padding = padding
|
19 |
+
self.dilation = dilation
|
20 |
+
self.groups = groups
|
21 |
+
self.deformable_groups = deformable_groups
|
22 |
+
self.with_bias = bias
|
23 |
+
|
24 |
+
self.weight = nn.Parameter(
|
25 |
+
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
26 |
+
|
27 |
+
self.extra_offset_mask = extra_offset_mask
|
28 |
+
self.conv_offset_mask = nn.Conv2d(
|
29 |
+
self.in_channels * 2,
|
30 |
+
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
|
31 |
+
kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
|
32 |
+
bias=True
|
33 |
+
)
|
34 |
+
|
35 |
+
if bias:
|
36 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
37 |
+
else:
|
38 |
+
self.register_parameter('bias', None)
|
39 |
+
|
40 |
+
self.init_offset()
|
41 |
+
self.reset_parameters()
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
n = self.in_channels
|
45 |
+
for k in self.kernel_size:
|
46 |
+
n *= k
|
47 |
+
stdv = 1. / math.sqrt(n)
|
48 |
+
self.weight.data.uniform_(-stdv, stdv)
|
49 |
+
if self.bias is not None:
|
50 |
+
self.bias.data.zero_()
|
51 |
+
|
52 |
+
def init_offset(self):
|
53 |
+
self.conv_offset_mask.weight.data.zero_()
|
54 |
+
self.conv_offset_mask.bias.data.zero_()
|
55 |
+
|
56 |
+
def forward(self, input_feat, inter):
|
57 |
+
feat_degradation = torch.cat([input_feat, inter], dim=1)
|
58 |
+
|
59 |
+
out = self.conv_offset_mask(feat_degradation)
|
60 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
61 |
+
offset = torch.cat((o1, o2), dim=1)
|
62 |
+
mask = torch.sigmoid(mask)
|
63 |
+
|
64 |
+
return modulated_deform_conv2d(input_feat.contiguous(), offset, mask, self.weight, self.bias, self.stride,
|
65 |
+
self.padding, self.dilation, self.groups, self.deformable_groups)
|
text_net/encoder.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from text_net.moco import MoCo
|
3 |
+
|
4 |
+
|
5 |
+
class ResBlock(nn.Module):
|
6 |
+
def __init__(self, in_feat, out_feat, stride=1):
|
7 |
+
super(ResBlock, self).__init__()
|
8 |
+
self.backbone = nn.Sequential(
|
9 |
+
nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=stride, padding=1, bias=False),
|
10 |
+
nn.BatchNorm2d(out_feat),
|
11 |
+
nn.LeakyReLU(0.1, True),
|
12 |
+
nn.Conv2d(out_feat, out_feat, kernel_size=3, padding=1, bias=False),
|
13 |
+
nn.BatchNorm2d(out_feat),
|
14 |
+
)
|
15 |
+
self.shortcut = nn.Sequential(
|
16 |
+
nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=stride, bias=False),
|
17 |
+
nn.BatchNorm2d(out_feat)
|
18 |
+
)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return nn.LeakyReLU(0.1, True)(self.backbone(x) + self.shortcut(x))
|
22 |
+
|
23 |
+
|
24 |
+
class ResEncoder(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(ResEncoder, self).__init__()
|
27 |
+
|
28 |
+
self.E_pre = ResBlock(in_feat=3, out_feat=64, stride=1)
|
29 |
+
self.E = nn.Sequential(
|
30 |
+
ResBlock(in_feat=64, out_feat=128, stride=2),
|
31 |
+
ResBlock(in_feat=128, out_feat=256, stride=2),
|
32 |
+
nn.AdaptiveAvgPool2d(1)
|
33 |
+
)
|
34 |
+
|
35 |
+
self.mlp = nn.Sequential(
|
36 |
+
nn.Linear(256, 256),
|
37 |
+
nn.LeakyReLU(0.1, True),
|
38 |
+
nn.Linear(256, 256),
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
inter = self.E_pre(x)
|
43 |
+
fea = self.E(inter).squeeze(-1).squeeze(-1)
|
44 |
+
out = self.mlp(fea)
|
45 |
+
|
46 |
+
return fea, out, inter
|
47 |
+
|
48 |
+
|
49 |
+
class CBDE(nn.Module):
|
50 |
+
def __init__(self, opt):
|
51 |
+
super(CBDE, self).__init__()
|
52 |
+
|
53 |
+
dim = 256
|
54 |
+
|
55 |
+
# Encoder
|
56 |
+
self.E = MoCo(base_encoder=ResEncoder, dim=dim, K=opt.batch_size * dim)
|
57 |
+
|
58 |
+
def forward(self, x_query, x_key):
|
59 |
+
if self.training:
|
60 |
+
# degradation-aware represenetion learning
|
61 |
+
fea, logits, labels, inter = self.E(x_query, x_key)
|
62 |
+
|
63 |
+
return fea, logits, labels, inter
|
64 |
+
else:
|
65 |
+
# degradation-aware represenetion learning
|
66 |
+
fea, inter = self.E(x_query, x_query)
|
67 |
+
return fea, inter
|
text_net/moco.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class MoCo(nn.Module):
|
7 |
+
"""
|
8 |
+
Build a MoCo model with: a query encoder, a key encoder, and a queue
|
9 |
+
https://arxiv.org/abs/1911.05722
|
10 |
+
"""
|
11 |
+
def __init__(self, base_encoder, dim=256, K=3*256, m=0.999, T=0.07, mlp=False):
|
12 |
+
"""
|
13 |
+
dim: feature dimension (default: 128)
|
14 |
+
K: queue size; number of negative keys (default: 65536)
|
15 |
+
m: moco momentum of updating key encoder (default: 0.999)
|
16 |
+
T: softmax temperature (default: 0.07)
|
17 |
+
"""
|
18 |
+
super(MoCo, self).__init__()
|
19 |
+
|
20 |
+
self.K = K
|
21 |
+
self.m = m
|
22 |
+
self.T = T
|
23 |
+
|
24 |
+
# create the encoders
|
25 |
+
# num_classes is the output fc dimension
|
26 |
+
self.encoder_q = base_encoder()
|
27 |
+
self.encoder_k = base_encoder()
|
28 |
+
|
29 |
+
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
30 |
+
param_k.data.copy_(param_q.data) # initialize
|
31 |
+
param_k.requires_grad = False # not update by gradient
|
32 |
+
|
33 |
+
# create the queue
|
34 |
+
self.register_buffer("queue", torch.randn(dim, K))
|
35 |
+
self.queue = nn.functional.normalize(self.queue, dim=0)
|
36 |
+
|
37 |
+
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def _momentum_update_key_encoder(self):
|
41 |
+
"""
|
42 |
+
Momentum update of the key encoder
|
43 |
+
"""
|
44 |
+
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
45 |
+
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
|
46 |
+
|
47 |
+
@torch.no_grad()
|
48 |
+
def _dequeue_and_enqueue(self, keys):
|
49 |
+
# gather keys before updating queue
|
50 |
+
# keys = concat_all_gather(keys)
|
51 |
+
batch_size = keys.shape[0]
|
52 |
+
|
53 |
+
ptr = int(self.queue_ptr)
|
54 |
+
assert self.K % batch_size == 0 # for simplicity
|
55 |
+
|
56 |
+
# replace the keys at ptr (dequeue and enqueue)
|
57 |
+
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
|
58 |
+
ptr = (ptr + batch_size) % self.K # move pointer
|
59 |
+
|
60 |
+
self.queue_ptr[0] = ptr
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def _batch_shuffle_ddp(self, x):
|
64 |
+
"""
|
65 |
+
Batch shuffle, for making use of BatchNorm.
|
66 |
+
*** Only support DistributedDataParallel (DDP) model. ***
|
67 |
+
"""
|
68 |
+
# gather from all gpus
|
69 |
+
batch_size_this = x.shape[0]
|
70 |
+
x_gather = concat_all_gather(x)
|
71 |
+
batch_size_all = x_gather.shape[0]
|
72 |
+
|
73 |
+
num_gpus = batch_size_all // batch_size_this
|
74 |
+
|
75 |
+
# random shuffle index
|
76 |
+
idx_shuffle = torch.randperm(batch_size_all).cuda()
|
77 |
+
|
78 |
+
# broadcast to all gpus
|
79 |
+
torch.distributed.broadcast(idx_shuffle, src=0)
|
80 |
+
|
81 |
+
# index for restoring
|
82 |
+
idx_unshuffle = torch.argsort(idx_shuffle)
|
83 |
+
|
84 |
+
# shuffled index for this gpu
|
85 |
+
gpu_idx = torch.distributed.get_rank()
|
86 |
+
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
87 |
+
|
88 |
+
return x_gather[idx_this], idx_unshuffle
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
|
92 |
+
"""
|
93 |
+
Undo batch shuffle.
|
94 |
+
*** Only support DistributedDataParallel (DDP) model. ***
|
95 |
+
"""
|
96 |
+
# gather from all gpus
|
97 |
+
batch_size_this = x.shape[0]
|
98 |
+
x_gather = concat_all_gather(x)
|
99 |
+
batch_size_all = x_gather.shape[0]
|
100 |
+
|
101 |
+
num_gpus = batch_size_all // batch_size_this
|
102 |
+
|
103 |
+
# restored index for this gpu
|
104 |
+
gpu_idx = torch.distributed.get_rank()
|
105 |
+
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
106 |
+
|
107 |
+
return x_gather[idx_this]
|
108 |
+
|
109 |
+
def forward(self, im_q, im_k):
|
110 |
+
"""
|
111 |
+
Input:
|
112 |
+
im_q: a batch of query images
|
113 |
+
im_k: a batch of key images
|
114 |
+
Output:
|
115 |
+
logits, targets
|
116 |
+
"""
|
117 |
+
if self.training:
|
118 |
+
# compute query features
|
119 |
+
embedding, q, inter = self.encoder_q(im_q) # queries: NxC
|
120 |
+
q = nn.functional.normalize(q, dim=1)
|
121 |
+
|
122 |
+
# compute key features
|
123 |
+
with torch.no_grad(): # no gradient to keys
|
124 |
+
self._momentum_update_key_encoder() # update the key encoder
|
125 |
+
|
126 |
+
_, k, _ = self.encoder_k(im_k) # keys: NxC
|
127 |
+
k = nn.functional.normalize(k, dim=1)
|
128 |
+
|
129 |
+
# compute logits
|
130 |
+
# Einstein sum is more intuitive
|
131 |
+
# positive logits: Nx1
|
132 |
+
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
|
133 |
+
# negative logits: NxK
|
134 |
+
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
|
135 |
+
|
136 |
+
# logits: Nx(1+K)
|
137 |
+
logits = torch.cat([l_pos, l_neg], dim=1)
|
138 |
+
|
139 |
+
# apply temperature
|
140 |
+
logits /= self.T
|
141 |
+
|
142 |
+
# labels: positive key indicators
|
143 |
+
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
|
144 |
+
# dequeue and enqueue
|
145 |
+
self._dequeue_and_enqueue(k)
|
146 |
+
|
147 |
+
return embedding, logits, labels, inter
|
148 |
+
else:
|
149 |
+
embedding, _, inter = self.encoder_q(im_q)
|
150 |
+
|
151 |
+
return embedding, inter
|
152 |
+
|
153 |
+
|
154 |
+
# utils
|
155 |
+
@torch.no_grad()
|
156 |
+
def concat_all_gather(tensor):
|
157 |
+
"""
|
158 |
+
Performs all_gather operation on the provided tensors.
|
159 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
160 |
+
"""
|
161 |
+
tensors_gather = [torch.ones_like(tensor)
|
162 |
+
for _ in range(torch.distributed.get_world_size())]
|
163 |
+
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
164 |
+
|
165 |
+
output = torch.cat(tensors_gather, dim=0)
|
166 |
+
return output
|
text_net/model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
from text_net.encoder import CBDE
|
4 |
+
from text_net.DGRN import DGRN
|
5 |
+
|
6 |
+
|
7 |
+
class AirNet(nn.Module):
|
8 |
+
def __init__(self, opt):
|
9 |
+
super(AirNet, self).__init__()
|
10 |
+
|
11 |
+
# Restorer
|
12 |
+
self.R = DGRN(opt)
|
13 |
+
|
14 |
+
# Encoder
|
15 |
+
self.E = CBDE(opt)
|
16 |
+
|
17 |
+
def forward(self, x_query, x_key, text_prompt):
|
18 |
+
if self.training:
|
19 |
+
fea, logits, labels, inter = self.E(x_query, x_key)
|
20 |
+
|
21 |
+
restored = self.R(x_query, inter, text_prompt)
|
22 |
+
|
23 |
+
return restored, logits, labels
|
24 |
+
else:
|
25 |
+
fea, inter = self.E(x_query, x_query)
|
26 |
+
|
27 |
+
restored = self.R(x_query, inter, text_prompt)
|
28 |
+
|
29 |
+
return restored
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (138 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-36.pyc
ADDED
Binary file (123 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (148 Bytes). View file
|
|
utils/__pycache__/dataset_utils.cpython-310.pyc
ADDED
Binary file (11.2 kB). View file
|
|
utils/__pycache__/dataset_utils.cpython-36.pyc
ADDED
Binary file (30.8 kB). View file
|
|
utils/__pycache__/dataset_utils.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
utils/__pycache__/dataset_utils_CDD.cpython-310.pyc
ADDED
Binary file (1.58 kB). View file
|
|
utils/__pycache__/degradation_utils.cpython-310.pyc
ADDED
Binary file (1.8 kB). View file
|
|
utils/__pycache__/degradation_utils.cpython-36.pyc
ADDED
Binary file (3.4 kB). View file
|
|
utils/__pycache__/degradation_utils.cpython-38.pyc
ADDED
Binary file (1.79 kB). View file
|
|
utils/__pycache__/image_io.cpython-310.pyc
ADDED
Binary file (11 kB). View file
|
|
utils/__pycache__/image_io.cpython-36.pyc
ADDED
Binary file (11.3 kB). View file
|
|
utils/__pycache__/image_io.cpython-38.pyc
ADDED
Binary file (11.1 kB). View file
|
|
utils/__pycache__/image_utils.cpython-310.pyc
ADDED
Binary file (7.48 kB). View file
|
|
utils/__pycache__/image_utils.cpython-36.pyc
ADDED
Binary file (7.61 kB). View file
|
|
utils/__pycache__/image_utils.cpython-38.pyc
ADDED
Binary file (7.46 kB). View file
|
|
utils/__pycache__/imresize.cpython-36.pyc
ADDED
Binary file (4.75 kB). View file
|
|
utils/__pycache__/imresize.cpython-38.pyc
ADDED
Binary file (4.75 kB). View file
|
|
utils/__pycache__/loss_utils.cpython-38.pyc
ADDED
Binary file (1.43 kB). View file
|
|
utils/__pycache__/val_utils.cpython-310.pyc
ADDED
Binary file (3.35 kB). View file
|
|
utils/__pycache__/val_utils.cpython-36.pyc
ADDED
Binary file (2.34 kB). View file
|
|
utils/__pycache__/val_utils.cpython-38.pyc
ADDED
Binary file (3.27 kB). View file
|
|
utils/dataset_utils.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import copy
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
|
9 |
+
|
10 |
+
from utils.image_utils import random_augmentation, crop_img
|
11 |
+
from utils.degradation_utils import Degradation
|
12 |
+
|
13 |
+
|
14 |
+
class TrainDataset(Dataset):
|
15 |
+
def __init__(self, args):
|
16 |
+
super(TrainDataset, self).__init__()
|
17 |
+
self.args = args
|
18 |
+
self.rs_ids = []
|
19 |
+
self.hazy_ids = []
|
20 |
+
self.D = Degradation(args)
|
21 |
+
self.de_temp = 0
|
22 |
+
self.de_type = self.args.de_type
|
23 |
+
self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
|
24 |
+
|
25 |
+
self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4}
|
26 |
+
|
27 |
+
self._init_ids()
|
28 |
+
|
29 |
+
self.crop_transform = Compose([
|
30 |
+
ToPILImage(),
|
31 |
+
RandomCrop(args.patch_size),
|
32 |
+
])
|
33 |
+
|
34 |
+
self.toTensor = ToTensor()
|
35 |
+
|
36 |
+
def _init_ids(self):
|
37 |
+
if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type:
|
38 |
+
self._init_clean_ids()
|
39 |
+
if 'derain' in self.de_type:
|
40 |
+
self._init_rs_ids()
|
41 |
+
if 'dehaze' in self.de_type:
|
42 |
+
self._init_hazy_ids()
|
43 |
+
|
44 |
+
random.shuffle(self.de_type)
|
45 |
+
|
46 |
+
def _init_clean_ids(self):
|
47 |
+
clean_ids = []
|
48 |
+
# 파일 목록 중 이미지 파일만 필터링
|
49 |
+
name_list = os.listdir(self.args.denoise_dir)
|
50 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
51 |
+
|
52 |
+
clean_ids += [self.args.denoise_dir + id_ for id_ in name_list]
|
53 |
+
|
54 |
+
if 'denoise_15' in self.de_type:
|
55 |
+
self.s15_ids = copy.deepcopy(clean_ids)
|
56 |
+
random.shuffle(self.s15_ids)
|
57 |
+
self.s15_counter = 0
|
58 |
+
if 'denoise_25' in self.de_type:
|
59 |
+
self.s25_ids = copy.deepcopy(clean_ids)
|
60 |
+
random.shuffle(self.s25_ids)
|
61 |
+
self.s25_counter = 0
|
62 |
+
if 'denoise_50' in self.de_type:
|
63 |
+
self.s50_ids = copy.deepcopy(clean_ids)
|
64 |
+
random.shuffle(self.s50_ids)
|
65 |
+
self.s50_counter = 0
|
66 |
+
|
67 |
+
# print(clean_ids)
|
68 |
+
|
69 |
+
self.num_clean = len(clean_ids)
|
70 |
+
|
71 |
+
def _init_hazy_ids(self):
|
72 |
+
# 파일 목록 중 이미지 파일만 필터링
|
73 |
+
dehaze_ids = []
|
74 |
+
name_list = os.listdir(self.args.dehaze_dir)
|
75 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
76 |
+
dehaze_ids += [self.args.dehaze_dir + id_ for id_ in name_list]
|
77 |
+
self.hazy_ids = dehaze_ids
|
78 |
+
|
79 |
+
self.hazy_counter = 0
|
80 |
+
self.num_hazy = len(self.hazy_ids)
|
81 |
+
|
82 |
+
def _init_rs_ids(self):
|
83 |
+
# 파일 목록 중 이미지 파일만 필터링
|
84 |
+
derain_ids = []
|
85 |
+
name_list = os.listdir(self.args.derain_dir)
|
86 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
87 |
+
derain_ids += [self.args.derain_dir + id_ for id_ in name_list]
|
88 |
+
self.rs_ids = derain_ids
|
89 |
+
|
90 |
+
self.rl_counter = 0
|
91 |
+
# print(derain_ids)
|
92 |
+
|
93 |
+
self.num_rl = len(self.rs_ids)
|
94 |
+
|
95 |
+
def _crop_patch(self, img_1, img_2):
|
96 |
+
H = img_1.shape[0]
|
97 |
+
W = img_1.shape[1]
|
98 |
+
ind_H = random.randint(0, H - self.args.patch_size)
|
99 |
+
ind_W = random.randint(0, W - self.args.patch_size)
|
100 |
+
|
101 |
+
patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
|
102 |
+
patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
|
103 |
+
|
104 |
+
return patch_1, patch_2
|
105 |
+
|
106 |
+
def _get_gt_name(self, rainy_name):
|
107 |
+
gt_name = 'data/' + 'Target/Derain/norain-' + rainy_name.split('rain-')[-1]
|
108 |
+
return gt_name
|
109 |
+
|
110 |
+
def _get_nonhazy_name(self, hazy_name):
|
111 |
+
gt_name = 'data/' + 'Target/Dehaze/nohaze-' + rainy_name.split('haze-')[-1]
|
112 |
+
return gt_name
|
113 |
+
|
114 |
+
def __getitem__(self, _):
|
115 |
+
de_id = self.de_dict[self.de_type[self.de_temp]]
|
116 |
+
|
117 |
+
if de_id < 3:
|
118 |
+
if de_id == 0:
|
119 |
+
clean_id = self.s15_ids[self.s15_counter]
|
120 |
+
self.s15_counter = (self.s15_counter + 1) % self.num_clean
|
121 |
+
if self.s15_counter == 0:
|
122 |
+
random.shuffle(self.s15_ids)
|
123 |
+
elif de_id == 1:
|
124 |
+
clean_id = self.s25_ids[self.s25_counter]
|
125 |
+
self.s25_counter = (self.s25_counter + 1) % self.num_clean
|
126 |
+
if self.s25_counter == 0:
|
127 |
+
random.shuffle(self.s25_ids)
|
128 |
+
elif de_id == 2:
|
129 |
+
clean_id = self.s50_ids[self.s50_counter]
|
130 |
+
self.s50_counter = (self.s50_counter + 1) % self.num_clean
|
131 |
+
if self.s50_counter == 0:
|
132 |
+
random.shuffle(self.s50_ids)
|
133 |
+
|
134 |
+
# clean_id = random.randint(0, len(self.clean_ids) - 1)
|
135 |
+
clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
|
136 |
+
clean_patch_1, clean_patch_2 = self.crop_transform(clean_img), self.crop_transform(clean_img)
|
137 |
+
clean_patch_1, clean_patch_2 = np.array(clean_patch_1), np.array(clean_patch_2)
|
138 |
+
|
139 |
+
# clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
|
140 |
+
clean_name = clean_id.split("/")[-1].split('.')[0]
|
141 |
+
|
142 |
+
clean_patch_1, clean_patch_2 = random_augmentation(clean_patch_1, clean_patch_2)
|
143 |
+
degrad_patch_1, degrad_patch_2 = self.D.degrade(clean_patch_1, clean_patch_2, de_id)
|
144 |
+
else:
|
145 |
+
if de_id == 3:
|
146 |
+
# Rain Streak Removal
|
147 |
+
# rl_id = random.randint(0, len(self.rl_ids) - 1)
|
148 |
+
degrad_img = crop_img(np.array(Image.open(self.rs_ids[self.rl_counter]).convert('RGB')), base=16)
|
149 |
+
clean_name = self._get_gt_name(self.rs_ids[self.rl_counter])
|
150 |
+
clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
|
151 |
+
|
152 |
+
self.rl_counter = (self.rl_counter + 1) % self.num_rl
|
153 |
+
if self.rl_counter == 0:
|
154 |
+
random.shuffle(self.rs_ids)
|
155 |
+
elif de_id == 4:
|
156 |
+
# Dehazing with SOTS outdoor training set
|
157 |
+
# hazy_id = random.randint(0, len(self.hazy_ids) - 1)
|
158 |
+
degrad_img = crop_img(np.array(Image.open(self.hazy_ids[self.hazy_counter]).convert('RGB')), base=16)
|
159 |
+
clean_name = self._get_nonhazy_name(self.hazy_ids[self.hazy_counter])
|
160 |
+
clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
|
161 |
+
|
162 |
+
self.hazy_counter = (self.hazy_counter + 1) % self.num_hazy
|
163 |
+
if self.hazy_counter == 0:
|
164 |
+
random.shuffle(self.hazy_ids)
|
165 |
+
degrad_patch_1, clean_patch_1 = random_augmentation(*self._crop_patch(degrad_img, clean_img))
|
166 |
+
degrad_patch_2, clean_patch_2 = random_augmentation(*self._crop_patch(degrad_img, clean_img))
|
167 |
+
|
168 |
+
clean_patch_1, clean_patch_2 = self.toTensor(clean_patch_1), self.toTensor(clean_patch_2)
|
169 |
+
degrad_patch_1, degrad_patch_2 = self.toTensor(degrad_patch_1), self.toTensor(degrad_patch_2)
|
170 |
+
|
171 |
+
self.de_temp = (self.de_temp + 1) % len(self.de_type)
|
172 |
+
if self.de_temp == 0:
|
173 |
+
random.shuffle(self.de_type)
|
174 |
+
|
175 |
+
return [clean_name, de_id], degrad_patch_1, degrad_patch_2, clean_patch_1, clean_patch_2
|
176 |
+
|
177 |
+
def __len__(self):
|
178 |
+
return 400 * len(self.args.de_type)
|
179 |
+
|
180 |
+
|
181 |
+
class DenoiseTestDataset(Dataset):
|
182 |
+
def __init__(self, args):
|
183 |
+
super(DenoiseTestDataset, self).__init__()
|
184 |
+
self.args = args
|
185 |
+
self.clean_ids = []
|
186 |
+
self.sigma = 15
|
187 |
+
self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
|
188 |
+
|
189 |
+
self._init_clean_ids()
|
190 |
+
|
191 |
+
self.toTensor = ToTensor()
|
192 |
+
|
193 |
+
def _init_clean_ids(self):
|
194 |
+
clean_ids = []
|
195 |
+
# 파일 목록 중 이미지 파일만 필터링
|
196 |
+
name_list = os.listdir(self.args.denoise_path)
|
197 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
198 |
+
self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list]
|
199 |
+
|
200 |
+
self.num_clean = len(self.clean_ids)
|
201 |
+
|
202 |
+
def _add_gaussian_noise(self, clean_patch):
|
203 |
+
noise = np.random.randn(*clean_patch.shape)
|
204 |
+
noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8)
|
205 |
+
return noisy_patch, clean_patch
|
206 |
+
|
207 |
+
def set_sigma(self, sigma):
|
208 |
+
self.sigma = sigma
|
209 |
+
|
210 |
+
def __getitem__(self, clean_id):
|
211 |
+
clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16)
|
212 |
+
clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
|
213 |
+
|
214 |
+
noisy_img, _ = self._add_gaussian_noise(clean_img)
|
215 |
+
clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img)
|
216 |
+
|
217 |
+
return [clean_name], noisy_img, clean_img
|
218 |
+
|
219 |
+
def __len__(self):
|
220 |
+
return self.num_clean
|
221 |
+
|
222 |
+
|
223 |
+
class DerainDehazeDataset(Dataset):
|
224 |
+
def __init__(self, args, task="derain"):
|
225 |
+
super(DerainDehazeDataset, self).__init__()
|
226 |
+
self.ids = []
|
227 |
+
self.task_idx = 0
|
228 |
+
self.args = args
|
229 |
+
self.image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif']
|
230 |
+
|
231 |
+
self.task_dict = {'derain': 0, 'dehaze': 1}
|
232 |
+
self.toTensor = ToTensor()
|
233 |
+
|
234 |
+
self.set_dataset(task)
|
235 |
+
|
236 |
+
def _init_input_ids(self):
|
237 |
+
if self.task_idx == 0:
|
238 |
+
self.ids = []
|
239 |
+
# 파일 목록 중 이미지 파일만 필터링
|
240 |
+
name_list = os.listdir(self.args.derain_path + 'input/')
|
241 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
242 |
+
self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list]
|
243 |
+
elif self.task_idx == 1:
|
244 |
+
self.ids = []
|
245 |
+
# 파일 목록 중 이미지 파일만 필터링
|
246 |
+
name_list = os.listdir(self.args.dehaze_path + 'input/')
|
247 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
248 |
+
self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list]
|
249 |
+
|
250 |
+
self.length = len(self.ids)
|
251 |
+
|
252 |
+
def _get_gt_path(self, degraded_name):
|
253 |
+
if self.task_idx == 0:
|
254 |
+
gt_name = '/'.join(degraded_name.replace("input", "target").split('/')[:-1] + degraded_name.replace("input", "target").replace("rain", "norain").split('/')[-1:])
|
255 |
+
print(gt_name)
|
256 |
+
elif self.task_idx == 1:
|
257 |
+
dir_name = degraded_name.split("input")[0] + 'target/'
|
258 |
+
name = degraded_name.split('/')[-1].split('_')[0] + '.png'
|
259 |
+
gt_name = dir_name + name
|
260 |
+
return gt_name
|
261 |
+
|
262 |
+
def set_dataset(self, task):
|
263 |
+
self.task_idx = self.task_dict[task]
|
264 |
+
self._init_input_ids()
|
265 |
+
|
266 |
+
def __getitem__(self, idx):
|
267 |
+
degraded_path = self.ids[idx]
|
268 |
+
clean_path = self._get_gt_path(degraded_path)
|
269 |
+
|
270 |
+
degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16)
|
271 |
+
clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16)
|
272 |
+
|
273 |
+
clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
|
274 |
+
degraded_name = degraded_path.split('/')[-1][:-4]
|
275 |
+
|
276 |
+
return [degraded_name], degraded_img, clean_img
|
277 |
+
|
278 |
+
def __len__(self):
|
279 |
+
return self.length
|
280 |
+
|
281 |
+
|
282 |
+
class TestSpecificDataset(Dataset):
|
283 |
+
def __init__(self, args):
|
284 |
+
super(TestSpecificDataset, self).__init__()
|
285 |
+
self.args = args
|
286 |
+
self.degraded_ids = []
|
287 |
+
self._init_clean_ids(args.test_path)
|
288 |
+
|
289 |
+
self.toTensor = ToTensor()
|
290 |
+
|
291 |
+
def _init_clean_ids(self, root):
|
292 |
+
degraded_ids = []
|
293 |
+
# 파일 목록 중 이미지 파일만 필터링
|
294 |
+
name_list = os.listdir(root)
|
295 |
+
name_list = [file for file in name_list if os.path.splitext(file)[1].lower() in self.image_extensions]
|
296 |
+
self.degraded_ids += [root + id_ for id_ in name_list]
|
297 |
+
|
298 |
+
self.num_img = len(self.degraded_ids)
|
299 |
+
|
300 |
+
def __getitem__(self, idx):
|
301 |
+
degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16)
|
302 |
+
name = self.degraded_ids[idx].split('/')[-1][:-4]
|
303 |
+
|
304 |
+
degraded_img = self.toTensor(degraded_img)
|
305 |
+
|
306 |
+
return [name], degraded_img
|
307 |
+
|
308 |
+
def __len__(self):
|
309 |
+
return self.num_img
|