DS
commited on
Commit
•
e5b70eb
1
Parent(s):
b19f11c
dump shiet
Browse files- .gitignore +1 -0
- Dockerfile +21 -0
- MeasureV1.py +123 -0
- README.md +3 -5
- app.py +99 -0
- compare.py +47 -0
- crop_test.py +5 -0
- flagged/Input/tmplp_isgr5.jpg +0 -0
- flagged/log.csv +2 -0
- images_uploaded/0805.png +0 -0
- images_uploaded/0821.png +0 -0
- images_uploaded/0873.png +0 -0
- images_uploaded/1.png +0 -0
- inference.py +97 -0
- model/__pycache__/decoder.cpython-38.pyc +0 -0
- model/__pycache__/discriminator.cpython-38.pyc +0 -0
- model/__pycache__/encoder.cpython-38.pyc +0 -0
- model/decoder.py +240 -0
- model/discriminator.py +121 -0
- model/encoder.py +67 -0
- opt/__pycache__/option.cpython-38.pyc +0 -0
- opt/option.py +69 -0
- requirements.txt +10 -0
- test/1.png +0 -0
- testsets/0848.png +0 -0
- testsets/0851.png +0 -0
- testsets/0855.png +0 -0
- testsets/0879.png +0 -0
- train.py +474 -0
- util/__pycache__/utils.cpython-38.pyc +0 -0
- util/utils.py +133 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
weights/best_weight.pth
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
USER root
|
6 |
+
|
7 |
+
WORKDIR /code
|
8 |
+
|
9 |
+
COPY ./requirements.txt /code/requirements.txt
|
10 |
+
|
11 |
+
RUN apt-get update
|
12 |
+
RUN apt-get install ffmpeg libsm6 libxext6 -y
|
13 |
+
|
14 |
+
RUN pip install --upgrade pip
|
15 |
+
|
16 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
17 |
+
|
18 |
+
COPY . .
|
19 |
+
|
20 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
21 |
+
|
MeasureV1.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from collections import OrderedDict
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
import argparse
|
10 |
+
|
11 |
+
from natsort import natsort
|
12 |
+
from skimage.metrics import structural_similarity as compare_ssim
|
13 |
+
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
|
14 |
+
import lpips
|
15 |
+
|
16 |
+
|
17 |
+
class Measure():
|
18 |
+
def __init__(self, net='alex', use_gpu=False):
|
19 |
+
self.device = 'cuda' if use_gpu else 'cpu'
|
20 |
+
self.model = lpips.LPIPS(net=net)
|
21 |
+
self.model.to(self.device)
|
22 |
+
|
23 |
+
def measure(self, imgA, imgB):
|
24 |
+
if not all([s1 == s2 for s1, s2 in zip(imgA.shape, imgB.shape)]):
|
25 |
+
raise RuntimeError("Image sizes not the same.")
|
26 |
+
return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
|
27 |
+
|
28 |
+
def lpips(self, imgA, imgB, model=None):
|
29 |
+
tA = t(imgA).to(self.device)
|
30 |
+
tB = t(imgB).to(self.device)
|
31 |
+
dist01 = self.model.forward(tA, tB).item()
|
32 |
+
return dist01
|
33 |
+
|
34 |
+
def ssim(self, imgA, imgB):
|
35 |
+
# multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
|
36 |
+
score, diff = compare_ssim(imgA, imgB, full=True, multichannel=True)
|
37 |
+
return score
|
38 |
+
|
39 |
+
def psnr(self, imgA, imgB):
|
40 |
+
psnr = compare_psnr(imgA, imgB)
|
41 |
+
return psnr
|
42 |
+
|
43 |
+
|
44 |
+
def t(img):
|
45 |
+
def to_4d(img):
|
46 |
+
assert len(img.shape) == 3
|
47 |
+
assert img.dtype == np.uint8
|
48 |
+
img_new = np.expand_dims(img, axis=0)
|
49 |
+
assert len(img_new.shape) == 4
|
50 |
+
return img_new
|
51 |
+
|
52 |
+
def to_CHW(img):
|
53 |
+
return np.transpose(img, [2, 0, 1])
|
54 |
+
|
55 |
+
def to_tensor(img):
|
56 |
+
return torch.Tensor(img)
|
57 |
+
|
58 |
+
return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
|
59 |
+
|
60 |
+
|
61 |
+
def fiFindByWildcard(wildcard):
|
62 |
+
return natsort.natsorted(glob.glob(wildcard, recursive=True))
|
63 |
+
|
64 |
+
|
65 |
+
def imread(path):
|
66 |
+
return cv2.imread(path)[:, :, [2, 1, 0]]
|
67 |
+
|
68 |
+
|
69 |
+
def format_result(psnr, ssim, lpips):
|
70 |
+
return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
|
71 |
+
|
72 |
+
def measure_dirs(dirA, dirB, use_gpu, verbose=False):
|
73 |
+
if verbose:
|
74 |
+
vprint = lambda x: print(x)
|
75 |
+
else:
|
76 |
+
vprint = lambda x: None
|
77 |
+
|
78 |
+
|
79 |
+
t_init = time.time()
|
80 |
+
|
81 |
+
paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
|
82 |
+
paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
|
83 |
+
|
84 |
+
vprint("Comparing: ")
|
85 |
+
vprint(dirA)
|
86 |
+
vprint(dirB)
|
87 |
+
|
88 |
+
measure = Measure(use_gpu=use_gpu)
|
89 |
+
|
90 |
+
results = []
|
91 |
+
for pathA, pathB in zip(paths_A, paths_B):
|
92 |
+
result = OrderedDict()
|
93 |
+
|
94 |
+
t = time.time()
|
95 |
+
result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB))
|
96 |
+
d = time.time() - t
|
97 |
+
vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
|
98 |
+
|
99 |
+
results.append(result)
|
100 |
+
|
101 |
+
psnr = np.mean([result['psnr'] for result in results])
|
102 |
+
ssim = np.mean([result['ssim'] for result in results])
|
103 |
+
lpips = np.mean([result['lpips'] for result in results])
|
104 |
+
|
105 |
+
vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument('-dirA', default='', type=str)
|
111 |
+
parser.add_argument('-dirB', default='', type=str)
|
112 |
+
parser.add_argument('-type', default='png')
|
113 |
+
parser.add_argument('--use_gpu', action='store_true', default=False)
|
114 |
+
args = parser.parse_args()
|
115 |
+
|
116 |
+
dirA = args.dirA
|
117 |
+
dirB = args.dirB
|
118 |
+
type = args.type
|
119 |
+
use_gpu = args.use_gpu
|
120 |
+
|
121 |
+
if len(dirA) > 0 and len(dirB) > 0:
|
122 |
+
measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
|
123 |
+
|
README.md
CHANGED
@@ -1,11 +1,9 @@
|
|
1 |
---
|
2 |
title: USR DA
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
sdk_version: 3.14.0
|
8 |
-
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
|
|
1 |
---
|
2 |
title: USR DA
|
3 |
+
emoji: 😻
|
4 |
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: docker
|
|
|
|
|
7 |
pinned: false
|
8 |
---
|
9 |
|
app.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import wget
|
11 |
+
from torchvision.transforms import Compose, ToTensor
|
12 |
+
|
13 |
+
from model import decoder, encoder
|
14 |
+
|
15 |
+
WEIGHT_PATH = './weights/best_weight.pth'
|
16 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
|
18 |
+
|
19 |
+
class Model(object):
|
20 |
+
def __init__(self) -> None:
|
21 |
+
self.model_Enc = encoder.Encoder_RRDB(num_feat=64).to(device=DEVICE)
|
22 |
+
self.model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=64).to(device=DEVICE)
|
23 |
+
self.preprocess = Compose([ToTensor()])
|
24 |
+
self.load_model()
|
25 |
+
|
26 |
+
def load_model(self, weight_path=WEIGHT_PATH):
|
27 |
+
if not os.path.isfile("./weights/best_weight.pth"):
|
28 |
+
response = wget.download("https://raw.githubusercontent.com/hungnguyen2611/super-resolution/master/weights/best_weight.pth", "./weights/best_weight.pth")
|
29 |
+
weight = torch.load(weight_path)
|
30 |
+
print("[LOADING] Loading encoder...")
|
31 |
+
self.model_Enc.load_state_dict(weight['model_Enc'])
|
32 |
+
print("[LOADING] Loading decoder...")
|
33 |
+
self.model_Dec_SR.load_state_dict(weight['model_Dec_SR'])
|
34 |
+
print("[LOADING] Loading done!")
|
35 |
+
self.model_Enc.eval()
|
36 |
+
self.model_Dec_SR.eval()
|
37 |
+
|
38 |
+
def predict(self, img):
|
39 |
+
with torch.no_grad():
|
40 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
41 |
+
img = self.preprocess(img)
|
42 |
+
img = img.unsqueeze(0)
|
43 |
+
img = img.to(DEVICE)
|
44 |
+
|
45 |
+
feat = self.model_Enc(img)
|
46 |
+
out = self.model_Dec_SR(feat)
|
47 |
+
min_max = (0, 1)
|
48 |
+
out = out.detach()[0].float().cpu()
|
49 |
+
|
50 |
+
out = out.squeeze().float().cpu().clamp_(*min_max)
|
51 |
+
out = (out - min_max[0]) / (min_max[1] - min_max[0])
|
52 |
+
out = out.numpy()
|
53 |
+
out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
|
54 |
+
|
55 |
+
out = (out*255.0).round()
|
56 |
+
out = out.astype(np.uint8)
|
57 |
+
return out
|
58 |
+
|
59 |
+
model = Model()
|
60 |
+
|
61 |
+
def predict(img):
|
62 |
+
global model
|
63 |
+
img.save("test/1.png", "PNG")
|
64 |
+
image = cv2.imread("test/1.png", cv2.IMREAD_COLOR)
|
65 |
+
out = model.predict(img=image)
|
66 |
+
|
67 |
+
cv2.imwrite(f'images_uploaded/1.png', out)
|
68 |
+
return f"images_uploaded/1.png"
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
title = "Super-Resolution Demo USR-DA Unofficial 🚀🚀🔥"
|
75 |
+
description = '''
|
76 |
+
<br>
|
77 |
+
**This Demo expects low-quality and low-resolution images**
|
78 |
+
**We are looking for collaborators! Collaborator**
|
79 |
+
</br>
|
80 |
+
'''
|
81 |
+
article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Unsupervised_Real-World_Super-Resolution_A_Domain_Adaptation_Perspective_ICCV_2021_paper.pdf' target='_blank'>Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective</a> | <a href='https://github.com/hungnguyen2611/super-resolution.git' target='_blank'>Github Repo</a></p>"
|
82 |
+
examples= glob.glob("testsets/*.png")
|
83 |
+
gr.Interface(
|
84 |
+
predict,
|
85 |
+
gr.inputs.Image(type="pil", label="Input").style(height=260),
|
86 |
+
gr.inputs.Image(type="pil", label="Ouput").style(height=240),
|
87 |
+
title=title,
|
88 |
+
description=description,
|
89 |
+
article=article,
|
90 |
+
examples=examples,
|
91 |
+
).launch(enable_queue=True)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
compare.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from collections import OrderedDict
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from natsort import natsort
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def fiFindByWildcard(wildcard):
|
13 |
+
return natsort.natsorted(glob.glob(wildcard, recursive=True))
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
out_data_path = fiFindByWildcard("./results_crop (1)/out/*")
|
20 |
+
gt_data_path = fiFindByWildcard("./results_crop (1)/target/*")
|
21 |
+
source_data_path = fiFindByWildcard("./results_crop (1)/source/*")
|
22 |
+
|
23 |
+
for src_path, out_path, gt_path in tqdm(list(zip(source_data_path, out_data_path, gt_data_path))):
|
24 |
+
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
|
25 |
+
ax1.set_title("Bicubic")
|
26 |
+
ax2.set_title("Baseline")
|
27 |
+
ax3.set_title("Ground truth")
|
28 |
+
|
29 |
+
src = cv2.imread(src_path)[:, :, [2, 1, 0]]
|
30 |
+
out = cv2.imread(out_path)[:, :, [2, 1, 0]]
|
31 |
+
gt = cv2.imread(gt_path)[:, :, [2, 1, 0]]
|
32 |
+
src = cv2.resize(src, None, fx=4, fy=4, interpolation=cv2.INTER_CUBIC)
|
33 |
+
ax1.set_yticklabels([])
|
34 |
+
ax1.set_xticklabels([])
|
35 |
+
ax2.set_yticklabels([])
|
36 |
+
ax2.set_xticklabels([])
|
37 |
+
ax3.set_yticklabels([])
|
38 |
+
ax3.set_xticklabels([])
|
39 |
+
ax1.imshow(src)
|
40 |
+
ax2.imshow(out)
|
41 |
+
ax3.imshow(gt)
|
42 |
+
fig.savefig(f"./result_compare_crop_new/{os.path.basename(gt_path)}", bbox_inches='tight' , dpi=1200)
|
43 |
+
plt.close()
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
crop_test.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
|
flagged/Input/tmplp_isgr5.jpg
ADDED
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Input,Ouput,flag,username,timestamp
|
2 |
+
/home/ds/Documents/SR/USR_DA/USR-DA/USR-DA/flagged/Input/tmplp_isgr5.jpg,,,,2022-12-18 14:45:29.533016
|
images_uploaded/0805.png
ADDED
images_uploaded/0821.png
ADDED
images_uploaded/0873.png
ADDED
images_uploaded/1.png
ADDED
inference.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torch.nn as nn
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from model import encoder, decoder
|
10 |
+
from opt.option import args
|
11 |
+
|
12 |
+
|
13 |
+
# device setting
|
14 |
+
if args.gpu_id is not None:
|
15 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
16 |
+
print('using GPU 0')
|
17 |
+
else:
|
18 |
+
print('use --gpu_id to specify GPU ID to use')
|
19 |
+
exit()
|
20 |
+
|
21 |
+
|
22 |
+
# make directory for saving weights
|
23 |
+
if not os.path.exists(args.results):
|
24 |
+
os.mkdir(args.results)
|
25 |
+
|
26 |
+
|
27 |
+
# numpy array -> torch tensor
|
28 |
+
class ToTensor(object):
|
29 |
+
def __call__(self, sample):
|
30 |
+
sample = np.transpose(sample, (2, 0, 1))
|
31 |
+
sample = torch.from_numpy(sample)
|
32 |
+
return sample
|
33 |
+
|
34 |
+
|
35 |
+
# create model
|
36 |
+
# model_Enc = encoder.Encoder().cuda()
|
37 |
+
# model_Dec_SR = decoder.Decoder_SR().cuda()
|
38 |
+
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).cuda()
|
39 |
+
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).cuda()
|
40 |
+
|
41 |
+
model_Enc = nn.DataParallel(model_Enc)
|
42 |
+
#model_Dec_Id = nn.DataParallel(model_Dec_Id)
|
43 |
+
model_Dec_SR = nn.DataParallel(model_Dec_SR)
|
44 |
+
|
45 |
+
# load weights
|
46 |
+
checkpoint = torch.load(args.weights)
|
47 |
+
model_Enc.load_state_dict(checkpoint['model_Enc'])
|
48 |
+
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
|
49 |
+
model_Enc.eval()
|
50 |
+
model_Dec_SR.eval()
|
51 |
+
|
52 |
+
# input transform
|
53 |
+
transforms = torchvision.transforms.Compose([ToTensor()])
|
54 |
+
|
55 |
+
|
56 |
+
filenames = os.listdir(args.dir_test)
|
57 |
+
filenames.sort()
|
58 |
+
with torch.no_grad():
|
59 |
+
for filename in tqdm(filenames):
|
60 |
+
img_name = os.path.join(args.dir_test, filename)
|
61 |
+
ext = os.path.splitext(img_name)[-1]
|
62 |
+
if ext in ['.png', '.jpg']:
|
63 |
+
img = cv2.imread(img_name)
|
64 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
65 |
+
#img = cv2.resize(img, ((img.shape[1] // 4),(img.shape[0] // 4)))
|
66 |
+
img = np.array(img).astype('float32') / 255
|
67 |
+
# img = img[0:256, 0:256, :]
|
68 |
+
|
69 |
+
img = transforms(img)
|
70 |
+
img = torch.tensor(img.cuda()).unsqueeze(0)
|
71 |
+
|
72 |
+
# inference output
|
73 |
+
feat = model_Enc(img)
|
74 |
+
out = model_Dec_SR(feat)
|
75 |
+
|
76 |
+
min_max = (0, 1)
|
77 |
+
out = out.detach()[0].float().cpu()
|
78 |
+
|
79 |
+
out = out.squeeze().float().cpu().clamp_(*min_max)
|
80 |
+
out = (out - min_max[0]) / (min_max[1] - min_max[0])
|
81 |
+
out = out.numpy()
|
82 |
+
out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
|
83 |
+
|
84 |
+
out = (out*255.0).round()
|
85 |
+
out = out.astype(np.uint8)
|
86 |
+
|
87 |
+
# result image save (b x c x h x w (torch tensor) -> h x w x c (numpy array))
|
88 |
+
# out = out.data.cpu().squeeze().numpy()
|
89 |
+
# out = np.clip(out, 0, 1)
|
90 |
+
# out = np.transpose(out, (1, 2, 0))
|
91 |
+
print(args.results, filename)
|
92 |
+
cv2.imwrite('%s_out.png' %(os.path.join(args.results, filename)[:-4]), out)
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
model/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (7.59 kB). View file
|
|
model/__pycache__/discriminator.cpython-38.pyc
ADDED
Binary file (3.32 kB). View file
|
|
model/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (2.07 kB). View file
|
|
model/decoder.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init as init
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
6 |
+
|
7 |
+
class Decoder_Identity(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(Decoder_Identity, self).__init__()
|
10 |
+
|
11 |
+
self.conv_up_2 = nn.Sequential(
|
12 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
|
15 |
+
nn.ReLU(),
|
16 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
|
17 |
+
nn.ReLU()
|
18 |
+
)
|
19 |
+
|
20 |
+
self.conv_up_1 = nn.Sequential(
|
21 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
|
26 |
+
nn.ReLU()
|
27 |
+
)
|
28 |
+
|
29 |
+
self.conv_last = nn.Sequential(
|
30 |
+
nn.Conv2d(in_channels=16, out_channels=3, kernel_size=1, bias=True),
|
31 |
+
nn.ReLU()
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, feat):
|
35 |
+
featmap_2 = self.conv_up_2(feat)
|
36 |
+
featmap_1 = self.conv_up_1(featmap_2)
|
37 |
+
out = self.conv_last(featmap_1)
|
38 |
+
|
39 |
+
return out
|
40 |
+
|
41 |
+
|
42 |
+
class Decoder_SR(nn.Module):
|
43 |
+
def __init__(self, scale=4):
|
44 |
+
super(Decoder_SR, self).__init__()
|
45 |
+
|
46 |
+
self.scale = scale
|
47 |
+
|
48 |
+
self.conv_up_2 = nn.Sequential(
|
49 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
|
50 |
+
nn.ReLU(),
|
51 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
|
52 |
+
nn.ReLU(),
|
53 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
|
54 |
+
nn.ReLU()
|
55 |
+
)
|
56 |
+
|
57 |
+
self.conv_up_1 = nn.Sequential(
|
58 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
|
59 |
+
nn.ReLU(),
|
60 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
|
63 |
+
nn.ReLU()
|
64 |
+
)
|
65 |
+
|
66 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
67 |
+
|
68 |
+
# upsampling
|
69 |
+
self.upsample_1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True)
|
70 |
+
self.upsample_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True)
|
71 |
+
|
72 |
+
self.HR_conv = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True)
|
73 |
+
self.conv_last = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=True)
|
74 |
+
|
75 |
+
def forward(self, feat):
|
76 |
+
featmap_2 = self.conv_up_2(feat)
|
77 |
+
featmap_1 = self.conv_up_1(featmap_2)
|
78 |
+
|
79 |
+
if self.scale == 4:
|
80 |
+
featmap = self.lrelu(self.upsample_1(F.interpolate(featmap_1, scale_factor=2, mode='nearest')))
|
81 |
+
featmap = self.lrelu(self.upsample_2(F.interpolate(featmap, scale_factor=2, mode='nearest')))
|
82 |
+
elif self.scale == 2:
|
83 |
+
featmap = self.lrelu(self.upsample_1(F.interpolate(featmap_1, scale_factor=2, mode='nearest')))
|
84 |
+
|
85 |
+
|
86 |
+
out = self.conv_last(self.lrelu(self.HR_conv(featmap)))
|
87 |
+
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
92 |
+
"""Initialize network weights.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
96 |
+
scale (float): Scale initialized weights, especially for residual
|
97 |
+
blocks. Default: 1.
|
98 |
+
bias_fill (float): The value to fill bias. Default: 0
|
99 |
+
kwargs (dict): Other arguments for initialization function.
|
100 |
+
"""
|
101 |
+
if not isinstance(module_list, list):
|
102 |
+
module_list = [module_list]
|
103 |
+
for module in module_list:
|
104 |
+
for m in module.modules():
|
105 |
+
if isinstance(m, nn.Conv2d):
|
106 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
107 |
+
m.weight.data *= scale
|
108 |
+
if m.bias is not None:
|
109 |
+
m.bias.data.fill_(bias_fill)
|
110 |
+
elif isinstance(m, nn.Linear):
|
111 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
112 |
+
m.weight.data *= scale
|
113 |
+
if m.bias is not None:
|
114 |
+
m.bias.data.fill_(bias_fill)
|
115 |
+
elif isinstance(m, _BatchNorm):
|
116 |
+
init.constant_(m.weight, 1)
|
117 |
+
if m.bias is not None:
|
118 |
+
m.bias.data.fill_(bias_fill)
|
119 |
+
|
120 |
+
|
121 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
122 |
+
"""Make layers by stacking the same blocks.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
basic_block (nn.module): nn.module class for basic block.
|
126 |
+
num_basic_block (int): number of blocks.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
130 |
+
"""
|
131 |
+
layers = []
|
132 |
+
for _ in range(num_basic_block):
|
133 |
+
layers.append(basic_block(**kwarg))
|
134 |
+
return nn.Sequential(*layers)
|
135 |
+
|
136 |
+
|
137 |
+
class ResidualDenseBlock(nn.Module):
|
138 |
+
"""Residual Dense Block.
|
139 |
+
|
140 |
+
Used in RRDB block in ESRGAN.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
num_feat (int): Channel number of intermediate features.
|
144 |
+
num_grow_ch (int): Channels for each growth.
|
145 |
+
"""
|
146 |
+
|
147 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
148 |
+
super(ResidualDenseBlock, self).__init__()
|
149 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
150 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
151 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
152 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
153 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
154 |
+
|
155 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
156 |
+
|
157 |
+
# initialization
|
158 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
x1 = self.lrelu(self.conv1(x))
|
162 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
163 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
164 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
165 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
166 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
167 |
+
return x5 * 0.2 + x
|
168 |
+
|
169 |
+
|
170 |
+
class RRDB(nn.Module):
|
171 |
+
"""Residual in Residual Dense Block.
|
172 |
+
|
173 |
+
Used in RRDB-Net in ESRGAN.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
num_feat (int): Channel number of intermediate features.
|
177 |
+
num_grow_ch (int): Channels for each growth.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
181 |
+
super(RRDB, self).__init__()
|
182 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
183 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
184 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
out = self.rdb1(x)
|
188 |
+
out = self.rdb2(out)
|
189 |
+
out = self.rdb3(out)
|
190 |
+
# Emperically, we use 0.2 to scale the residual for better performance
|
191 |
+
return out * 0.2 + x
|
192 |
+
|
193 |
+
|
194 |
+
class Decoder_Id_RRDB(nn.Module):
|
195 |
+
def __init__(self, num_in_ch, num_out_ch=3, scale=4, num_feat=64, num_block=10, num_grow_ch=32):
|
196 |
+
super(Decoder_Id_RRDB, self).__init__()
|
197 |
+
|
198 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
199 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
200 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
201 |
+
|
202 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
203 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
204 |
+
|
205 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
|
209 |
+
feat = self.conv_first(x)
|
210 |
+
body_feat = self.conv_body(self.body(feat))
|
211 |
+
feat = feat + body_feat
|
212 |
+
|
213 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
214 |
+
return out
|
215 |
+
|
216 |
+
|
217 |
+
class Decoder_SR_RRDB(nn.Module):
|
218 |
+
def __init__(self, num_in_ch, num_out_ch=3, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
219 |
+
super(Decoder_SR_RRDB, self).__init__()
|
220 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
221 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
222 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
223 |
+
# upsample
|
224 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
225 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
226 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
227 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
228 |
+
|
229 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
230 |
+
|
231 |
+
def forward(self, x):
|
232 |
+
|
233 |
+
feat = self.conv_first(x)
|
234 |
+
body_feat = self.conv_body(self.body(feat))
|
235 |
+
feat = feat + body_feat
|
236 |
+
# upsample
|
237 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
238 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
239 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
240 |
+
return out
|
model/discriminator.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import spectral_norm
|
5 |
+
|
6 |
+
class DiscriminatorVGG(nn.Module):
|
7 |
+
def __init__(self, in_ch=3, image_size=128, d=64):
|
8 |
+
super(DiscriminatorVGG, self).__init__()
|
9 |
+
self.feature_map_size = image_size // 32
|
10 |
+
self.d = d
|
11 |
+
|
12 |
+
self.features = nn.Sequential(
|
13 |
+
nn.Conv2d(in_ch, d, kernel_size=3, stride=1, padding=1), # input is 3 x 128 x 128
|
14 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
15 |
+
|
16 |
+
nn.Conv2d(d, d, kernel_size=3, stride=2, padding=1, bias=False), # state size. 64 x 64 x 64
|
17 |
+
nn.BatchNorm2d(d),
|
18 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
19 |
+
|
20 |
+
nn.Conv2d(d, d*2, kernel_size=3, stride=1, padding=1, bias=False),
|
21 |
+
nn.BatchNorm2d(d*2),
|
22 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
23 |
+
|
24 |
+
nn.Conv2d(d*2, d*2, kernel_size=3, stride=2, padding=1, bias=False), # state size. 128 x 32 x 32
|
25 |
+
nn.BatchNorm2d(d*2),
|
26 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
27 |
+
|
28 |
+
nn.Conv2d(d*2, d*4, kernel_size=3, stride=1, padding=1, bias=False),
|
29 |
+
nn.BatchNorm2d(d*4),
|
30 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
31 |
+
|
32 |
+
nn.Conv2d(d*4, d*4, kernel_size=3, stride=2, padding=1, bias=False), # state size. 256 x 16 x 16
|
33 |
+
nn.BatchNorm2d(d*4),
|
34 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
35 |
+
|
36 |
+
nn.Conv2d(d*4, d*8, kernel_size=3, stride=1, padding=1, bias=False),
|
37 |
+
nn.BatchNorm2d(d*8),
|
38 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
39 |
+
|
40 |
+
nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), # state size. 512 x 8 x 8
|
41 |
+
nn.BatchNorm2d(d*8),
|
42 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
43 |
+
|
44 |
+
nn.Conv2d(d*8, d*8, kernel_size=3, stride=1, padding=1, bias=False),
|
45 |
+
nn.BatchNorm2d(d*8),
|
46 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
47 |
+
|
48 |
+
nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), # state size. 512 x 4 x 4
|
49 |
+
nn.BatchNorm2d(d*8),
|
50 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
51 |
+
)
|
52 |
+
|
53 |
+
self.classifier = nn.Sequential(
|
54 |
+
nn.Linear((self.d*8) * self.feature_map_size * self.feature_map_size, 100),
|
55 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
56 |
+
nn.Linear(100, 1)
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
out = self.features(x)
|
61 |
+
out = torch.flatten(out, 1)
|
62 |
+
out = self.classifier(out)
|
63 |
+
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class UNetDiscriminator(nn.Module):
|
68 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
69 |
+
super(UNetDiscriminator, self).__init__()
|
70 |
+
self.skip_connection = skip_connection
|
71 |
+
norm = spectral_norm
|
72 |
+
self.num_in_ch = num_in_ch
|
73 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
74 |
+
|
75 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat*2, kernel_size=4, stride=2, padding=1, bias=False))
|
76 |
+
self.conv2 = norm(nn.Conv2d(num_feat*2, num_feat*4, kernel_size=4, stride=2, padding=1, bias=False))
|
77 |
+
self.conv3 = norm(nn.Conv2d(num_feat*4, num_feat*8, kernel_size=4, stride=2, padding=1, bias=False))
|
78 |
+
|
79 |
+
# upsample
|
80 |
+
self.conv4 = norm(nn.Conv2d(num_feat*8, num_feat*4, kernel_size=3, stride=1, padding=1, bias=False))
|
81 |
+
self.conv5 = norm(nn.Conv2d(num_feat*4, num_feat*2, kernel_size=3, stride=1, padding=1, bias=False))
|
82 |
+
self.conv6 = norm(nn.Conv2d(num_feat*2, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
|
83 |
+
|
84 |
+
# extra
|
85 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
|
86 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
|
87 |
+
|
88 |
+
self.conv9 = nn.Conv2d(num_feat, 1, kernel_size=3, stride=1, padding=1)
|
89 |
+
|
90 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
94 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
95 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
96 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
97 |
+
|
98 |
+
# upsample
|
99 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear')
|
100 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
101 |
+
|
102 |
+
if self.skip_connection:
|
103 |
+
x4 = x4 + x2
|
104 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear')
|
105 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
106 |
+
|
107 |
+
if self.skip_connection:
|
108 |
+
x5 = x5 + x1
|
109 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear')
|
110 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
111 |
+
|
112 |
+
if self.skip_connection:
|
113 |
+
x6 = x6 + x0
|
114 |
+
|
115 |
+
# extra
|
116 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
117 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
118 |
+
out = self.conv9(out)
|
119 |
+
out = self.avg_pool(out)
|
120 |
+
out = torch.flatten(out, 1)
|
121 |
+
return out
|
model/encoder.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
class Encoder(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(Encoder, self).__init__()
|
7 |
+
|
8 |
+
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
|
9 |
+
|
10 |
+
self.conv_featmap_1 = nn.Sequential(
|
11 |
+
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1, bias=True),
|
12 |
+
nn.ReLU(),
|
13 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
|
16 |
+
nn.ReLU(),
|
17 |
+
)
|
18 |
+
|
19 |
+
self.conv_featmap_2 = nn.Sequential(
|
20 |
+
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, bias=True),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
|
25 |
+
nn.ReLU(),
|
26 |
+
)
|
27 |
+
|
28 |
+
self.conv_featmap_3 = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, bias=True),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=True),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=True),
|
34 |
+
nn.ReLU(),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, img):
|
38 |
+
featmap_1 = self.conv_featmap_1(img)
|
39 |
+
featmap_1_down = self.maxpool(featmap_1)
|
40 |
+
|
41 |
+
featmap_2 = self.conv_featmap_2(featmap_1_down)
|
42 |
+
featmap_2_down = self.maxpool(featmap_2)
|
43 |
+
|
44 |
+
featmap_3 = self.conv_featmap_3(featmap_2_down)
|
45 |
+
|
46 |
+
return featmap_3
|
47 |
+
|
48 |
+
|
49 |
+
class Encoder_RRDB(nn.Module):
|
50 |
+
def __init__(self, num_feat=16):
|
51 |
+
super(Encoder_RRDB, self).__init__()
|
52 |
+
self.conv_featmap = nn.Sequential(
|
53 |
+
nn.Conv2d(in_channels=3, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
|
54 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
55 |
+
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
|
56 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
57 |
+
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
|
58 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
59 |
+
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
|
60 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
61 |
+
nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, img):
|
65 |
+
featmap = self.conv_featmap(img)
|
66 |
+
|
67 |
+
return featmap
|
opt/__pycache__/option.cpython-38.pyc
ADDED
Binary file (2.81 kB). View file
|
|
opt/option.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
parser = argparse.ArgumentParser(description='BebyGAN')
|
5 |
+
|
6 |
+
# Hardware specifications
|
7 |
+
parser.add_argument('--gpu_id', type=str, default = "0", help='specify GPU ID to use')
|
8 |
+
parser.add_argument('--num_workers', type=int, default=4)
|
9 |
+
|
10 |
+
# Data specifications
|
11 |
+
parser.add_argument('--dir_data', type=str, default='./dataset', help='dataset root directory')
|
12 |
+
parser.add_argument('--scale', type=int, default=4, help='super resolution scale')
|
13 |
+
parser.add_argument('--patch_size', type=int, default=64, help='LR patch size') # default = 128 (in the paper)
|
14 |
+
|
15 |
+
# Train specifications
|
16 |
+
parser.add_argument('--epochs', type=int, default=35000, help='total epochs')
|
17 |
+
parser.add_argument('--batch_size', type=int, default=1, help='size of each batch') # default = 8 (in the paper)
|
18 |
+
|
19 |
+
# Optimizer specificaions
|
20 |
+
parser.add_argument('--lr_G', type=float, default=1e-4, help='initial learning rate of generator')
|
21 |
+
parser.add_argument('--lr_D', type=float, default=1e-4, help='initial learning rate of discriminator')
|
22 |
+
parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
|
23 |
+
parser.add_argument('--beta2', type=float, default=0.99, help='ADAM beta2')
|
24 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
|
25 |
+
|
26 |
+
# Scheduler specifications
|
27 |
+
parser.add_argument('--interval1', type=int, default=2.5e5, help='1st step size (iteration)')
|
28 |
+
parser.add_argument('--interval2', type=int, default=3.5e5, help='2nd step size (iteration)')
|
29 |
+
parser.add_argument('--interval3', type=int, default=4.5e5, help='3rd step size (iteration)')
|
30 |
+
parser.add_argument('--interval4', type=int, default=5.5e5, help='4th step size (iteration)')
|
31 |
+
parser.add_argument('--gamma_G', type=float, default=0.5, help='generator learning rate decay ratio')
|
32 |
+
parser.add_argument('--gamma_D', type=float, default=0.5, help='discriminator learning rate decay ratio')
|
33 |
+
|
34 |
+
# Train specificaions
|
35 |
+
parser.add_argument('--snap_path', type=str, default='./weights', help='path to save model weights')
|
36 |
+
parser.add_argument('--save_freq', type=str, default=5, help='save model frequency (epoch)')
|
37 |
+
# Logger
|
38 |
+
parser.add_argument('--log_interval', type=int, default=20)
|
39 |
+
# checkpoint
|
40 |
+
parser.add_argument('--checkpoint', type=str, default=None, help='load checkpoint')
|
41 |
+
# pretrained
|
42 |
+
parser.add_argument('--pretrained', type=str, default=None)
|
43 |
+
# Optimizer specifications
|
44 |
+
parser.add_argument('--lambda_align', type=float, default=0.01, help='L1 loss weight')
|
45 |
+
parser.add_argument('--lambda_rec', type=float, default=1.0, help='back-projection loss weight')
|
46 |
+
parser.add_argument('--lambda_res', type=float, default=1.0, help='perceptual loss weight')
|
47 |
+
parser.add_argument('--lambda_sty', type=float, default=0.01, help='style loss weight')
|
48 |
+
parser.add_argument('--lambda_idt', type=float, default=0.01, help='identity loss weight')
|
49 |
+
parser.add_argument('--lambda_cyc', type=float, default=1, help='cycle loss weight')
|
50 |
+
|
51 |
+
parser.add_argument('--lambda_percept', type=float, default=0.01, help='perceptual loss weight')
|
52 |
+
parser.add_argument('--lambda_adv', type=float, default=0.01, help='adversarial loss weight')
|
53 |
+
|
54 |
+
# generator & discriminator specifications
|
55 |
+
parser.add_argument('--n_disc', type=int, default=1, help='number of iteration for discriminator update in one epoch')
|
56 |
+
parser.add_argument('--n_gen', type=int, default=2, help='number of iteration for generator update in one epoch')
|
57 |
+
|
58 |
+
# encoder & decoder specifications
|
59 |
+
parser.add_argument('--n_hidden_feats', type=int, default=64, help='number of feature vectors in hidden layer')
|
60 |
+
parser.add_argument('--n_sr_feats', type=int, default=64, help='number of feature vectors in RRDB layer')
|
61 |
+
# eval spec
|
62 |
+
parser.add_argument('--phase', type=str, default='train')
|
63 |
+
|
64 |
+
# test specifications
|
65 |
+
parser.add_argument('--weights', type=str, default = "/data4/anhdh4/SR2/USR_DA-main/weights/epoch_1660.pth",help='load weights for test')
|
66 |
+
parser.add_argument('--dir_test', type=str, default = "/data4/anhdh4/SR2/NTIRE2020/valid_source_crop",help='directory of test images')
|
67 |
+
parser.add_argument('--results', type=str, default='./results1660/', help='directory of test results')
|
68 |
+
|
69 |
+
args = parser.parse_args()
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
opencv-python
|
3 |
+
scipy
|
4 |
+
wget
|
5 |
+
scikit-image
|
6 |
+
torch==1.13.0
|
7 |
+
torchmetrics==0.11.0
|
8 |
+
torchvision==0.14.0
|
9 |
+
tqdm
|
10 |
+
uvicorn
|
test/1.png
ADDED
testsets/0848.png
ADDED
testsets/0851.png
ADDED
testsets/0855.png
ADDED
testsets/0879.png
ADDED
train.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.optim as optim
|
6 |
+
import wandb
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision import transforms
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from data.LQGT_dataset import LQGTDataset, LQGTValDataset
|
12 |
+
from model import decoder, discriminator, encoder
|
13 |
+
from opt.option import args
|
14 |
+
from util.utils import (RandCrop, RandHorizontalFlip, RandRotate, ToTensor, RandCrop_pair,
|
15 |
+
VGG19PerceptualLoss)
|
16 |
+
|
17 |
+
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
|
18 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
19 |
+
|
20 |
+
wandb.init(project='SR', config=args)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# device setting
|
25 |
+
if args.gpu_id is not None:
|
26 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
27 |
+
print('using GPU 0')
|
28 |
+
else:
|
29 |
+
print('use --gpu_id to specify GPU ID to use')
|
30 |
+
exit()
|
31 |
+
|
32 |
+
device = torch.device('cuda')
|
33 |
+
|
34 |
+
# make directory for saving weights
|
35 |
+
if not os.path.exists(args.snap_path):
|
36 |
+
os.mkdir(args.snap_path)
|
37 |
+
|
38 |
+
print("Loading dataset...")
|
39 |
+
# load training dataset
|
40 |
+
train_dataset = LQGTDataset(
|
41 |
+
db_path=args.dir_data,
|
42 |
+
transform=transforms.Compose([RandCrop(args.patch_size, args.scale), RandHorizontalFlip(), RandRotate(), ToTensor()])
|
43 |
+
)
|
44 |
+
|
45 |
+
val_dataset = LQGTValDataset(
|
46 |
+
db_path=args.dir_data,
|
47 |
+
transform=transforms.Compose([RandCrop_pair(args.patch_size, args.scale), ToTensor()])
|
48 |
+
)
|
49 |
+
|
50 |
+
train_loader = DataLoader(
|
51 |
+
train_dataset,
|
52 |
+
batch_size=args.batch_size,
|
53 |
+
num_workers=args.num_workers,
|
54 |
+
drop_last=True,
|
55 |
+
shuffle=True
|
56 |
+
)
|
57 |
+
|
58 |
+
val_loader = DataLoader(
|
59 |
+
val_dataset,
|
60 |
+
batch_size=args.batch_size,
|
61 |
+
num_workers=args.num_workers,
|
62 |
+
shuffle=False
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
print("Create model")
|
67 |
+
model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).to(device)
|
68 |
+
model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).to(device)
|
69 |
+
model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.scale*args.patch_size).to(device)
|
70 |
+
# define model (generator)
|
71 |
+
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).to(device)
|
72 |
+
model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).to(device)
|
73 |
+
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).to(device)
|
74 |
+
|
75 |
+
# define model (discriminator)
|
76 |
+
|
77 |
+
# model_Disc_feat = discriminator.UNetDiscriminator(num_in_ch=64).to(device)
|
78 |
+
# model_Disc_img_LR = discriminator.UNetDiscriminator(num_in_ch=3).to(device)
|
79 |
+
# model_Disc_img_HR = discriminator.UNetDiscriminator(num_in_ch=3).to(device)
|
80 |
+
|
81 |
+
# wandb logging
|
82 |
+
wandb.watch(model_Disc_feat)
|
83 |
+
wandb.watch(model_Disc_img_LR)
|
84 |
+
wandb.watch(model_Enc)
|
85 |
+
wandb.watch(model_Dec_Id)
|
86 |
+
wandb.watch(model_Dec_SR)
|
87 |
+
|
88 |
+
|
89 |
+
print("Define Loss")
|
90 |
+
# loss
|
91 |
+
loss_L1 = nn.L1Loss().to(device)
|
92 |
+
loss_MSE = nn.MSELoss().to(device)
|
93 |
+
loss_adversarial = nn.BCEWithLogitsLoss().to(device)
|
94 |
+
loss_percept = VGG19PerceptualLoss().to(device)
|
95 |
+
|
96 |
+
|
97 |
+
print("Define Optimizer")
|
98 |
+
# optimizer
|
99 |
+
params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters())
|
100 |
+
optimizer_G = optim.Adam(
|
101 |
+
params_G,
|
102 |
+
lr=args.lr_G,
|
103 |
+
betas=(args.beta1, args.beta2),
|
104 |
+
weight_decay=args.weight_decay,
|
105 |
+
amsgrad=True
|
106 |
+
)
|
107 |
+
params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters())
|
108 |
+
optimizer_D = optim.Adam(
|
109 |
+
params_D,
|
110 |
+
lr=args.lr_D,
|
111 |
+
betas=(args.beta1, args.beta2),
|
112 |
+
weight_decay=args.weight_decay,
|
113 |
+
amsgrad=True
|
114 |
+
)
|
115 |
+
|
116 |
+
print("Define Scheduler")
|
117 |
+
# Scheduler
|
118 |
+
iter_indices = [args.interval1, args.interval2, args.interval3]
|
119 |
+
scheduler_G = optim.lr_scheduler.MultiStepLR(
|
120 |
+
optimizer=optimizer_G,
|
121 |
+
milestones=iter_indices,
|
122 |
+
gamma=0.5
|
123 |
+
)
|
124 |
+
scheduler_D = optim.lr_scheduler.MultiStepLR(
|
125 |
+
optimizer=optimizer_D,
|
126 |
+
milestones=iter_indices,
|
127 |
+
gamma=0.5
|
128 |
+
)
|
129 |
+
|
130 |
+
# print("Data Parallel")
|
131 |
+
model_Enc = nn.DataParallel(model_Enc)
|
132 |
+
model_Dec_Id = nn.DataParallel(model_Dec_Id)
|
133 |
+
model_Dec_SR = nn.DataParallel(model_Dec_SR)
|
134 |
+
|
135 |
+
# define model (discriminator)
|
136 |
+
#model_Disc_feat = nn.DataParallel(model_Disc_feat)
|
137 |
+
#model_Disc_img_LR = nn.DataParallel(model_Disc_img_LR)
|
138 |
+
#model_Disc_img_HR = nn.DataParallel(model_Disc_img_HR)
|
139 |
+
|
140 |
+
print("Load model weight")
|
141 |
+
# load model weights & optimzer % scheduler
|
142 |
+
if args.checkpoint is not None:
|
143 |
+
checkpoint = torch.load(args.checkpoint)
|
144 |
+
|
145 |
+
model_Enc.load_state_dict(checkpoint['model_Enc'])
|
146 |
+
model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id'])
|
147 |
+
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
|
148 |
+
model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat'])
|
149 |
+
model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR'])
|
150 |
+
model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR'])
|
151 |
+
|
152 |
+
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
|
153 |
+
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
|
154 |
+
|
155 |
+
scheduler_D.load_state_dict(checkpoint['scheduler_D'])
|
156 |
+
scheduler_G.load_state_dict(checkpoint['scheduler_G'])
|
157 |
+
|
158 |
+
start_epoch = checkpoint['epoch']
|
159 |
+
else:
|
160 |
+
start_epoch = 0
|
161 |
+
|
162 |
+
|
163 |
+
if args.pretrained is not None:
|
164 |
+
ckpt = torch.load(args.pretrained)
|
165 |
+
ckpt["params"]["conv_first.weight"] = ckpt["params"]["conv_first.weight"][:,0,:,:].expand(64,64,3,3)
|
166 |
+
model_Dec_SR.load_state_dict(ckpt["params"])
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
# model_Enc = model_Enc.to(device)
|
175 |
+
# model_Dec_Id = model_Dec_Id.to(device)
|
176 |
+
# model_Dec_SR = model_Dec_SR.to(device)
|
177 |
+
|
178 |
+
# # define model (discriminator)
|
179 |
+
# model_Disc_feat = model_Disc_feat.to(device)
|
180 |
+
# model_Disc_img_LR = model_Disc_img_LR.to(device)
|
181 |
+
# model_Disc_img_HR =model_Disc_img_HR.to(device)
|
182 |
+
# training
|
183 |
+
|
184 |
+
PSNR = PeakSignalNoiseRatio().to(device)
|
185 |
+
SSIM = StructuralSimilarityIndexMeasure().to(device)
|
186 |
+
LPIPS = LearnedPerceptualImagePatchSimilarity().to(device)
|
187 |
+
|
188 |
+
if args.phase == "train":
|
189 |
+
for epoch in range(start_epoch, args.epochs):
|
190 |
+
# generator
|
191 |
+
model_Enc.train()
|
192 |
+
model_Dec_Id.train()
|
193 |
+
model_Dec_SR.train()
|
194 |
+
|
195 |
+
# discriminator
|
196 |
+
model_Disc_feat.train()
|
197 |
+
model_Disc_img_LR.train()
|
198 |
+
model_Disc_img_HR.train()
|
199 |
+
|
200 |
+
running_loss_D_total = 0.0
|
201 |
+
running_loss_G_total = 0.0
|
202 |
+
|
203 |
+
running_loss_align = 0.0
|
204 |
+
running_loss_rec = 0.0
|
205 |
+
running_loss_res = 0.0
|
206 |
+
running_loss_sty = 0.0
|
207 |
+
running_loss_idt = 0.0
|
208 |
+
running_loss_cyc = 0.0
|
209 |
+
|
210 |
+
iter = 0
|
211 |
+
|
212 |
+
for data in tqdm(train_loader):
|
213 |
+
iter += 1
|
214 |
+
|
215 |
+
########################
|
216 |
+
# data load #
|
217 |
+
########################
|
218 |
+
X_t, Y_s = data['img_LQ'], data['img_GT']
|
219 |
+
|
220 |
+
ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic')
|
221 |
+
X_s = ds4(Y_s)
|
222 |
+
|
223 |
+
X_t = X_t.cuda(non_blocking=True)
|
224 |
+
X_s = X_s.cuda(non_blocking=True)
|
225 |
+
Y_s = Y_s.cuda(non_blocking=True)
|
226 |
+
|
227 |
+
# real label and fake label
|
228 |
+
batch_size = X_t.size(0)
|
229 |
+
real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True)
|
230 |
+
fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True)
|
231 |
+
|
232 |
+
|
233 |
+
########################
|
234 |
+
# (1) Update D network #
|
235 |
+
########################
|
236 |
+
model_Disc_feat.zero_grad()
|
237 |
+
model_Disc_img_LR.zero_grad()
|
238 |
+
model_Disc_img_HR.zero_grad()
|
239 |
+
|
240 |
+
for i in range(args.n_disc):
|
241 |
+
# generator output (feature domain)
|
242 |
+
F_t = model_Enc(X_t)
|
243 |
+
F_s = model_Enc(X_s)
|
244 |
+
|
245 |
+
# 1. feature aligment loss (discriminator)
|
246 |
+
# output of discriminator (feature domain) (b x c(=1) x h x w)
|
247 |
+
output_Disc_F_t = model_Disc_feat(F_t.detach())
|
248 |
+
output_Disc_F_s = model_Disc_feat(F_s.detach())
|
249 |
+
# discriminator loss (feature domain)
|
250 |
+
loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label)
|
251 |
+
loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label)
|
252 |
+
loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2
|
253 |
+
|
254 |
+
# 2. SR reconstruction loss (discriminator)
|
255 |
+
# generator output (image domain)
|
256 |
+
Y_s_s = model_Dec_SR(F_s)
|
257 |
+
# output of discriminator (image domain)
|
258 |
+
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach())
|
259 |
+
output_Disc_Y_s = model_Disc_img_HR(Y_s)
|
260 |
+
# discriminator loss (image domain)
|
261 |
+
loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label)
|
262 |
+
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
|
263 |
+
loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2
|
264 |
+
|
265 |
+
# 4. Target degradation style loss
|
266 |
+
# generator output (image domain)
|
267 |
+
X_s_t = model_Dec_Id(F_s)
|
268 |
+
# output of discriminator (image domain)
|
269 |
+
output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach())
|
270 |
+
output_Disc_X_t = model_Disc_img_LR(X_t)
|
271 |
+
# discriminator loss (image domain)
|
272 |
+
loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label)
|
273 |
+
loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label)
|
274 |
+
loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2
|
275 |
+
|
276 |
+
# 6. Cycle loss
|
277 |
+
# generator output (image domain)
|
278 |
+
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
|
279 |
+
# output of discriminator (image domain)
|
280 |
+
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach())
|
281 |
+
output_Disc_Y_s = model_Disc_img_HR(Y_s)
|
282 |
+
# discriminator loss (image domain)
|
283 |
+
loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label)
|
284 |
+
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
|
285 |
+
loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2
|
286 |
+
|
287 |
+
# discriminator weight update
|
288 |
+
loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc
|
289 |
+
loss_D_total.backward()
|
290 |
+
optimizer_D.step()
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
scheduler_D.step()
|
295 |
+
|
296 |
+
|
297 |
+
########################
|
298 |
+
# (2) Update G network #
|
299 |
+
########################
|
300 |
+
model_Enc.zero_grad()
|
301 |
+
model_Dec_Id.zero_grad()
|
302 |
+
model_Dec_SR.zero_grad()
|
303 |
+
|
304 |
+
for i in range(args.n_gen):
|
305 |
+
# generator output (feature domain)
|
306 |
+
F_t = model_Enc(X_t)
|
307 |
+
F_s = model_Enc(X_s)
|
308 |
+
|
309 |
+
# 1. feature alignment loss (generator)
|
310 |
+
# output of discriminator (feature domain)
|
311 |
+
output_Disc_F_t = model_Disc_feat(F_t)
|
312 |
+
output_Disc_F_s = model_Disc_feat(F_s)
|
313 |
+
# generator loss (feature domain)
|
314 |
+
loss_G_F_t = loss_MSE(output_Disc_F_t, (real_label + fake_label)/2)
|
315 |
+
loss_G_F_s = loss_MSE(output_Disc_F_s, (real_label + fake_label)/2)
|
316 |
+
L_align_E = loss_G_F_t + loss_G_F_s
|
317 |
+
|
318 |
+
# 2. SR reconstruction loss
|
319 |
+
# generator output (image domain)
|
320 |
+
Y_s_s = model_Dec_SR(F_s)
|
321 |
+
# output of discriminator (image domain)
|
322 |
+
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s)
|
323 |
+
# L1 loss
|
324 |
+
loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s)
|
325 |
+
# perceptual loss
|
326 |
+
loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s)
|
327 |
+
# adversatial loss
|
328 |
+
loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label)
|
329 |
+
L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s
|
330 |
+
|
331 |
+
# 3. Target LR restoration loss
|
332 |
+
X_t_t = model_Dec_Id(F_t)
|
333 |
+
L_res_G_t = loss_L1(X_t, X_t_t)
|
334 |
+
|
335 |
+
# 4. Target degredation style loss
|
336 |
+
# generator output (image domain)
|
337 |
+
X_s_t = model_Dec_Id(F_s)
|
338 |
+
# output of discriminator (img domain)
|
339 |
+
output_Disc_X_s_t = model_Disc_img_LR(X_s_t)
|
340 |
+
# generator loss (feature domain)
|
341 |
+
loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label)
|
342 |
+
L_sty_G_t = loss_G_X_s_t
|
343 |
+
|
344 |
+
# 5. Feature identity loss
|
345 |
+
F_s_tilda = model_Enc(model_Dec_Id(F_s))
|
346 |
+
L_idt_G_t = loss_L1(F_s, F_s_tilda)
|
347 |
+
|
348 |
+
# 6. Cycle loss
|
349 |
+
# generator output (image domain)
|
350 |
+
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
|
351 |
+
# output of discriminator (image domain)
|
352 |
+
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s)
|
353 |
+
# L1 loss
|
354 |
+
loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s)
|
355 |
+
# perceptual loss
|
356 |
+
loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s)
|
357 |
+
# adversarial loss
|
358 |
+
loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label)
|
359 |
+
L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s
|
360 |
+
|
361 |
+
# generator weight update
|
362 |
+
loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t + args.lambda_cyc*L_cyc_G_t_G_SR
|
363 |
+
loss_G_total.backward()
|
364 |
+
optimizer_G.step()
|
365 |
+
scheduler_G.step()
|
366 |
+
|
367 |
+
|
368 |
+
########################
|
369 |
+
# compute loss #
|
370 |
+
########################
|
371 |
+
running_loss_D_total += loss_D_total.item()
|
372 |
+
running_loss_G_total += loss_G_total.item()
|
373 |
+
|
374 |
+
running_loss_align += L_align_E.item()
|
375 |
+
running_loss_rec += L_rec_G_SR.item()
|
376 |
+
running_loss_res += L_res_G_t.item()
|
377 |
+
running_loss_sty += L_sty_G_t.item()
|
378 |
+
running_loss_idt += L_idt_G_t.item()
|
379 |
+
running_loss_cyc += L_cyc_G_t_G_SR.item()
|
380 |
+
if iter % args.log_interval == 0:
|
381 |
+
wandb.log(
|
382 |
+
{
|
383 |
+
"loss_D_total_step": running_loss_D_total/iter,
|
384 |
+
"loss_G_total_step": running_loss_G_total/iter,
|
385 |
+
"loss_align_step": running_loss_align/iter,
|
386 |
+
"loss_rec_step": running_loss_rec/iter,
|
387 |
+
"loss_res_step": running_loss_res/iter,
|
388 |
+
"loss_sty_step": running_loss_sty/iter,
|
389 |
+
"loss_idt_step": running_loss_idt/iter,
|
390 |
+
"loss_cyc_step": running_loss_cyc/iter,
|
391 |
+
}
|
392 |
+
)
|
393 |
+
### EVALUATE ###
|
394 |
+
total_PSNR = 0
|
395 |
+
total_SSIM = 0
|
396 |
+
total_LPIPS = 0
|
397 |
+
val_iter = 0
|
398 |
+
with torch.no_grad():
|
399 |
+
model_Enc.eval()
|
400 |
+
model_Dec_SR.eval()
|
401 |
+
for batch_idx, batch in enumerate(val_loader):
|
402 |
+
val_iter += 1
|
403 |
+
source = batch["img_LQ"].to(device)
|
404 |
+
target = batch["img_GT"].to(device)
|
405 |
+
|
406 |
+
feat = model_Enc(source)
|
407 |
+
out = model_Dec_SR(feat)
|
408 |
+
|
409 |
+
total_PSNR += PSNR(out, target)
|
410 |
+
total_SSIM += SSIM(out, target)
|
411 |
+
total_LPIPS += LPIPS(out, target)
|
412 |
+
|
413 |
+
wandb.log(
|
414 |
+
{
|
415 |
+
"epoch": epoch,
|
416 |
+
"lr": optimizer_G.param_groups[0]['lr'],
|
417 |
+
"loss_D_total_epoch": running_loss_D_total/iter,
|
418 |
+
"loss_G_total_epoch": running_loss_G_total/iter,
|
419 |
+
"loss_align_epoch": running_loss_align/iter,
|
420 |
+
"loss_rec_epoch": running_loss_rec/iter,
|
421 |
+
"loss_res_epoch": running_loss_res/iter,
|
422 |
+
"loss_sty_epoch": running_loss_sty/iter,
|
423 |
+
"loss_idt_epoch": running_loss_idt/iter,
|
424 |
+
"loss_cyc_epoch": running_loss_cyc/iter,
|
425 |
+
"PSNR_val": total_PSNR/val_iter,
|
426 |
+
"SSIM_val": total_SSIM/val_iter,
|
427 |
+
"LPIPS_val": total_LPIPS/val_iter
|
428 |
+
}
|
429 |
+
)
|
430 |
+
|
431 |
+
|
432 |
+
if (epoch+1) % args.save_freq == 0:
|
433 |
+
weights_file_name = 'epoch_%d.pth' % (epoch+1)
|
434 |
+
weights_file = os.path.join(args.snap_path, weights_file_name)
|
435 |
+
torch.save({
|
436 |
+
'epoch': epoch,
|
437 |
+
|
438 |
+
'model_Enc': model_Enc.state_dict(),
|
439 |
+
'model_Dec_Id': model_Dec_Id.state_dict(),
|
440 |
+
'model_Dec_SR': model_Dec_SR.state_dict(),
|
441 |
+
'model_Disc_feat': model_Disc_feat.state_dict(),
|
442 |
+
'model_Disc_img_LR': model_Disc_img_LR.state_dict(),
|
443 |
+
'model_Disc_img_HR': model_Disc_img_HR.state_dict(),
|
444 |
+
|
445 |
+
'optimizer_D': optimizer_D.state_dict(),
|
446 |
+
'optimizer_G': optimizer_G.state_dict(),
|
447 |
+
|
448 |
+
'scheduler_D': scheduler_D.state_dict(),
|
449 |
+
'scheduler_G': scheduler_G.state_dict(),
|
450 |
+
}, weights_file)
|
451 |
+
print('save weights of epoch %d' % (epoch+1))
|
452 |
+
else:
|
453 |
+
### EVALUATE ###
|
454 |
+
total_PSNR = 0
|
455 |
+
total_SSIM = 0
|
456 |
+
total_LPIPS = 0
|
457 |
+
val_iter = 0
|
458 |
+
with torch.no_grad():
|
459 |
+
model_Enc.eval()
|
460 |
+
model_Dec_SR.eval()
|
461 |
+
for batch_idx, batch in enumerate(val_loader):
|
462 |
+
val_iter += 1
|
463 |
+
source = batch["img_LQ"].to(device)
|
464 |
+
target = batch["img_GT"].to(device)
|
465 |
+
|
466 |
+
feat = model_Enc(source)
|
467 |
+
out = model_Dec_SR(feat)
|
468 |
+
|
469 |
+
total_PSNR += PSNR(out, target)
|
470 |
+
total_SSIM += SSIM(out, target)
|
471 |
+
total_LPIPS += LPIPS(out, target)
|
472 |
+
print("PSNR_val: ", total_PSNR/val_iter)
|
473 |
+
print("SSIM_val: ", total_SSIM/val_iter)
|
474 |
+
print("LPIPS_val: ", total_LPIPS/val_iter)
|
util/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (4.02 kB). View file
|
|
util/utils.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
import math
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
from scipy.ndimage import rotate
|
8 |
+
|
9 |
+
|
10 |
+
class RandCrop(object):
|
11 |
+
def __init__(self, crop_size, scale):
|
12 |
+
# if output size is tuple -> (height, width)
|
13 |
+
assert isinstance(crop_size, (int, tuple))
|
14 |
+
if isinstance(crop_size, int):
|
15 |
+
self.crop_size = (crop_size, crop_size)
|
16 |
+
else:
|
17 |
+
assert len(crop_size) == 2
|
18 |
+
self.crop_size = crop_size
|
19 |
+
|
20 |
+
self.scale = scale
|
21 |
+
|
22 |
+
def __call__(self, sample):
|
23 |
+
# img_LQ: H x W x C (numpy array)
|
24 |
+
img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
|
25 |
+
|
26 |
+
h, w, c = img_LQ.shape
|
27 |
+
new_h, new_w = self.crop_size
|
28 |
+
top = np.random.randint(0, h - new_h)
|
29 |
+
left = np.random.randint(0, w - new_w)
|
30 |
+
img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :]
|
31 |
+
|
32 |
+
h, w, c = img_GT.shape
|
33 |
+
top = np.random.randint(0, h - self.scale*new_h)
|
34 |
+
left = np.random.randint(0, w - self.scale*new_w)
|
35 |
+
img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :]
|
36 |
+
|
37 |
+
sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop}
|
38 |
+
return sample
|
39 |
+
|
40 |
+
|
41 |
+
class RandRotate(object):
|
42 |
+
def __call__(self, sample):
|
43 |
+
# img_LQ: H x W x C (numpy array)
|
44 |
+
img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
|
45 |
+
|
46 |
+
prob_rotate = np.random.random()
|
47 |
+
if prob_rotate < 0.25:
|
48 |
+
img_LQ = rotate(img_LQ, 90).copy()
|
49 |
+
img_GT = rotate(img_GT, 90).copy()
|
50 |
+
elif prob_rotate < 0.5:
|
51 |
+
img_LQ = rotate(img_LQ, 90).copy()
|
52 |
+
img_GT = rotate(img_GT, 90).copy()
|
53 |
+
elif prob_rotate < 0.75:
|
54 |
+
img_LQ = rotate(img_LQ, 90).copy()
|
55 |
+
img_GT = rotate(img_GT, 90).copy()
|
56 |
+
|
57 |
+
sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
|
58 |
+
return sample
|
59 |
+
|
60 |
+
|
61 |
+
class RandHorizontalFlip(object):
|
62 |
+
def __call__(self, sample):
|
63 |
+
# img_LQ: H x W x C (numpy array)
|
64 |
+
img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
|
65 |
+
|
66 |
+
prob_lr = np.random.random()
|
67 |
+
if prob_lr < 0.5:
|
68 |
+
img_LQ = np.fliplr(img_LQ).copy()
|
69 |
+
img_GT = np.fliplr(img_GT).copy()
|
70 |
+
|
71 |
+
sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
|
72 |
+
return sample
|
73 |
+
|
74 |
+
|
75 |
+
class ToTensor(object):
|
76 |
+
def __call__(self, sample):
|
77 |
+
# img_LQ : H x W x C (numpy array) -> C x H x W (torch tensor)
|
78 |
+
img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
|
79 |
+
|
80 |
+
img_LQ = img_LQ.transpose((2, 0, 1))
|
81 |
+
img_GT = img_GT.transpose((2, 0, 1))
|
82 |
+
|
83 |
+
img_LQ = torch.from_numpy(img_LQ)
|
84 |
+
img_GT = torch.from_numpy(img_GT)
|
85 |
+
|
86 |
+
sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
|
87 |
+
return sample
|
88 |
+
|
89 |
+
|
90 |
+
class VGG19PerceptualLoss(torch.nn.Module):
|
91 |
+
def __init__(self, feature_layer=35):
|
92 |
+
super(VGG19PerceptualLoss, self).__init__()
|
93 |
+
model = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
|
94 |
+
self.features = torch.nn.Sequential(*list(model.features.children())[:feature_layer]).eval()
|
95 |
+
# Freeze parameters
|
96 |
+
for name, param in self.features.named_parameters():
|
97 |
+
param.requires_grad = False
|
98 |
+
|
99 |
+
def forward(self, source, target):
|
100 |
+
vgg_loss = torch.nn.functional.l1_loss(self.features(source), self.features(target))
|
101 |
+
|
102 |
+
return vgg_loss
|
103 |
+
|
104 |
+
|
105 |
+
class RandCrop_pair(object):
|
106 |
+
def __init__(self, crop_size, scale):
|
107 |
+
# if output size is tuple -> (height, width)
|
108 |
+
assert isinstance(crop_size, (int, tuple))
|
109 |
+
if isinstance(crop_size, int):
|
110 |
+
self.crop_size = (crop_size, crop_size)
|
111 |
+
else:
|
112 |
+
assert len(crop_size) == 2
|
113 |
+
self.crop_size = crop_size
|
114 |
+
|
115 |
+
self.scale = scale
|
116 |
+
|
117 |
+
def __call__(self, sample):
|
118 |
+
# img_LQ: H x W x C (numpy array)
|
119 |
+
img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
|
120 |
+
|
121 |
+
h, w, c = img_LQ.shape
|
122 |
+
new_h, new_w = self.crop_size
|
123 |
+
top = np.random.randint(0, h - new_h)
|
124 |
+
left = np.random.randint(0, w - new_w)
|
125 |
+
img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :]
|
126 |
+
|
127 |
+
h, w, c = img_GT.shape
|
128 |
+
top = self.scale*top
|
129 |
+
left = self.scale*left
|
130 |
+
img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :]
|
131 |
+
|
132 |
+
sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop}
|
133 |
+
return sample
|