DS commited on
Commit
e5b70eb
1 Parent(s): b19f11c

dump shiet

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ weights/best_weight.pth
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+
4
+
5
+ USER root
6
+
7
+ WORKDIR /code
8
+
9
+ COPY ./requirements.txt /code/requirements.txt
10
+
11
+ RUN apt-get update
12
+ RUN apt-get install ffmpeg libsm6 libxext6 -y
13
+
14
+ RUN pip install --upgrade pip
15
+
16
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
17
+
18
+ COPY . .
19
+
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
21
+
MeasureV1.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+
6
+ import numpy as np
7
+ import torch
8
+ import cv2
9
+ import argparse
10
+
11
+ from natsort import natsort
12
+ from skimage.metrics import structural_similarity as compare_ssim
13
+ from skimage.metrics import peak_signal_noise_ratio as compare_psnr
14
+ import lpips
15
+
16
+
17
+ class Measure():
18
+ def __init__(self, net='alex', use_gpu=False):
19
+ self.device = 'cuda' if use_gpu else 'cpu'
20
+ self.model = lpips.LPIPS(net=net)
21
+ self.model.to(self.device)
22
+
23
+ def measure(self, imgA, imgB):
24
+ if not all([s1 == s2 for s1, s2 in zip(imgA.shape, imgB.shape)]):
25
+ raise RuntimeError("Image sizes not the same.")
26
+ return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
27
+
28
+ def lpips(self, imgA, imgB, model=None):
29
+ tA = t(imgA).to(self.device)
30
+ tB = t(imgB).to(self.device)
31
+ dist01 = self.model.forward(tA, tB).item()
32
+ return dist01
33
+
34
+ def ssim(self, imgA, imgB):
35
+ # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
36
+ score, diff = compare_ssim(imgA, imgB, full=True, multichannel=True)
37
+ return score
38
+
39
+ def psnr(self, imgA, imgB):
40
+ psnr = compare_psnr(imgA, imgB)
41
+ return psnr
42
+
43
+
44
+ def t(img):
45
+ def to_4d(img):
46
+ assert len(img.shape) == 3
47
+ assert img.dtype == np.uint8
48
+ img_new = np.expand_dims(img, axis=0)
49
+ assert len(img_new.shape) == 4
50
+ return img_new
51
+
52
+ def to_CHW(img):
53
+ return np.transpose(img, [2, 0, 1])
54
+
55
+ def to_tensor(img):
56
+ return torch.Tensor(img)
57
+
58
+ return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
59
+
60
+
61
+ def fiFindByWildcard(wildcard):
62
+ return natsort.natsorted(glob.glob(wildcard, recursive=True))
63
+
64
+
65
+ def imread(path):
66
+ return cv2.imread(path)[:, :, [2, 1, 0]]
67
+
68
+
69
+ def format_result(psnr, ssim, lpips):
70
+ return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
71
+
72
+ def measure_dirs(dirA, dirB, use_gpu, verbose=False):
73
+ if verbose:
74
+ vprint = lambda x: print(x)
75
+ else:
76
+ vprint = lambda x: None
77
+
78
+
79
+ t_init = time.time()
80
+
81
+ paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
82
+ paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
83
+
84
+ vprint("Comparing: ")
85
+ vprint(dirA)
86
+ vprint(dirB)
87
+
88
+ measure = Measure(use_gpu=use_gpu)
89
+
90
+ results = []
91
+ for pathA, pathB in zip(paths_A, paths_B):
92
+ result = OrderedDict()
93
+
94
+ t = time.time()
95
+ result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB))
96
+ d = time.time() - t
97
+ vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
98
+
99
+ results.append(result)
100
+
101
+ psnr = np.mean([result['psnr'] for result in results])
102
+ ssim = np.mean([result['ssim'] for result in results])
103
+ lpips = np.mean([result['lpips'] for result in results])
104
+
105
+ vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument('-dirA', default='', type=str)
111
+ parser.add_argument('-dirB', default='', type=str)
112
+ parser.add_argument('-type', default='png')
113
+ parser.add_argument('--use_gpu', action='store_true', default=False)
114
+ args = parser.parse_args()
115
+
116
+ dirA = args.dirA
117
+ dirB = args.dirB
118
+ type = args.type
119
+ use_gpu = args.use_gpu
120
+
121
+ if len(dirA) > 0 and len(dirB) > 0:
122
+ measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
123
+
README.md CHANGED
@@ -1,11 +1,9 @@
1
  ---
2
  title: USR DA
3
- emoji: 📊
4
  colorFrom: gray
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.14.0
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
  title: USR DA
3
+ emoji: 😻
4
  colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: docker
 
 
7
  pinned: false
8
  ---
