|
import torch |
|
import yaml |
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
from basicsr.data.paired_image_dataset import PairedImageDataset |
|
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss |
|
|
|
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN |
|
from realesrgan.models.realesrgan_model import RealESRGANModel |
|
from realesrgan.models.realesrnet_model import RealESRNetModel |
|
|
|
|
|
def test_realesrnet_model(): |
|
with open('tests/data/test_realesrnet_model.yml', mode='r') as f: |
|
opt = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
|
|
model = RealESRNetModel(opt) |
|
|
|
assert model.__class__.__name__ == 'RealESRNetModel' |
|
assert isinstance(model.net_g, RRDBNet) |
|
assert isinstance(model.cri_pix, L1Loss) |
|
assert isinstance(model.optimizers[0], torch.optim.Adam) |
|
|
|
|
|
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32) |
|
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32) |
|
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32) |
|
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32) |
|
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel) |
|
model.feed_data(data) |
|
|
|
model.feed_data(data) |
|
|
|
assert model.lq.shape == (1, 3, 8, 8) |
|
assert model.gt.shape == (1, 3, 32, 32) |
|
|
|
|
|
model.opt['gaussian_noise_prob'] = 0 |
|
model.opt['gray_noise_prob'] = 0 |
|
model.opt['second_blur_prob'] = 0 |
|
model.opt['gaussian_noise_prob2'] = 0 |
|
model.opt['gray_noise_prob2'] = 0 |
|
model.feed_data(data) |
|
|
|
assert model.lq.shape == (1, 3, 8, 8) |
|
assert model.gt.shape == (1, 3, 32, 32) |
|
|
|
|
|
|
|
dataset_opt = dict( |
|
name='Demo', |
|
dataroot_gt='tests/data/gt', |
|
dataroot_lq='tests/data/lq', |
|
io_backend=dict(type='disk'), |
|
scale=4, |
|
phase='val') |
|
dataset = PairedImageDataset(dataset_opt) |
|
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) |
|
assert model.is_train is True |
|
model.nondist_validation(dataloader, 1, None, False) |
|
assert model.is_train is True |
|
|
|
|
|
def test_realesrgan_model(): |
|
with open('tests/data/test_realesrgan_model.yml', mode='r') as f: |
|
opt = yaml.load(f, Loader=yaml.FullLoader) |
|
|
|
|
|
model = RealESRGANModel(opt) |
|
|
|
assert model.__class__.__name__ == 'RealESRGANModel' |
|
assert isinstance(model.net_g, RRDBNet) |
|
assert isinstance(model.net_d, UNetDiscriminatorSN) |
|
assert isinstance(model.cri_pix, L1Loss) |
|
assert isinstance(model.cri_perceptual, PerceptualLoss) |
|
assert isinstance(model.cri_gan, GANLoss) |
|
assert isinstance(model.optimizers[0], torch.optim.Adam) |
|
assert isinstance(model.optimizers[1], torch.optim.Adam) |
|
|
|
|
|
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32) |
|
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32) |
|
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32) |
|
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32) |
|
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel) |
|
model.feed_data(data) |
|
|
|
model.feed_data(data) |
|
|
|
assert model.lq.shape == (1, 3, 8, 8) |
|
assert model.gt.shape == (1, 3, 32, 32) |
|
|
|
|
|
model.opt['gaussian_noise_prob'] = 0 |
|
model.opt['gray_noise_prob'] = 0 |
|
model.opt['second_blur_prob'] = 0 |
|
model.opt['gaussian_noise_prob2'] = 0 |
|
model.opt['gray_noise_prob2'] = 0 |
|
model.feed_data(data) |
|
|
|
assert model.lq.shape == (1, 3, 8, 8) |
|
assert model.gt.shape == (1, 3, 32, 32) |
|
|
|
|
|
|
|
dataset_opt = dict( |
|
name='Demo', |
|
dataroot_gt='tests/data/gt', |
|
dataroot_lq='tests/data/lq', |
|
io_backend=dict(type='disk'), |
|
scale=4, |
|
phase='val') |
|
dataset = PairedImageDataset(dataset_opt) |
|
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) |
|
assert model.is_train is True |
|
model.nondist_validation(dataloader, 1, None, False) |
|
assert model.is_train is True |
|
|
|
|
|
model.feed_data(data) |
|
model.optimize_parameters(1) |
|
assert model.output.shape == (1, 3, 32, 32) |
|
assert isinstance(model.log_dict, dict) |
|
|
|
expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake'] |
|
assert set(expected_keys).issubset(set(model.log_dict.keys())) |
|
|