guetLzy commited on
Commit
08dcbea
·
1 Parent(s): d935f64

Upload 20 files

Browse files
tests/data/gt.lmdb/data.mdb ADDED
Binary file (758 kB). View file
 
tests/data/gt.lmdb/lock.mdb ADDED
Binary file (8.19 kB). View file
 
tests/data/gt.lmdb/meta_info.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ baboon.png (480,500,3) 1
2
+ comic.png (360,240,3) 1
tests/data/gt/baboon.png ADDED
tests/data/gt/comic.png ADDED
tests/data/lq.lmdb/data.mdb ADDED
Binary file (65.5 kB). View file
 
tests/data/lq.lmdb/lock.mdb ADDED
Binary file (8.19 kB). View file
 
tests/data/lq.lmdb/meta_info.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ baboon.png (120,125,3) 1
2
+ comic.png (80,60,3) 1
tests/data/lq/baboon.png ADDED
tests/data/lq/comic.png ADDED
tests/data/meta_info_gt.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ baboon.png
2
+ comic.png
tests/data/meta_info_pair.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gt/baboon.png, lq/baboon.png
2
+ gt/comic.png, lq/comic.png
tests/data/test_realesrgan_dataset.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Demo
2
+ type: RealESRGANDataset
3
+ dataroot_gt: tests/data/gt
4
+ meta_info: tests/data/meta_info_gt.txt
5
+ io_backend:
6
+ type: disk
7
+
8
+ blur_kernel_size: 21
9
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
10
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
11
+ sinc_prob: 1
12
+ blur_sigma: [0.2, 3]
13
+ betag_range: [0.5, 4]
14
+ betap_range: [1, 2]
15
+
16
+ blur_kernel_size2: 21
17
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
18
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
19
+ sinc_prob2: 1
20
+ blur_sigma2: [0.2, 1.5]
21
+ betag_range2: [0.5, 4]
22
+ betap_range2: [1, 2]
23
+
24
+ final_sinc_prob: 1
25
+
26
+ gt_size: 128
27
+ use_hflip: True
28
+ use_rot: False
tests/data/test_realesrgan_model.yml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scale: 4
2
+ num_gpu: 1
3
+ manual_seed: 0
4
+ is_train: True
5
+ dist: False
6
+
7
+ # ----------------- options for synthesizing training data ----------------- #
8
+ # USM the ground-truth
9
+ l1_gt_usm: True
10
+ percep_gt_usm: True
11
+ gan_gt_usm: False
12
+
13
+ # the first degradation process
14
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
15
+ resize_range: [0.15, 1.5]
16
+ gaussian_noise_prob: 1
17
+ noise_range: [1, 30]
18
+ poisson_scale_range: [0.05, 3]
19
+ gray_noise_prob: 1
20
+ jpeg_range: [30, 95]
21
+
22
+ # the second degradation process
23
+ second_blur_prob: 1
24
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
25
+ resize_range2: [0.3, 1.2]
26
+ gaussian_noise_prob2: 1
27
+ noise_range2: [1, 25]
28
+ poisson_scale_range2: [0.05, 2.5]
29
+ gray_noise_prob2: 1
30
+ jpeg_range2: [30, 95]
31
+
32
+ gt_size: 32
33
+ queue_size: 1
34
+
35
+ # network structures
36
+ network_g:
37
+ type: RRDBNet
38
+ num_in_ch: 3
39
+ num_out_ch: 3
40
+ num_feat: 4
41
+ num_block: 1
42
+ num_grow_ch: 2
43
+
44
+ network_d:
45
+ type: UNetDiscriminatorSN
46
+ num_in_ch: 3
47
+ num_feat: 2
48
+ skip_connection: True
49
+
50
+ # path
51
+ path:
52
+ pretrain_network_g: ~
53
+ param_key_g: params_ema
54
+ strict_load_g: true
55
+ resume_state: ~
56
+
57
+ # training settings
58
+ train:
59
+ ema_decay: 0.999
60
+ optim_g:
61
+ type: Adam
62
+ lr: !!float 1e-4
63
+ weight_decay: 0
64
+ betas: [0.9, 0.99]
65
+ optim_d:
66
+ type: Adam
67
+ lr: !!float 1e-4
68
+ weight_decay: 0
69
+ betas: [0.9, 0.99]
70
+
71
+ scheduler:
72
+ type: MultiStepLR
73
+ milestones: [400000]
74
+ gamma: 0.5
75
+
76
+ total_iter: 400000
77
+ warmup_iter: -1 # no warm up
78
+
79
+ # losses
80
+ pixel_opt:
81
+ type: L1Loss
82
+ loss_weight: 1.0
83
+ reduction: mean
84
+ # perceptual loss (content and style losses)
85
+ perceptual_opt:
86
+ type: PerceptualLoss
87
+ layer_weights:
88
+ # before relu
89
+ 'conv1_2': 0.1
90
+ 'conv2_2': 0.1
91
+ 'conv3_4': 1
92
+ 'conv4_4': 1
93
+ 'conv5_4': 1
94
+ vgg_type: vgg19
95
+ use_input_norm: true
96
+ perceptual_weight: !!float 1.0
97
+ style_weight: 0
98
+ range_norm: false
99
+ criterion: l1
100
+ # gan loss
101
+ gan_opt:
102
+ type: GANLoss
103
+ gan_type: vanilla
104
+ real_label_val: 1.0
105
+ fake_label_val: 0.0
106
+ loss_weight: !!float 1e-1
107
+
108
+ net_d_iters: 1
109
+ net_d_init_iters: 0
110
+
111
+
112
+ # validation settings
113
+ val:
114
+ val_freq: !!float 5e3
115
+ save_img: False
tests/data/test_realesrgan_paired_dataset.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Demo
2
+ type: RealESRGANPairedDataset
3
+ scale: 4
4
+ dataroot_gt: tests/data
5
+ dataroot_lq: tests/data
6
+ meta_info: tests/data/meta_info_pair.txt
7
+ io_backend:
8
+ type: disk
9
+
10
+ phase: train
11
+ gt_size: 128
12
+ use_hflip: True
13
+ use_rot: False
tests/data/test_realesrnet_model.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scale: 4
2
+ num_gpu: 1
3
+ manual_seed: 0
4
+ is_train: True
5
+ dist: False
6
+
7
+ # ----------------- options for synthesizing training data ----------------- #
8
+ gt_usm: True # USM the ground-truth
9
+
10
+ # the first degradation process
11
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
12
+ resize_range: [0.15, 1.5]
13
+ gaussian_noise_prob: 1
14
+ noise_range: [1, 30]
15
+ poisson_scale_range: [0.05, 3]
16
+ gray_noise_prob: 1
17
+ jpeg_range: [30, 95]
18
+
19
+ # the second degradation process
20
+ second_blur_prob: 1
21
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
22
+ resize_range2: [0.3, 1.2]
23
+ gaussian_noise_prob2: 1
24
+ noise_range2: [1, 25]
25
+ poisson_scale_range2: [0.05, 2.5]
26
+ gray_noise_prob2: 1
27
+ jpeg_range2: [30, 95]
28
+
29
+ gt_size: 32
30
+ queue_size: 1
31
+
32
+ # network structures
33
+ network_g:
34
+ type: RRDBNet
35
+ num_in_ch: 3
36
+ num_out_ch: 3
37
+ num_feat: 4
38
+ num_block: 1
39
+ num_grow_ch: 2
40
+
41
+ # path
42
+ path:
43
+ pretrain_network_g: ~
44
+ param_key_g: params_ema
45
+ strict_load_g: true
46
+ resume_state: ~
47
+
48
+ # training settings
49
+ train:
50
+ ema_decay: 0.999
51
+ optim_g:
52
+ type: Adam
53
+ lr: !!float 2e-4
54
+ weight_decay: 0
55
+ betas: [0.9, 0.99]
56
+
57
+ scheduler:
58
+ type: MultiStepLR
59
+ milestones: [1000000]
60
+ gamma: 0.5
61
+
62
+ total_iter: 1000000
63
+ warmup_iter: -1 # no warm up
64
+
65
+ # losses
66
+ pixel_opt:
67
+ type: L1Loss
68
+ loss_weight: 1.0
69
+ reduction: mean
70
+
71
+
72
+ # validation settings
73
+ val:
74
+ val_freq: !!float 5e3
75
+ save_img: False
tests/test_dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import yaml
3
+
4
+ from realesrgan.data.realesrgan_dataset import RealESRGANDataset
5
+ from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
6
+
7
+
8
+ def test_realesrgan_dataset():
9
+
10
+ with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
11
+ opt = yaml.load(f, Loader=yaml.FullLoader)
12
+
13
+ dataset = RealESRGANDataset(opt)
14
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
15
+ assert len(dataset) == 2 # whether to read correct meta info
16
+ assert dataset.kernel_list == [
17
+ 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
18
+ ] # correct initialization the degradation configurations
19
+ assert dataset.betag_range2 == [0.5, 4]
20
+
21
+ # test __getitem__
22
+ result = dataset.__getitem__(0)
23
+ # check returned keys
24
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
25
+ assert set(expected_keys).issubset(set(result.keys()))
26
+ # check shape and contents
27
+ assert result['gt'].shape == (3, 400, 400)
28
+ assert result['kernel1'].shape == (21, 21)
29
+ assert result['kernel2'].shape == (21, 21)
30
+ assert result['sinc_kernel'].shape == (21, 21)
31
+ assert result['gt_path'] == 'tests/data/gt/baboon.png'
32
+
33
+ # ------------------ test lmdb backend -------------------- #
34
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
35
+ opt['io_backend']['type'] = 'lmdb'
36
+
37
+ dataset = RealESRGANDataset(opt)
38
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
39
+ assert len(dataset.paths) == 2 # whether to read correct meta info
40
+ assert dataset.kernel_list == [
41
+ 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
42
+ ] # correct initialization the degradation configurations
43
+ assert dataset.betag_range2 == [0.5, 4]
44
+
45
+ # test __getitem__
46
+ result = dataset.__getitem__(1)
47
+ # check returned keys
48
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
49
+ assert set(expected_keys).issubset(set(result.keys()))
50
+ # check shape and contents
51
+ assert result['gt'].shape == (3, 400, 400)
52
+ assert result['kernel1'].shape == (21, 21)
53
+ assert result['kernel2'].shape == (21, 21)
54
+ assert result['sinc_kernel'].shape == (21, 21)
55
+ assert result['gt_path'] == 'comic'
56
+
57
+ # ------------------ test with sinc_prob = 0 -------------------- #
58
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
59
+ opt['io_backend']['type'] = 'lmdb'
60
+ opt['sinc_prob'] = 0
61
+ opt['sinc_prob2'] = 0
62
+ opt['final_sinc_prob'] = 0
63
+ dataset = RealESRGANDataset(opt)
64
+ result = dataset.__getitem__(0)
65
+ # check returned keys
66
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
67
+ assert set(expected_keys).issubset(set(result.keys()))
68
+ # check shape and contents
69
+ assert result['gt'].shape == (3, 400, 400)
70
+ assert result['kernel1'].shape == (21, 21)
71
+ assert result['kernel2'].shape == (21, 21)
72
+ assert result['sinc_kernel'].shape == (21, 21)
73
+ assert result['gt_path'] == 'baboon'
74
+
75
+ # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
76
+ with pytest.raises(ValueError):
77
+ opt['dataroot_gt'] = 'tests/data/gt'
78
+ opt['io_backend']['type'] = 'lmdb'
79
+ dataset = RealESRGANDataset(opt)
80
+
81
+
82
+ def test_realesrgan_paired_dataset():
83
+
84
+ with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
85
+ opt = yaml.load(f, Loader=yaml.FullLoader)
86
+
87
+ dataset = RealESRGANPairedDataset(opt)
88
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
89
+ assert len(dataset) == 2 # whether to read correct meta info
90
+
91
+ # test __getitem__
92
+ result = dataset.__getitem__(0)
93
+ # check returned keys
94
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
95
+ assert set(expected_keys).issubset(set(result.keys()))
96
+ # check shape and contents
97
+ assert result['gt'].shape == (3, 128, 128)
98
+ assert result['lq'].shape == (3, 32, 32)
99
+ assert result['gt_path'] == 'tests/data/gt/baboon.png'
100
+ assert result['lq_path'] == 'tests/data/lq/baboon.png'
101
+
102
+ # ------------------ test lmdb backend -------------------- #
103
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
104
+ opt['dataroot_lq'] = 'tests/data/lq.lmdb'
105
+ opt['io_backend']['type'] = 'lmdb'
106
+
107
+ dataset = RealESRGANPairedDataset(opt)
108
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
109
+ assert len(dataset) == 2 # whether to read correct meta info
110
+
111
+ # test __getitem__
112
+ result = dataset.__getitem__(1)
113
+ # check returned keys
114
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
115
+ assert set(expected_keys).issubset(set(result.keys()))
116
+ # check shape and contents
117
+ assert result['gt'].shape == (3, 128, 128)
118
+ assert result['lq'].shape == (3, 32, 32)
119
+ assert result['gt_path'] == 'comic'
120
+ assert result['lq_path'] == 'comic'
121
+
122
+ # ------------------ test paired_paths_from_folder -------------------- #
123
+ opt['dataroot_gt'] = 'tests/data/gt'
124
+ opt['dataroot_lq'] = 'tests/data/lq'
125
+ opt['io_backend'] = dict(type='disk')
126
+ opt['meta_info'] = None
127
+
128
+ dataset = RealESRGANPairedDataset(opt)
129
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
130
+ assert len(dataset) == 2 # whether to read correct meta info
131
+
132
+ # test __getitem__
133
+ result = dataset.__getitem__(0)
134
+ # check returned keys
135
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
136
+ assert set(expected_keys).issubset(set(result.keys()))
137
+ # check shape and contents
138
+ assert result['gt'].shape == (3, 128, 128)
139
+ assert result['lq'].shape == (3, 32, 32)
140
+
141
+ # ------------------ test normalization -------------------- #
142
+ dataset.mean = [0.5, 0.5, 0.5]
143
+ dataset.std = [0.5, 0.5, 0.5]
144
+ # test __getitem__
145
+ result = dataset.__getitem__(0)
146
+ # check returned keys
147
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
148
+ assert set(expected_keys).issubset(set(result.keys()))
149
+ # check shape and contents
150
+ assert result['gt'].shape == (3, 128, 128)
151
+ assert result['lq'].shape == (3, 32, 32)
tests/test_discriminator_arch.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
4
+
5
+
6
+ def test_unetdiscriminatorsn():
7
+ """Test arch: UNetDiscriminatorSN."""
8
+
9
+ # model init and forward (cpu)
10
+ net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
11
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
12
+ output = net(img)
13
+ assert output.shape == (1, 1, 32, 32)
14
+
15
+ # model init and forward (gpu)
16
+ if torch.cuda.is_available():
17
+ net.cuda()
18
+ output = net(img.cuda())
19
+ assert output.shape == (1, 1, 32, 32)
tests/test_model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+ from basicsr.archs.rrdbnet_arch import RRDBNet
4
+ from basicsr.data.paired_image_dataset import PairedImageDataset
5
+ from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
6
+
7
+ from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
8
+ from realesrgan.models.realesrgan_model import RealESRGANModel
9
+ from realesrgan.models.realesrnet_model import RealESRNetModel
10
+
11
+
12
+ def test_realesrnet_model():
13
+ with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
14
+ opt = yaml.load(f, Loader=yaml.FullLoader)
15
+
16
+ # build model
17
+ model = RealESRNetModel(opt)
18
+ # test attributes
19
+ assert model.__class__.__name__ == 'RealESRNetModel'
20
+ assert isinstance(model.net_g, RRDBNet)
21
+ assert isinstance(model.cri_pix, L1Loss)
22
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
23
+
24
+ # prepare data
25
+ gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
26
+ kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
27
+ kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
28
+ sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
29
+ data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
30
+ model.feed_data(data)
31
+ # check dequeue
32
+ model.feed_data(data)
33
+ # check data shape
34
+ assert model.lq.shape == (1, 3, 8, 8)
35
+ assert model.gt.shape == (1, 3, 32, 32)
36
+
37
+ # change probability to test if-else
38
+ model.opt['gaussian_noise_prob'] = 0
39
+ model.opt['gray_noise_prob'] = 0
40
+ model.opt['second_blur_prob'] = 0
41
+ model.opt['gaussian_noise_prob2'] = 0
42
+ model.opt['gray_noise_prob2'] = 0
43
+ model.feed_data(data)
44
+ # check data shape
45
+ assert model.lq.shape == (1, 3, 8, 8)
46
+ assert model.gt.shape == (1, 3, 32, 32)
47
+
48
+ # ----------------- test nondist_validation -------------------- #
49
+ # construct dataloader
50
+ dataset_opt = dict(
51
+ name='Demo',
52
+ dataroot_gt='tests/data/gt',
53
+ dataroot_lq='tests/data/lq',
54
+ io_backend=dict(type='disk'),
55
+ scale=4,
56
+ phase='val')
57
+ dataset = PairedImageDataset(dataset_opt)
58
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
59
+ assert model.is_train is True
60
+ model.nondist_validation(dataloader, 1, None, False)
61
+ assert model.is_train is True
62
+
63
+
64
+ def test_realesrgan_model():
65
+ with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
66
+ opt = yaml.load(f, Loader=yaml.FullLoader)
67
+
68
+ # build model
69
+ model = RealESRGANModel(opt)
70
+ # test attributes
71
+ assert model.__class__.__name__ == 'RealESRGANModel'
72
+ assert isinstance(model.net_g, RRDBNet) # generator
73
+ assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
74
+ assert isinstance(model.cri_pix, L1Loss)
75
+ assert isinstance(model.cri_perceptual, PerceptualLoss)
76
+ assert isinstance(model.cri_gan, GANLoss)
77
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
78
+ assert isinstance(model.optimizers[1], torch.optim.Adam)
79
+
80
+ # prepare data
81
+ gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
82
+ kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
83
+ kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
84
+ sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
85
+ data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
86
+ model.feed_data(data)
87
+ # check dequeue
88
+ model.feed_data(data)
89
+ # check data shape
90
+ assert model.lq.shape == (1, 3, 8, 8)
91
+ assert model.gt.shape == (1, 3, 32, 32)
92
+
93
+ # change probability to test if-else
94
+ model.opt['gaussian_noise_prob'] = 0
95
+ model.opt['gray_noise_prob'] = 0
96
+ model.opt['second_blur_prob'] = 0
97
+ model.opt['gaussian_noise_prob2'] = 0
98
+ model.opt['gray_noise_prob2'] = 0
99
+ model.feed_data(data)
100
+ # check data shape
101
+ assert model.lq.shape == (1, 3, 8, 8)
102
+ assert model.gt.shape == (1, 3, 32, 32)
103
+
104
+ # ----------------- test nondist_validation -------------------- #
105
+ # construct dataloader
106
+ dataset_opt = dict(
107
+ name='Demo',
108
+ dataroot_gt='tests/data/gt',
109
+ dataroot_lq='tests/data/lq',
110
+ io_backend=dict(type='disk'),
111
+ scale=4,
112
+ phase='val')
113
+ dataset = PairedImageDataset(dataset_opt)
114
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
115
+ assert model.is_train is True
116
+ model.nondist_validation(dataloader, 1, None, False)
117
+ assert model.is_train is True
118
+
119
+ # ----------------- test optimize_parameters -------------------- #
120
+ model.feed_data(data)
121
+ model.optimize_parameters(1)
122
+ assert model.output.shape == (1, 3, 32, 32)
123
+ assert isinstance(model.log_dict, dict)
124
+ # check returned keys
125
+ expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
126
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
tests/test_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from basicsr.archs.rrdbnet_arch import RRDBNet
3
+
4
+ from realesrgan.utils import RealESRGANer
5
+
6
+
7
+ def test_realesrganer():
8
+ # initialize with default model
9
+ restorer = RealESRGANer(
10
+ scale=4,
11
+ model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
12
+ model=None,
13
+ tile=10,
14
+ tile_pad=10,
15
+ pre_pad=2,
16
+ half=False)
17
+ assert isinstance(restorer.model, RRDBNet)
18
+ assert restorer.half is False
19
+ # initialize with user-defined model
20
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
21
+ restorer = RealESRGANer(
22
+ scale=4,
23
+ model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
24
+ model=model,
25
+ tile=10,
26
+ tile_pad=10,
27
+ pre_pad=2,
28
+ half=True)
29
+ # test attribute
30
+ assert isinstance(restorer.model, RRDBNet)
31
+ assert restorer.half is True
32
+
33
+ # ------------------ test pre_process ---------------- #
34
+ img = np.random.random((12, 12, 3)).astype(np.float32)
35
+ restorer.pre_process(img)
36
+ assert restorer.img.shape == (1, 3, 14, 14)
37
+ # with modcrop
38
+ restorer.scale = 1
39
+ restorer.pre_process(img)
40
+ assert restorer.img.shape == (1, 3, 16, 16)
41
+
42
+ # ------------------ test process ---------------- #
43
+ restorer.process()
44
+ assert restorer.output.shape == (1, 3, 64, 64)
45
+
46
+ # ------------------ test post_process ---------------- #
47
+ restorer.mod_scale = 4
48
+ output = restorer.post_process()
49
+ assert output.shape == (1, 3, 60, 60)
50
+
51
+ # ------------------ test tile_process ---------------- #
52
+ restorer.scale = 4
53
+ img = np.random.random((12, 12, 3)).astype(np.float32)
54
+ restorer.pre_process(img)
55
+ restorer.tile_process()
56
+ assert restorer.output.shape == (1, 3, 64, 64)
57
+
58
+ # ------------------ test enhance ---------------- #
59
+ img = np.random.random((12, 12, 3)).astype(np.float32)
60
+ result = restorer.enhance(img, outscale=2)
61
+ assert result[0].shape == (24, 24, 3)
62
+ assert result[1] == 'RGB'
63
+
64
+ # ------------------ test enhance with 16-bit image---------------- #
65
+ img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
66
+ result = restorer.enhance(img, outscale=2)
67
+ assert result[0].shape == (8, 8, 3)
68
+ assert result[1] == 'RGB'
69
+
70
+ # ------------------ test enhance with gray image---------------- #
71
+ img = np.random.random((4, 4)).astype(np.float32)
72
+ result = restorer.enhance(img, outscale=2)
73
+ assert result[0].shape == (8, 8)
74
+ assert result[1] == 'L'
75
+
76
+ # ------------------ test enhance with RGBA---------------- #
77
+ img = np.random.random((4, 4, 4)).astype(np.float32)
78
+ result = restorer.enhance(img, outscale=2)
79
+ assert result[0].shape == (8, 8, 4)
80
+ assert result[1] == 'RGBA'
81
+
82
+ # ------------------ test enhance with RGBA, alpha_upsampler---------------- #
83
+ restorer.tile_size = 0
84
+ img = np.random.random((4, 4, 4)).astype(np.float32)
85
+ result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
86
+ assert result[0].shape == (8, 8, 4)
87
+ assert result[1] == 'RGBA'