9
 
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import io
3
+ import os
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import wget
11
+ from torchvision.transforms import Compose, ToTensor
12
+
13
+ from model import decoder, encoder
14
+
15
+ WEIGHT_PATH = './weights/best_weight.pth'
16
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+
19
+ class Model(object):
20
+ def __init__(self) -> None:
21
+ self.model_Enc = encoder.Encoder_RRDB(num_feat=64).to(device=DEVICE)
22
+ self.model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=64).to(device=DEVICE)
23
+ self.preprocess = Compose([ToTensor()])
24
+ self.load_model()
25
+
26
+ def load_model(self, weight_path=WEIGHT_PATH):
27
+ if not os.path.isfile("./weights/best_weight.pth"):
28
+ response = wget.download("https://raw.githubusercontent.com/hungnguyen2611/super-resolution/master/weights/best_weight.pth", "./weights/best_weight.pth")
29
+ weight = torch.load(weight_path)
30
+ print("[LOADING] Loading encoder...")
31
+ self.model_Enc.load_state_dict(weight['model_Enc'])
32
+ print("[LOADING] Loading decoder...")
33
+ self.model_Dec_SR.load_state_dict(weight['model_Dec_SR'])
34
+ print("[LOADING] Loading done!")
35
+ self.model_Enc.eval()
36
+ self.model_Dec_SR.eval()
37
+
38
+ def predict(self, img):
39
+ with torch.no_grad():
40
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
41
+ img = self.preprocess(img)
42
+ img = img.unsqueeze(0)
43
+ img = img.to(DEVICE)
44
+
45
+ feat = self.model_Enc(img)
46
+ out = self.model_Dec_SR(feat)
47
+ min_max = (0, 1)
48
+ out = out.detach()[0].float().cpu()
49
+
50
+ out = out.squeeze().float().cpu().clamp_(*min_max)
51
+ out = (out - min_max[0]) / (min_max[1] - min_max[0])
52
+ out = out.numpy()
53
+ out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
54
+
55
+ out = (out*255.0).round()
56
+ out = out.astype(np.uint8)
57
+ return out
58
+
59
+ model = Model()
60
+
61
+ def predict(img):
62
+ global model
63
+ img.save("test/1.png", "PNG")
64
+ image = cv2.imread("test/1.png", cv2.IMREAD_COLOR)
65
+ out = model.predict(img=image)
66
+
67
+ cv2.imwrite(f'images_uploaded/1.png', out)
68
+ return f"images_uploaded/1.png"
69
+
70
+
71
+
72
+
73
+ if __name__ == '__main__':
74
+ title = "Super-Resolution Demo USR-DA Unofficial 🚀🚀🔥"
75
+ description = '''
76
+ <br>
77
+ **This Demo expects low-quality and low-resolution images**
78
+ **We are looking for collaborators! Collaborator**
79
+ </br>
80
+ '''
81
+ article = "<p style='text-align: center'><a href='https://openaccess.thecvf.com/content/ICCV2021/papers/Wang_Unsupervised_Real-World_Super-Resolution_A_Domain_Adaptation_Perspective_ICCV_2021_paper.pdf' target='_blank'>Unsupervised Real-World Super-Resolution: A Domain Adaptation Perspective</a> | <a href='https://github.com/hungnguyen2611/super-resolution.git' target='_blank'>Github Repo</a></p>"
82
+ examples= glob.glob("testsets/*.png")
83
+ gr.Interface(
84
+ predict,
85
+ gr.inputs.Image(type="pil", label="Input").style(height=260),
86
+ gr.inputs.Image(type="pil", label="Ouput").style(height=240),
87
+ title=title,
88
+ description=description,
89
+ article=article,
90
+ examples=examples,
91
+ ).launch(enable_queue=True)
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
compare.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import time
4
+ from collections import OrderedDict
5
+ import numpy as np
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ from natsort import natsort
9
+ from tqdm import tqdm
10
+
11
+
12
+ def fiFindByWildcard(wildcard):
13
+ return natsort.natsorted(glob.glob(wildcard, recursive=True))
14
+
15
+
16
+
17
+
18
+ if __name__ == "__main__":
19
+ out_data_path = fiFindByWildcard("./results_crop (1)/out/*")
20
+ gt_data_path = fiFindByWildcard("./results_crop (1)/target/*")
21
+ source_data_path = fiFindByWildcard("./results_crop (1)/source/*")
22
+
23
+ for src_path, out_path, gt_path in tqdm(list(zip(source_data_path, out_data_path, gt_data_path))):
24
+ fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
25
+ ax1.set_title("Bicubic")
26
+ ax2.set_title("Baseline")
27
+ ax3.set_title("Ground truth")
28
+
29
+ src = cv2.imread(src_path)[:, :, [2, 1, 0]]
30
+ out = cv2.imread(out_path)[:, :, [2, 1, 0]]
31
+ gt = cv2.imread(gt_path)[:, :, [2, 1, 0]]
32
+ src = cv2.resize(src, None, fx=4, fy=4, interpolation=cv2.INTER_CUBIC)
33
+ ax1.set_yticklabels([])
34
+ ax1.set_xticklabels([])
35
+ ax2.set_yticklabels([])
36
+ ax2.set_xticklabels([])
37
+ ax3.set_yticklabels([])
38
+ ax3.set_xticklabels([])
39
+ ax1.imshow(src)
40
+ ax2.imshow(out)
41
+ ax3.imshow(gt)
42
+ fig.savefig(f"./result_compare_crop_new/{os.path.basename(gt_path)}", bbox_inches='tight' , dpi=1200)
43
+ plt.close()
44
+
45
+
46
+
47
+
crop_test.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+
5
+
flagged/Input/tmplp_isgr5.jpg ADDED
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Input,Ouput,flag,username,timestamp
2
+ /home/ds/Documents/SR/USR_DA/USR-DA/USR-DA/flagged/Input/tmplp_isgr5.jpg,,,,2022-12-18 14:45:29.533016
images_uploaded/0805.png ADDED
images_uploaded/0821.png ADDED
images_uploaded/0873.png ADDED
images_uploaded/1.png ADDED
inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import torch.nn as nn
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torchvision
8
+
9
+ from model import encoder, decoder
10
+ from opt.option import args
11
+
12
+
13
+ # device setting
14
+ if args.gpu_id is not None:
15
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
16
+ print('using GPU 0')
17
+ else:
18
+ print('use --gpu_id to specify GPU ID to use')
19
+ exit()
20
+
21
+
22
+ # make directory for saving weights
23
+ if not os.path.exists(args.results):
24
+ os.mkdir(args.results)
25
+
26
+
27
+ # numpy array -> torch tensor
28
+ class ToTensor(object):
29
+ def __call__(self, sample):
30
+ sample = np.transpose(sample, (2, 0, 1))
31
+ sample = torch.from_numpy(sample)
32
+ return sample
33
+
34
+
35
+ # create model
36
+ # model_Enc = encoder.Encoder().cuda()
37
+ # model_Dec_SR = decoder.Decoder_SR().cuda()
38
+ model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).cuda()
39
+ model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).cuda()
40
+
41
+ model_Enc = nn.DataParallel(model_Enc)
42
+ #model_Dec_Id = nn.DataParallel(model_Dec_Id)
43
+ model_Dec_SR = nn.DataParallel(model_Dec_SR)
44
+
45
+ # load weights
46
+ checkpoint = torch.load(args.weights)
47
+ model_Enc.load_state_dict(checkpoint['model_Enc'])
48
+ model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
49
+ model_Enc.eval()
50
+ model_Dec_SR.eval()
51
+
52
+ # input transform
53
+ transforms = torchvision.transforms.Compose([ToTensor()])
54
+
55
+
56
+ filenames = os.listdir(args.dir_test)
57
+ filenames.sort()
58
+ with torch.no_grad():
59
+ for filename in tqdm(filenames):
60
+ img_name = os.path.join(args.dir_test, filename)
61
+ ext = os.path.splitext(img_name)[-1]
62
+ if ext in ['.png', '.jpg']:
63
+ img = cv2.imread(img_name)
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
65
+ #img = cv2.resize(img, ((img.shape[1] // 4),(img.shape[0] // 4)))
66
+ img = np.array(img).astype('float32') / 255
67
+ # img = img[0:256, 0:256, :]
68
+
69
+ img = transforms(img)
70
+ img = torch.tensor(img.cuda()).unsqueeze(0)
71
+
72
+ # inference output
73
+ feat = model_Enc(img)
74
+ out = model_Dec_SR(feat)
75
+
76
+ min_max = (0, 1)
77
+ out = out.detach()[0].float().cpu()
78
+
79
+ out = out.squeeze().float().cpu().clamp_(*min_max)
80
+ out = (out - min_max[0]) / (min_max[1] - min_max[0])
81
+ out = out.numpy()
82
+ out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
83
+
84
+ out = (out*255.0).round()
85
+ out = out.astype(np.uint8)
86
+
87
+ # result image save (b x c x h x w (torch tensor) -> h x w x c (numpy array))
88
+ # out = out.data.cpu().squeeze().numpy()
89
+ # out = np.clip(out, 0, 1)
90
+ # out = np.transpose(out, (1, 2, 0))
91
+ print(args.results, filename)
92
+ cv2.imwrite('%s_out.png' %(os.path.join(args.results, filename)[:-4]), out)
93
+
94
+
95
+
96
+
97
+
model/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (7.59 kB). View file
 
model/__pycache__/discriminator.cpython-38.pyc ADDED
Binary file (3.32 kB). View file
 
