Spaces:
Runtime error
Runtime error
File size: 4,861 Bytes
810c8ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import yaml
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.data.paired_image_dataset import PairedImageDataset
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
from realesrgan.models.realesrgan_model import RealESRGANModel
from realesrgan.models.realesrnet_model import RealESRNetModel
def test_realesrnet_model():
with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
# build model
model = RealESRNetModel(opt)
# test attributes
assert model.__class__.__name__ == 'RealESRNetModel'
assert isinstance(model.net_g, RRDBNet)
assert isinstance(model.cri_pix, L1Loss)
assert isinstance(model.optimizers[0], torch.optim.Adam)
# prepare data
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
model.feed_data(data)
# check dequeue
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# change probability to test if-else
model.opt['gaussian_noise_prob'] = 0
model.opt['gray_noise_prob'] = 0
model.opt['second_blur_prob'] = 0
model.opt['gaussian_noise_prob2'] = 0
model.opt['gray_noise_prob2'] = 0
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# ----------------- test nondist_validation -------------------- #
# construct dataloader
dataset_opt = dict(
name='Demo',
dataroot_gt='tests/data/gt',
dataroot_lq='tests/data/lq',
io_backend=dict(type='disk'),
scale=4,
phase='val')
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True
model.nondist_validation(dataloader, 1, None, False)
assert model.is_train is True
def test_realesrgan_model():
with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
# build model
model = RealESRGANModel(opt)
# test attributes
assert model.__class__.__name__ == 'RealESRGANModel'
assert isinstance(model.net_g, RRDBNet) # generator
assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
assert isinstance(model.cri_pix, L1Loss)
assert isinstance(model.cri_perceptual, PerceptualLoss)
assert isinstance(model.cri_gan, GANLoss)
assert isinstance(model.optimizers[0], torch.optim.Adam)
assert isinstance(model.optimizers[1], torch.optim.Adam)
# prepare data
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
model.feed_data(data)
# check dequeue
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# change probability to test if-else
model.opt['gaussian_noise_prob'] = 0
model.opt['gray_noise_prob'] = 0
model.opt['second_blur_prob'] = 0
model.opt['gaussian_noise_prob2'] = 0
model.opt['gray_noise_prob2'] = 0
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# ----------------- test nondist_validation -------------------- #
# construct dataloader
dataset_opt = dict(
name='Demo',
dataroot_gt='tests/data/gt',
dataroot_lq='tests/data/lq',
io_backend=dict(type='disk'),
scale=4,
phase='val')
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True
model.nondist_validation(dataloader, 1, None, False)
assert model.is_train is True
# ----------------- test optimize_parameters -------------------- #
model.feed_data(data)
model.optimize_parameters(1)
assert model.output.shape == (1, 3, 32, 32)
assert isinstance(model.log_dict, dict)
# check returned keys
expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
assert set(expected_keys).issubset(set(model.log_dict.keys()))
|