model/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (2.07 kB). View file
 
model/decoder.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init as init
4
+ import torch.nn.functional as F
5
+ from torch.nn.modules.batchnorm import _BatchNorm
6
+
7
+ class Decoder_Identity(nn.Module):
8
+ def __init__(self):
9
+ super(Decoder_Identity, self).__init__()
10
+
11
+ self.conv_up_2 = nn.Sequential(
12
+ nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
13
+ nn.ReLU(),
14
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
15
+ nn.ReLU(),
16
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
17
+ nn.ReLU()
18
+ )
19
+
20
+ self.conv_up_1 = nn.Sequential(
21
+ nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
22
+ nn.ReLU(),
23
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
24
+ nn.ReLU(),
25
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
26
+ nn.ReLU()
27
+ )
28
+
29
+ self.conv_last = nn.Sequential(
30
+ nn.Conv2d(in_channels=16, out_channels=3, kernel_size=1, bias=True),
31
+ nn.ReLU()
32
+ )
33
+
34
+ def forward(self, feat):
35
+ featmap_2 = self.conv_up_2(feat)
36
+ featmap_1 = self.conv_up_1(featmap_2)
37
+ out = self.conv_last(featmap_1)
38
+
39
+ return out
40
+
41
+
42
+ class Decoder_SR(nn.Module):
43
+ def __init__(self, scale=4):
44
+ super(Decoder_SR, self).__init__()
45
+
46
+ self.scale = scale
47
+
48
+ self.conv_up_2 = nn.Sequential(
49
+ nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
50
+ nn.ReLU(),
51
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
52
+ nn.ReLU(),
53
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
54
+ nn.ReLU()
55
+ )
56
+
57
+ self.conv_up_1 = nn.Sequential(
58
+ nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1),
59
+ nn.ReLU(),
60
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
61
+ nn.ReLU(),
62
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
63
+ nn.ReLU()
64
+ )
65
+
66
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
67
+
68
+ # upsampling
69
+ self.upsample_1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True)
70
+ self.upsample_2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True)
71
+
72
+ self.HR_conv = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True)
73
+ self.conv_last = nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=True)
74
+
75
+ def forward(self, feat):
76
+ featmap_2 = self.conv_up_2(feat)
77
+ featmap_1 = self.conv_up_1(featmap_2)
78
+
79
+ if self.scale == 4:
80
+ featmap = self.lrelu(self.upsample_1(F.interpolate(featmap_1, scale_factor=2, mode='nearest')))
81
+ featmap = self.lrelu(self.upsample_2(F.interpolate(featmap, scale_factor=2, mode='nearest')))
82
+ elif self.scale == 2:
83
+ featmap = self.lrelu(self.upsample_1(F.interpolate(featmap_1, scale_factor=2, mode='nearest')))
84
+
85
+
86
+ out = self.conv_last(self.lrelu(self.HR_conv(featmap)))
87
+
88
+ return out
89
+
90
+
91
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
92
+ """Initialize network weights.
93
+
94
+ Args:
95
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
96
+ scale (float): Scale initialized weights, especially for residual
97
+ blocks. Default: 1.
98
+ bias_fill (float): The value to fill bias. Default: 0
99
+ kwargs (dict): Other arguments for initialization function.
100
+ """
101
+ if not isinstance(module_list, list):
102
+ module_list = [module_list]
103
+ for module in module_list:
104
+ for m in module.modules():
105
+ if isinstance(m, nn.Conv2d):
106
+ init.kaiming_normal_(m.weight, **kwargs)
107
+ m.weight.data *= scale
108
+ if m.bias is not None:
109
+ m.bias.data.fill_(bias_fill)
110
+ elif isinstance(m, nn.Linear):
111
+ init.kaiming_normal_(m.weight, **kwargs)
112
+ m.weight.data *= scale
113
+ if m.bias is not None:
114
+ m.bias.data.fill_(bias_fill)
115
+ elif isinstance(m, _BatchNorm):
116
+ init.constant_(m.weight, 1)
117
+ if m.bias is not None:
118
+ m.bias.data.fill_(bias_fill)
119
+
120
+
121
+ def make_layer(basic_block, num_basic_block, **kwarg):
122
+ """Make layers by stacking the same blocks.
123
+
124
+ Args:
125
+ basic_block (nn.module): nn.module class for basic block.
126
+ num_basic_block (int): number of blocks.
127
+
128
+ Returns:
129
+ nn.Sequential: Stacked blocks in nn.Sequential.
130
+ """
131
+ layers = []
132
+ for _ in range(num_basic_block):
133
+ layers.append(basic_block(**kwarg))
134
+ return nn.Sequential(*layers)
135
+
136
+
137
+ class ResidualDenseBlock(nn.Module):
138
+ """Residual Dense Block.
139
+
140
+ Used in RRDB block in ESRGAN.
141
+
142
+ Args:
143
+ num_feat (int): Channel number of intermediate features.
144
+ num_grow_ch (int): Channels for each growth.
145
+ """
146
+
147
+ def __init__(self, num_feat=64, num_grow_ch=32):
148
+ super(ResidualDenseBlock, self).__init__()
149
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
150
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
151
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
152
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
153
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
154
+
155
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
156
+
157
+ # initialization
158
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
159
+
160
+ def forward(self, x):
161
+ x1 = self.lrelu(self.conv1(x))
162
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
163
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
164
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
165
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
166
+ # Emperically, we use 0.2 to scale the residual for better performance
167
+ return x5 * 0.2 + x
168
+
169
+
170
+ class RRDB(nn.Module):
171
+ """Residual in Residual Dense Block.
172
+
173
+ Used in RRDB-Net in ESRGAN.
174
+
175
+ Args:
176
+ num_feat (int): Channel number of intermediate features.
177
+ num_grow_ch (int): Channels for each growth.
178
+ """
179
+
180
+ def __init__(self, num_feat, num_grow_ch=32):
181
+ super(RRDB, self).__init__()
182
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
183
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
184
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
185
+
186
+ def forward(self, x):
187
+ out = self.rdb1(x)
188
+ out = self.rdb2(out)
189
+ out = self.rdb3(out)
190
+ # Emperically, we use 0.2 to scale the residual for better performance
191
+ return out * 0.2 + x
192
+
193
+
194
+ class Decoder_Id_RRDB(nn.Module):
195
+ def __init__(self, num_in_ch, num_out_ch=3, scale=4, num_feat=64, num_block=10, num_grow_ch=32):
196
+ super(Decoder_Id_RRDB, self).__init__()
197
+
198
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
199
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
200
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
201
+
202
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
203
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
204
+
205
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
206
+
207
+ def forward(self, x):
208
+
209
+ feat = self.conv_first(x)
210
+ body_feat = self.conv_body(self.body(feat))
211
+ feat = feat + body_feat
212
+
213
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
214
+ return out
215
+
216
+
217
+ class Decoder_SR_RRDB(nn.Module):
218
+ def __init__(self, num_in_ch, num_out_ch=3, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
219
+ super(Decoder_SR_RRDB, self).__init__()
220
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
221
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
222
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
223
+ # upsample
224
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
225
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
226
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
227
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
228
+
229
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
230
+
231
+ def forward(self, x):
232
+
233
+ feat = self.conv_first(x)
234
+ body_feat = self.conv_body(self.body(feat))
235
+ feat = feat + body_feat
236
+ # upsample
237
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
238
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
239
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
240
+ return out
model/discriminator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import spectral_norm
5
+
6
+ class DiscriminatorVGG(nn.Module):
7
+ def __init__(self, in_ch=3, image_size=128, d=64):
8
+ super(DiscriminatorVGG, self).__init__()
9
+ self.feature_map_size = image_size // 32
10
+ self.d = d
11
+
12
+ self.features = nn.Sequential(
13
+ nn.Conv2d(in_ch, d, kernel_size=3, stride=1, padding=1), # input is 3 x 128 x 128
14
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
15
+
16
+ nn.Conv2d(d, d, kernel_size=3, stride=2, padding=1, bias=False), # state size. 64 x 64 x 64
17
+ nn.BatchNorm2d(d),
18
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
19
+
20
+ nn.Conv2d(d, d*2, kernel_size=3, stride=1, padding=1, bias=False),
21
+ nn.BatchNorm2d(d*2),
22
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
23
+
24
+ nn.Conv2d(d*2, d*2, kernel_size=3, stride=2, padding=1, bias=False), # state size. 128 x 32 x 32
25
+ nn.BatchNorm2d(d*2),
26
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
+
28
+ nn.Conv2d(d*2, d*4, kernel_size=3, stride=1, padding=1, bias=False),
29
+ nn.BatchNorm2d(d*4),
30
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
31
+
32
+ nn.Conv2d(d*4, d*4, kernel_size=3, stride=2, padding=1, bias=False), # state size. 256 x 16 x 16
33
+ nn.BatchNorm2d(d*4),
34
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
35
+
36
+ nn.Conv2d(d*4, d*8, kernel_size=3, stride=1, padding=1, bias=False),
37
+ nn.BatchNorm2d(d*8),
38
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
39
+
40
+ nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), # state size. 512 x 8 x 8
41
+ nn.BatchNorm2d(d*8),
42
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
43
+
44
+ nn.Conv2d(d*8, d*8, kernel_size=3, stride=1, padding=1, bias=False),
45
+ nn.BatchNorm2d(d*8),
46
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
47
+
48
+ nn.Conv2d(d*8, d*8, kernel_size=3, stride=2, padding=1, bias=False), # state size. 512 x 4 x 4
49
+ nn.BatchNorm2d(d*8),
50
+ nn.LeakyReLU(negative_slope=0.2, inplace=True)
51
+ )
52
+
53
+ self.classifier = nn.Sequential(
54
+ nn.Linear((self.d*8) * self.feature_map_size * self.feature_map_size, 100),
55
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
56
+ nn.Linear(100, 1)
57
+ )
58
+
59
+ def forward(self, x):
60
+ out = self.features(x)
61
+ out = torch.flatten(out, 1)
62
+ out = self.classifier(out)
63
+
64
+ return out
65
+
66
+
67
+ class UNetDiscriminator(nn.Module):
68
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
69
+ super(UNetDiscriminator, self).__init__()
70
+ self.skip_connection = skip_connection
71
+ norm = spectral_norm
72
+ self.num_in_ch = num_in_ch
73
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
74
+
75
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat*2, kernel_size=4, stride=2, padding=1, bias=False))
76
+ self.conv2 = norm(nn.Conv2d(num_feat*2, num_feat*4, kernel_size=4, stride=2, padding=1, bias=False))
77
+ self.conv3 = norm(nn.Conv2d(num_feat*4, num_feat*8, kernel_size=4, stride=2, padding=1, bias=False))
78
+
79
+ # upsample
80
+ self.conv4 = norm(nn.Conv2d(num_feat*8, num_feat*4, kernel_size=3, stride=1, padding=1, bias=False))
81
+ self.conv5 = norm(nn.Conv2d(num_feat*4, num_feat*2, kernel_size=3, stride=1, padding=1, bias=False))
82
+ self.conv6 = norm(nn.Conv2d(num_feat*2, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
83
+
84
+ # extra
85
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
86
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, kernel_size=3, stride=1, padding=1, bias=False))
87
+
88
+ self.conv9 = nn.Conv2d(num_feat, 1, kernel_size=3, stride=1, padding=1)
89
+
90
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
91
+
92
+ def forward(self, x):
93
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
94
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
95
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
96
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
97
+
98
+ # upsample
99
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear')
100
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
101
+
102
+ if self.skip_connection:
103
+ x4 = x4 + x2
104
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear')
105
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
106
+
107
+ if self.skip_connection:
108
+ x5 = x5 + x1
109
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear')
110
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
111
+
112
+ if self.skip_connection:
113
+ x6 = x6 + x0
114
+
115
+ # extra
116
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
117
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
118
+ out = self.conv9(out)
119
+ out = self.avg_pool(out)
120
+ out = torch.flatten(out, 1)
121
+ return out
model/encoder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class Encoder(nn.Module):
5
+ def __init__(self):
6
+ super(Encoder, self).__init__()
7
+
8
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
9
+
10
+ self.conv_featmap_1 = nn.Sequential(
11
+ nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1, bias=True),
12
+ nn.ReLU(),
13
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
14
+ nn.ReLU(),
15
+ nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, bias=True),
16
+ nn.ReLU(),
17
+ )
18
+
19
+ self.conv_featmap_2 = nn.Sequential(
20
+ nn.Conv2d(in_channels=16, out_channels=32, kernel_size=1, bias=True),
21
+ nn.ReLU(),
22
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
23
+ nn.ReLU(),
24
+ nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True),
25
+ nn.ReLU(),
26
+ )
27
+
28
+ self.conv_featmap_3 = nn.Sequential(
29
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, bias=True),
30
+ nn.ReLU(),
31
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=True),
32
+ nn.ReLU(),
33
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=True),
34
+ nn.ReLU(),
35
+ )
36
+
37
+ def forward(self, img):
38
+ featmap_1 = self.conv_featmap_1(img)
39
+ featmap_1_down = self.maxpool(featmap_1)
40
+
41
+ featmap_2 = self.conv_featmap_2(featmap_1_down)
42
+ featmap_2_down = self.maxpool(featmap_2)
43
+
44
+ featmap_3 = self.conv_featmap_3(featmap_2_down)
45
+
46
+ return featmap_3
47
+
48
+
49
+ class Encoder_RRDB(nn.Module):
50
+ def __init__(self, num_feat=16):
51
+ super(Encoder_RRDB, self).__init__()
52
+ self.conv_featmap = nn.Sequential(
53
+ nn.Conv2d(in_channels=3, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
54
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
55
+ nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
56
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
57
+ nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
58
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
59
+ nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
60
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
61
+ nn.Conv2d(in_channels=num_feat, out_channels=num_feat, kernel_size=3, padding=1, bias=True),
62
+ )
63
+
64
+ def forward(self, img):
65
+ featmap = self.conv_featmap(img)
66
+
67
+ return featmap
opt/__pycache__/option.cpython-38.pyc ADDED
Binary file (2.81 kB). View file
 
opt/option.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ parser = argparse.ArgumentParser(description='BebyGAN')
5
+
6
+ # Hardware specifications
7
+ parser.add_argument('--gpu_id', type=str, default = "0", help='specify GPU ID to use')
8
+ parser.add_argument('--num_workers', type=int, default=4)
9
+
10
+ # Data specifications
11
+ parser.add_argument('--dir_data', type=str, default='./dataset', help='dataset root directory')
12
+ parser.add_argument('--scale', type=int, default=4, help='super resolution scale')
13
+ parser.add_argument('--patch_size', type=int, default=64, help='LR patch size') # default = 128 (in the paper)
14
+
15
+ # Train specifications
16
+ parser.add_argument('--epochs', type=int, default=35000, help='total epochs')
17
+ parser.add_argument('--batch_size', type=int, default=1, help='size of each batch') # default = 8 (in the paper)
18
+
19
+ # Optimizer specificaions
20
+ parser.add_argument('--lr_G', type=float, default=1e-4, help='initial learning rate of generator')
21
+ parser.add_argument('--lr_D', type=float, default=1e-4, help='initial learning rate of discriminator')
22
+ parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1')
23
+ parser.add_argument('--beta2', type=float, default=0.99, help='ADAM beta2')
24
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
25
+
26
+ # Scheduler specifications
27
+ parser.add_argument('--interval1', type=int, default=2.5e5, help='1st step size (iteration)')
28
+ parser.add_argument('--interval2', type=int, default=3.5e5, help='2nd step size (iteration)')
29
+ parser.add_argument('--interval3', type=int, default=4.5e5, help='3rd step size (iteration)')
30
+ parser.add_argument('--interval4', type=int, default=5.5e5, help='4th step size (iteration)')
31
+ parser.add_argument('--gamma_G', type=float, default=0.5, help='generator learning rate decay ratio')
32
+ parser.add_argument('--gamma_D', type=float, default=0.5, help='discriminator learning rate decay ratio')
33
+
34
+ # Train specificaions
35
+ parser.add_argument('--snap_path', type=str, default='./weights', help='path to save model weights')
36
+ parser.add_argument('--save_freq', type=str, default=5, help='save model frequency (epoch)')
37
+ # Logger
38
+ parser.add_argument('--log_interval', type=int, default=20)
39
+ # checkpoint
40
+ parser.add_argument('--checkpoint', type=str, default=None, help='load checkpoint')
41
+ # pretrained
42
+ parser.add_argument('--pretrained', type=str, default=None)
43
+ # Optimizer specifications
44
+ parser.add_argument('--lambda_align', type=float, default=0.01, help='L1 loss weight')
45
+ parser.add_argument('--lambda_rec', type=float, default=1.0, help='back-projection loss weight')
46
+ parser.add_argument('--lambda_res', type=float, default=1.0, help='perceptual loss weight')
47
+ parser.add_argument('--lambda_sty', type=float, default=0.01, help='style loss weight')
48
+ parser.add_argument('--lambda_idt', type=float, default=0.01, help='identity loss weight')
49
+ parser.add_argument('--lambda_cyc', type=float, default=1, help='cycle loss weight')
50
+
51
+ parser.add_argument('--lambda_percept', type=float, default=0.01, help='perceptual loss weight')
52
+ parser.add_argument('--lambda_adv', type=float, default=0.01, help='adversarial loss weight')
53
+
54
+ # generator & discriminator specifications
55
+ parser.add_argument('--n_disc', type=int, default=1, help='number of iteration for discriminator update in one epoch')
56
+ parser.add_argument('--n_gen', type=int, default=2, help='number of iteration for generator update in one epoch')
57
+
58
+ # encoder & decoder specifications
59
+ parser.add_argument('--n_hidden_feats', type=int, default=64, help='number of feature vectors in hidden layer')
60
+ parser.add_argument('--n_sr_feats', type=int, default=64, help='number of feature vectors in RRDB layer')
61
+ # eval spec
62
+ parser.add_argument('--phase', type=str, default='train')
63
+
64
+ # test specifications
65
+ parser.add_argument('--weights', type=str, default = "/data4/anhdh4/SR2/USR_DA-main/weights/epoch_1660.pth",help='load weights for test')
66
+ parser.add_argument('--dir_test', type=str, default = "/data4/anhdh4/SR2/NTIRE2020/valid_source_crop",help='directory of test images')
67
+ parser.add_argument('--results', type=str, default='./results1660/', help='directory of test results')
68
+
69
+ args = parser.parse_args()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ opencv-python
3
+ scipy
4
+ wget
5
+ scikit-image
6
+ torch==1.13.0
7
+ torchmetrics==0.11.0
8
+ torchvision==0.14.0
9
+ tqdm
10
+ uvicorn
test/1.png ADDED
testsets/0848.png ADDED
testsets/0851.png ADDED
testsets/0855.png ADDED
testsets/0879.png ADDED
train.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ import wandb
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import transforms
9
+ from tqdm import tqdm
10
+
11
+ from data.LQGT_dataset import LQGTDataset, LQGTValDataset
12
+ from model import decoder, discriminator, encoder
13
+ from opt.option import args
14
+ from util.utils import (RandCrop, RandHorizontalFlip, RandRotate, ToTensor, RandCrop_pair,
15
+ VGG19PerceptualLoss)
16
+
17
+ from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
18
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
19
+
20
+ wandb.init(project='SR', config=args)
21
+
22
+
23
+
24
+ # device setting
25
+ if args.gpu_id is not None:
26
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
27
+ print('using GPU 0')
28
+ else:
29
+ print('use --gpu_id to specify GPU ID to use')
30
+ exit()
31
+
32
+ device = torch.device('cuda')
33
+
34
+ # make directory for saving weights
35
+ if not os.path.exists(args.snap_path):
36
+ os.mkdir(args.snap_path)
37
+
38
+ print("Loading dataset...")
39
+ # load training dataset
40
+ train_dataset = LQGTDataset(
41
+ db_path=args.dir_data,
42
+ transform=transforms.Compose([RandCrop(args.patch_size, args.scale), RandHorizontalFlip(), RandRotate(), ToTensor()])
43
+ )
44
+
45
+ val_dataset = LQGTValDataset(
46
+ db_path=args.dir_data,
47
+ transform=transforms.Compose([RandCrop_pair(args.patch_size, args.scale), ToTensor()])
48
+ )
49
+
50
+ train_loader = DataLoader(
51
+ train_dataset,
52
+ batch_size=args.batch_size,
53
+ num_workers=args.num_workers,
54
+ drop_last=True,
55
+ shuffle=True
56
+ )
57
+
58
+ val_loader = DataLoader(
59
+ val_dataset,
60
+ batch_size=args.batch_size,
61
+ num_workers=args.num_workers,
62
+ shuffle=False
63
+ )
64
+
65
+
66
+ print("Create model")
67
+ model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).to(device)
68
+ model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).to(device)
69
+ model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.scale*args.patch_size).to(device)
70
+ # define model (generator)
71
+ model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).to(device)
72
+ model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).to(device)
73
+ model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).to(device)
74
+
75
+ # define model (discriminator)
76
+
77
+ # model_Disc_feat = discriminator.UNetDiscriminator(num_in_ch=64).to(device)
78
+ # model_Disc_img_LR = discriminator.UNetDiscriminator(num_in_ch=3).to(device)
79
+ # model_Disc_img_HR = discriminator.UNetDiscriminator(num_in_ch=3).to(device)
80
+
81
+ # wandb logging
82
+ wandb.watch(model_Disc_feat)
83
+ wandb.watch(model_Disc_img_LR)
84
+ wandb.watch(model_Enc)
85
+ wandb.watch(model_Dec_Id)
86
+ wandb.watch(model_Dec_SR)
87
+
88
+
89
+ print("Define Loss")
90
+ # loss
91
+ loss_L1 = nn.L1Loss().to(device)
92
+ loss_MSE = nn.MSELoss().to(device)
93
+ loss_adversarial = nn.BCEWithLogitsLoss().to(device)
94
+ loss_percept = VGG19PerceptualLoss().to(device)
95
+
96
+
97
+ print("Define Optimizer")
98
+ # optimizer
99
+ params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters())
100
+ optimizer_G = optim.Adam(
101
+ params_G,
102
+ lr=args.lr_G,
103
+ betas=(args.beta1, args.beta2),
104
+ weight_decay=args.weight_decay,
105
+ amsgrad=True
106
+ )
107
+ params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters())
108
+ optimizer_D = optim.Adam(
109
+ params_D,
110
+ lr=args.lr_D,
111
+ betas=(args.beta1, args.beta2),
112
+ weight_decay=args.weight_decay,
113
+ amsgrad=True
114
+ )
115
+
116
+ print("Define Scheduler")
117
+ # Scheduler
118
+ iter_indices = [args.interval1, args.interval2, args.interval3]
119
+ scheduler_G = optim.lr_scheduler.MultiStepLR(
120
+ optimizer=optimizer_G,
121
+ milestones=iter_indices,
122
+ gamma=0.5
123
+ )
124
+ scheduler_D = optim.lr_scheduler.MultiStepLR(
125
+ optimizer=optimizer_D,
126
+ milestones=iter_indices,
127
+ gamma=0.5
128
+ )
129
+
130
+ # print("Data Parallel")
131
+ model_Enc = nn.DataParallel(model_Enc)
132
+ model_Dec_Id = nn.DataParallel(model_Dec_Id)
133
+ model_Dec_SR = nn.DataParallel(model_Dec_SR)
134
+
135
+ # define model (discriminator)
136
+ #model_Disc_feat = nn.DataParallel(model_Disc_feat)
137
+ #model_Disc_img_LR = nn.DataParallel(model_Disc_img_LR)
138
+ #model_Disc_img_HR = nn.DataParallel(model_Disc_img_HR)
139
+
140
+ print("Load model weight")
141
+ # load model weights & optimzer % scheduler
142
+ if args.checkpoint is not None:
143
+ checkpoint = torch.load(args.checkpoint)
144
+
145
+ model_Enc.load_state_dict(checkpoint['model_Enc'])
146
+ model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id'])
147
+ model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
148
+ model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat'])
149
+ model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR'])
150
+ model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR'])
151
+
152
+ optimizer_D.load_state_dict(checkpoint['optimizer_D'])
153
+ optimizer_G.load_state_dict(checkpoint['optimizer_G'])
154
+
155
+ scheduler_D.load_state_dict(checkpoint['scheduler_D'])
156
+ scheduler_G.load_state_dict(checkpoint['scheduler_G'])
157
+
158
+ start_epoch = checkpoint['epoch']
159
+ else:
160
+ start_epoch = 0
161
+
162
+
163
+ if args.pretrained is not None:
164
+ ckpt = torch.load(args.pretrained)
165
+ ckpt["params"]["conv_first.weight"] = ckpt["params"]["conv_first.weight"][:,0,:,:].expand(64,64,3,3)
166
+ model_Dec_SR.load_state_dict(ckpt["params"])
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+ # model_Enc = model_Enc.to(device)
175
+ # model_Dec_Id = model_Dec_Id.to(device)
176
+ # model_Dec_SR = model_Dec_SR.to(device)
177
+
178
+ # # define model (discriminator)
179
+ # model_Disc_feat = model_Disc_feat.to(device)
180
+ # model_Disc_img_LR = model_Disc_img_LR.to(device)
181
+ # model_Disc_img_HR =model_Disc_img_HR.to(device)
182
+ # training
183
+
184
+ PSNR = PeakSignalNoiseRatio().to(device)
185
+ SSIM = StructuralSimilarityIndexMeasure().to(device)
186
+ LPIPS = LearnedPerceptualImagePatchSimilarity().to(device)
187
+
188
+ if args.phase == "train":
189
+ for epoch in range(start_epoch, args.epochs):
190
+ # generator
191
+ model_Enc.train()
192
+ model_Dec_Id.train()
193
+ model_Dec_SR.train()
194
+
195
+ # discriminator
196
+ model_Disc_feat.train()
197
+ model_Disc_img_LR.train()
198
+ model_Disc_img_HR.train()
199
+
200
+ running_loss_D_total = 0.0
201
+ running_loss_G_total = 0.0
202
+
203
+ running_loss_align = 0.0
204
+ running_loss_rec = 0.0
205
+ running_loss_res = 0.0
206
+ running_loss_sty = 0.0
207
+ running_loss_idt = 0.0
208
+ running_loss_cyc = 0.0
209
+
210
+ iter = 0
211
+
212
+ for data in tqdm(train_loader):
213
+ iter += 1
214
+
215
+ ########################
216
+ # data load #
217
+ ########################
218
+ X_t, Y_s = data['img_LQ'], data['img_GT']
219
+
220
+ ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic')
221
+ X_s = ds4(Y_s)
222
+
223
+ X_t = X_t.cuda(non_blocking=True)
224
+ X_s = X_s.cuda(non_blocking=True)
225
+ Y_s = Y_s.cuda(non_blocking=True)
226
+
227
+ # real label and fake label
228
+ batch_size = X_t.size(0)
229
+ real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True)
230
+ fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True)
231
+
232
+
233
+ ########################
234
+ # (1) Update D network #
235
+ ########################
236
+ model_Disc_feat.zero_grad()
237
+ model_Disc_img_LR.zero_grad()
238
+ model_Disc_img_HR.zero_grad()
239
+
240
+ for i in range(args.n_disc):
241
+ # generator output (feature domain)
242
+ F_t = model_Enc(X_t)
243
+ F_s = model_Enc(X_s)
244
+
245
+ # 1. feature aligment loss (discriminator)
246
+ # output of discriminator (feature domain) (b x c(=1) x h x w)
247
+ output_Disc_F_t = model_Disc_feat(F_t.detach())
248
+ output_Disc_F_s = model_Disc_feat(F_s.detach())
249
+ # discriminator loss (feature domain)
250
+ loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label)
251
+ loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label)
252
+ loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2
253
+
254
+ # 2. SR reconstruction loss (discriminator)
255
+ # generator output (image domain)
256
+ Y_s_s = model_Dec_SR(F_s)
257
+ # output of discriminator (image domain)
258
+ output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach())
259
+ output_Disc_Y_s = model_Disc_img_HR(Y_s)
260
+ # discriminator loss (image domain)
261
+ loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label)
262
+ loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
263
+ loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2
264
+
265
+ # 4. Target degradation style loss
266
+ # generator output (image domain)
267
+ X_s_t = model_Dec_Id(F_s)
268
+ # output of discriminator (image domain)
269
+ output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach())
270
+ output_Disc_X_t = model_Disc_img_LR(X_t)
271
+ # discriminator loss (image domain)
272
+ loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label)
273
+ loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label)
274
+ loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2
275
+
276
+ # 6. Cycle loss
277
+ # generator output (image domain)
278
+ Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
279
+ # output of discriminator (image domain)
280
+ output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach())
281
+ output_Disc_Y_s = model_Disc_img_HR(Y_s)
282
+ # discriminator loss (image domain)
283
+ loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label)
284
+ loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
285
+ loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2
286
+
287
+ # discriminator weight update
288
+ loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc
289
+ loss_D_total.backward()
290
+ optimizer_D.step()
291
+
292
+
293
+
294
+ scheduler_D.step()
295
+
296
+
297
+ ########################
298
+ # (2) Update G network #
299
+ ########################
300
+ model_Enc.zero_grad()
301
+ model_Dec_Id.zero_grad()
302
+ model_Dec_SR.zero_grad()
303
+
304
+ for i in range(args.n_gen):
305
+ # generator output (feature domain)
306
+ F_t = model_Enc(X_t)
307
+ F_s = model_Enc(X_s)
308
+
309
+ # 1. feature alignment loss (generator)
310
+ # output of discriminator (feature domain)
311
+ output_Disc_F_t = model_Disc_feat(F_t)
312
+ output_Disc_F_s = model_Disc_feat(F_s)
313
+ # generator loss (feature domain)
314
+ loss_G_F_t = loss_MSE(output_Disc_F_t, (real_label + fake_label)/2)
315
+ loss_G_F_s = loss_MSE(output_Disc_F_s, (real_label + fake_label)/2)
316
+ L_align_E = loss_G_F_t + loss_G_F_s
317
+
318
+ # 2. SR reconstruction loss
319
+ # generator output (image domain)
320
+ Y_s_s = model_Dec_SR(F_s)
321
+ # output of discriminator (image domain)
322
+ output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s)
323
+ # L1 loss
324
+ loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s)
325
+ # perceptual loss
326
+ loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s)
327
+ # adversatial loss
328
+ loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label)
329
+ L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s
330
+
331
+ # 3. Target LR restoration loss
332
+ X_t_t = model_Dec_Id(F_t)
333
+ L_res_G_t = loss_L1(X_t, X_t_t)
334
+
335
+ # 4. Target degredation style loss
336
+ # generator output (image domain)
337
+ X_s_t = model_Dec_Id(F_s)
338
+ # output of discriminator (img domain)
339
+ output_Disc_X_s_t = model_Disc_img_LR(X_s_t)
340
+ # generator loss (feature domain)
341
+ loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label)
342
+ L_sty_G_t = loss_G_X_s_t
343
+
344
+ # 5. Feature identity loss
345
+ F_s_tilda = model_Enc(model_Dec_Id(F_s))
346
+ L_idt_G_t = loss_L1(F_s, F_s_tilda)
347
+
348
+ # 6. Cycle loss
349
+ # generator output (image domain)
350
+ Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
351
+ # output of discriminator (image domain)
352
+ output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s)
353
+ # L1 loss
354
+ loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s)
355
+ # perceptual loss
356
+ loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s)
357
+ # adversarial loss
358
+ loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label)
359
+ L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s
360
+
361
+ # generator weight update
362
+ loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t + args.lambda_cyc*L_cyc_G_t_G_SR
363
+ loss_G_total.backward()
364
+ optimizer_G.step()
365
+ scheduler_G.step()
366
+
367
+
368
+ ########################
369
+ # compute loss #
370
+ ########################
371
+ running_loss_D_total += loss_D_total.item()
372
+ running_loss_G_total += loss_G_total.item()
373
+
374
+ running_loss_align += L_align_E.item()
375
+ running_loss_rec += L_rec_G_SR.item()
376
+ running_loss_res += L_res_G_t.item()
377
+ running_loss_sty += L_sty_G_t.item()
378
+ running_loss_idt += L_idt_G_t.item()
379
+ running_loss_cyc += L_cyc_G_t_G_SR.item()
380
+ if iter % args.log_interval == 0:
381
+ wandb.log(
382
+ {
383
+ "loss_D_total_step": running_loss_D_total/iter,
384
+ "loss_G_total_step": running_loss_G_total/iter,
385
+ "loss_align_step": running_loss_align/iter,
386
+ "loss_rec_step": running_loss_rec/iter,
387
+ "loss_res_step": running_loss_res/iter,
388
+ "loss_sty_step": running_loss_sty/iter,
389
+ "loss_idt_step": running_loss_idt/iter,
390
+ "loss_cyc_step": running_loss_cyc/iter,
391
+ }
392
+ )
393
+ ### EVALUATE ###
394
+ total_PSNR = 0
395
+ total_SSIM = 0
396
+ total_LPIPS = 0
397
+ val_iter = 0
398
+ with torch.no_grad():
399
+ model_Enc.eval()
400
+ model_Dec_SR.eval()
401
+ for batch_idx, batch in enumerate(val_loader):
402
+ val_iter += 1
403
+ source = batch["img_LQ"].to(device)
404
+ target = batch["img_GT"].to(device)
405
+
406
+ feat = model_Enc(source)
407
+ out = model_Dec_SR(feat)
408
+
409
+ total_PSNR += PSNR(out, target)
410
+ total_SSIM += SSIM(out, target)
411
+ total_LPIPS += LPIPS(out, target)
412
+
413
+ wandb.log(
414
+ {
415
+ "epoch": epoch,
416
+ "lr": optimizer_G.param_groups[0]['lr'],
417
+ "loss_D_total_epoch": running_loss_D_total/iter,
418
+ "loss_G_total_epoch": running_loss_G_total/iter,
419
+ "loss_align_epoch": running_loss_align/iter,
420
+ "loss_rec_epoch": running_loss_rec/iter,
421
+ "loss_res_epoch": running_loss_res/iter,
422
+ "loss_sty_epoch": running_loss_sty/iter,
423
+ "loss_idt_epoch": running_loss_idt/iter,
424
+ "loss_cyc_epoch": running_loss_cyc/iter,
425
+ "PSNR_val": total_PSNR/val_iter,
426
+ "SSIM_val": total_SSIM/val_iter,
427
+ "LPIPS_val": total_LPIPS/val_iter
428
+ }
429
+ )
430
+
431
+
432
+ if (epoch+1) % args.save_freq == 0:
433
+ weights_file_name = 'epoch_%d.pth' % (epoch+1)
434
+ weights_file = os.path.join(args.snap_path, weights_file_name)
435
+ torch.save({
436
+ 'epoch': epoch,
437
+
438
+ 'model_Enc': model_Enc.state_dict(),
439
+ 'model_Dec_Id': model_Dec_Id.state_dict(),
440
+ 'model_Dec_SR': model_Dec_SR.state_dict(),
441
+ 'model_Disc_feat': model_Disc_feat.state_dict(),
442
+ 'model_Disc_img_LR': model_Disc_img_LR.state_dict(),
443
+ 'model_Disc_img_HR': model_Disc_img_HR.state_dict(),
444
+
445
+ 'optimizer_D': optimizer_D.state_dict(),
446
+ 'optimizer_G': optimizer_G.state_dict(),
447
+
448
+ 'scheduler_D': scheduler_D.state_dict(),
449
+ 'scheduler_G': scheduler_G.state_dict(),
450
+ }, weights_file)
451
+ print('save weights of epoch %d' % (epoch+1))
452
+ else:
453
+ ### EVALUATE ###
454
+ total_PSNR = 0
455
+ total_SSIM = 0
456
+ total_LPIPS = 0
457
+ val_iter = 0
458
+ with torch.no_grad():
459
+ model_Enc.eval()
460
+ model_Dec_SR.eval()
461
+ for batch_idx, batch in enumerate(val_loader):
462
+ val_iter += 1
463
+ source = batch["img_LQ"].to(device)
464
+ target = batch["img_GT"].to(device)
465
+
466
+ feat = model_Enc(source)
467
+ out = model_Dec_SR(feat)
468
+
469
+ total_PSNR += PSNR(out, target)
470
+ total_SSIM += SSIM(out, target)
471
+ total_LPIPS += LPIPS(out, target)
472
+ print("PSNR_val: ", total_PSNR/val_iter)
473
+ print("SSIM_val: ", total_SSIM/val_iter)
474
+ print("LPIPS_val: ", total_LPIPS/val_iter)
util/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.02 kB). View file
 
util/utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ import math
5
+ import cv2
6
+ import numpy as np
7
+ from scipy.ndimage import rotate
8
+
9
+
10
+ class RandCrop(object):
11
+ def __init__(self, crop_size, scale):
12
+ # if output size is tuple -> (height, width)
13
+ assert isinstance(crop_size, (int, tuple))
14
+ if isinstance(crop_size, int):
15
+ self.crop_size = (crop_size, crop_size)
16
+ else:
17
+ assert len(crop_size) == 2
18
+ self.crop_size = crop_size
19
+
20
+ self.scale = scale
21
+
22
+ def __call__(self, sample):
23
+ # img_LQ: H x W x C (numpy array)
24
+ img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
25
+
26
+ h, w, c = img_LQ.shape
27
+ new_h, new_w = self.crop_size
28
+ top = np.random.randint(0, h - new_h)
29
+ left = np.random.randint(0, w - new_w)
30
+ img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :]
31
+
32
+ h, w, c = img_GT.shape
33
+ top = np.random.randint(0, h - self.scale*new_h)
34
+ left = np.random.randint(0, w - self.scale*new_w)
35
+ img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :]
36
+
37
+ sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop}
38
+ return sample
39
+
40
+
41
+ class RandRotate(object):
42
+ def __call__(self, sample):
43
+ # img_LQ: H x W x C (numpy array)
44
+ img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
45
+
46
+ prob_rotate = np.random.random()
47
+ if prob_rotate < 0.25:
48
+ img_LQ = rotate(img_LQ, 90).copy()
49
+ img_GT = rotate(img_GT, 90).copy()
50
+ elif prob_rotate < 0.5:
51
+ img_LQ = rotate(img_LQ, 90).copy()
52
+ img_GT = rotate(img_GT, 90).copy()
53
+ elif prob_rotate < 0.75:
54
+ img_LQ = rotate(img_LQ, 90).copy()
55
+ img_GT = rotate(img_GT, 90).copy()
56
+
57
+ sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
58
+ return sample
59
+
60
+
61
+ class RandHorizontalFlip(object):
62
+ def __call__(self, sample):
63
+ # img_LQ: H x W x C (numpy array)
64
+ img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
65
+
66
+ prob_lr = np.random.random()
67
+ if prob_lr < 0.5:
68
+ img_LQ = np.fliplr(img_LQ).copy()
69
+ img_GT = np.fliplr(img_GT).copy()
70
+
71
+ sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
72
+ return sample
73
+
74
+
75
+ class ToTensor(object):
76
+ def __call__(self, sample):
77
+ # img_LQ : H x W x C (numpy array) -> C x H x W (torch tensor)
78
+ img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
79
+
80
+ img_LQ = img_LQ.transpose((2, 0, 1))
81
+ img_GT = img_GT.transpose((2, 0, 1))
82
+
83
+ img_LQ = torch.from_numpy(img_LQ)
84
+ img_GT = torch.from_numpy(img_GT)
85
+
86
+ sample = {'img_LQ': img_LQ, 'img_GT': img_GT}
87
+ return sample
88
+
89
+
90
+ class VGG19PerceptualLoss(torch.nn.Module):
91
+ def __init__(self, feature_layer=35):
92
+ super(VGG19PerceptualLoss, self).__init__()
93
+ model = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT)
94
+ self.features = torch.nn.Sequential(*list(model.features.children())[:feature_layer]).eval()
95
+ # Freeze parameters
96
+ for name, param in self.features.named_parameters():
97
+ param.requires_grad = False
98
+
99
+ def forward(self, source, target):
100
+ vgg_loss = torch.nn.functional.l1_loss(self.features(source), self.features(target))
101
+
102
+ return vgg_loss
103
+
104
+
105
+ class RandCrop_pair(object):
106
+ def __init__(self, crop_size, scale):
107
+ # if output size is tuple -> (height, width)
108
+ assert isinstance(crop_size, (int, tuple))
109
+ if isinstance(crop_size, int):
110
+ self.crop_size = (crop_size, crop_size)
111
+ else:
112
+ assert len(crop_size) == 2
113
+ self.crop_size = crop_size
114
+
115
+ self.scale = scale
116
+
117
+ def __call__(self, sample):
118
+ # img_LQ: H x W x C (numpy array)
119
+ img_LQ, img_GT = sample['img_LQ'], sample['img_GT']
120
+
121
+ h, w, c = img_LQ.shape
122
+ new_h, new_w = self.crop_size
123
+ top = np.random.randint(0, h - new_h)
124
+ left = np.random.randint(0, w - new_w)
125
+ img_LQ_crop = img_LQ[top: top+new_h, left: left+new_w, :]
126
+
127
+ h, w, c = img_GT.shape
128
+ top = self.scale*top
129
+ left = self.scale*left
130
+ img_GT_crop = img_GT[top: top + self.scale*new_h, left: left + self.scale*new_w, :]
131
+
132
+ sample = {'img_LQ': img_LQ_crop, 'img_GT': img_GT_crop}
133
+ return sample