import argparse import logging import os import os.path as osp import random import cv2 import data.util as data_util import lmdb import numpy as np import torch import utils.util as util import yaml from models.kernel_encoding.kernel_wizard import KernelWizard def read_image(env, key, x, y, h, w): img = data_util.read_img(env, key, (3, 720, 1280)) img = np.transpose(img[x : x + h, y : y + w, [2, 1, 0]], (2, 0, 1)) return img def main(): device = torch.device("cuda") parser = argparse.ArgumentParser(description="Kernel extractor testing") parser.add_argument("--source_H", action="store", help="source image height", type=int, required=True) parser.add_argument("--source_W", action="store", help="source image width", type=int, required=True) parser.add_argument("--target_H", action="store", help="target image height", type=int, required=True) parser.add_argument("--target_W", action="store", help="target image width", type=int, required=True) parser.add_argument( "--augmented_H", action="store", help="desired height of the augmented images", type=int, required=True ) parser.add_argument( "--augmented_W", action="store", help="desired width of the augmented images", type=int, required=True ) parser.add_argument( "--source_LQ_root", action="store", help="source low-quality dataroot", type=str, required=True ) parser.add_argument( "--source_HQ_root", action="store", help="source high-quality dataroot", type=str, required=True ) parser.add_argument( "--target_HQ_root", action="store", help="target high-quality dataroot", type=str, required=True ) parser.add_argument("--save_path", action="store", help="save path", type=str, required=True) parser.add_argument("--yml_path", action="store", help="yml path", type=str, required=True) parser.add_argument( "--num_images", action="store", help="number of desire augmented images", type=int, required=True ) args = parser.parse_args() source_LQ_root = args.source_LQ_root source_HQ_root = args.source_HQ_root target_HQ_root = args.target_HQ_root save_path = args.save_path source_H, source_W = args.source_H, args.source_W target_H, target_W = args.target_H, args.target_W augmented_H, augmented_W = args.augmented_H, args.augmented_W yml_path = args.yml_path num_images = args.num_images # Initializing logger logger = logging.getLogger("base") os.makedirs(save_path, exist_ok=True) util.setup_logger("base", save_path, "test", level=logging.INFO, screen=True, tofile=True) logger.info("source LQ root: {}".format(source_LQ_root)) logger.info("source HQ root: {}".format(source_HQ_root)) logger.info("target HQ root: {}".format(target_HQ_root)) logger.info("augmented height: {}".format(augmented_H)) logger.info("augmented width: {}".format(augmented_W)) logger.info("Number of augmented images: {}".format(num_images)) # Initializing mode logger.info("Loading model...") with open(yml_path, "r") as f: print(yml_path) opt = yaml.load(f)["KernelWizard"] model_path = opt["pretrained"] model = KernelWizard(opt) model.eval() model.load_state_dict(torch.load(model_path)) model = model.to(device) logger.info("Done") # processing data source_HQ_env = lmdb.open(source_HQ_root, readonly=True, lock=False, readahead=False, meminit=False) source_LQ_env = lmdb.open(source_LQ_root, readonly=True, lock=False, readahead=False, meminit=False) target_HQ_env = lmdb.open(target_HQ_root, readonly=True, lock=False, readahead=False, meminit=False) paths_source_HQ, _ = data_util.get_image_paths("lmdb", source_HQ_root) paths_target_HQ, _ = data_util.get_image_paths("lmdb", target_HQ_root) psnr_avg = 0 for i in range(num_images): source_key = np.random.choice(paths_source_HQ) target_key = np.random.choice(paths_target_HQ) source_rnd_h = random.randint(0, max(0, source_H - augmented_H)) source_rnd_w = random.randint(0, max(0, source_W - augmented_W)) target_rnd_h = random.randint(0, max(0, target_H - augmented_H)) target_rnd_w = random.randint(0, max(0, target_W - augmented_W)) source_LQ = read_image(source_LQ_env, source_key, source_rnd_h, source_rnd_w, augmented_H, augmented_W) source_HQ = read_image(source_HQ_env, source_key, source_rnd_h, source_rnd_w, augmented_H, augmented_W) target_HQ = read_image(target_HQ_env, target_key, target_rnd_h, target_rnd_w, augmented_H, augmented_W) source_LQ = torch.Tensor(source_LQ).unsqueeze(0).to(device) source_HQ = torch.Tensor(source_HQ).unsqueeze(0).to(device) target_HQ = torch.Tensor(target_HQ).unsqueeze(0).to(device) with torch.no_grad(): kernel_mean, kernel_sigma = model(source_HQ, source_LQ) kernel = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean) fake_source_LQ = model.adaptKernel(source_HQ, kernel) target_LQ = model.adaptKernel(target_HQ, kernel) LQ_img = util.tensor2img(source_LQ) fake_LQ_img = util.tensor2img(fake_source_LQ) target_LQ_img = util.tensor2img(target_LQ) target_HQ_img = util.tensor2img(target_HQ) target_HQ_dst = osp.join(save_path, "sharp/{:03d}/{:08d}.png".format(i // 100, i % 100)) target_LQ_dst = osp.join(save_path, "blur/{:03d}/{:08d}.png".format(i // 100, i % 100)) os.makedirs(osp.dirname(target_HQ_dst), exist_ok=True) os.makedirs(osp.dirname(target_LQ_dst), exist_ok=True) cv2.imwrite(target_HQ_dst, target_HQ_img) cv2.imwrite(target_LQ_dst, target_LQ_img) # torch.save(kernel, osp.join(osp.dirname(target_LQ_dst), f'kernel{i:03d}.pth')) psnr = util.calculate_psnr(LQ_img, fake_LQ_img) logger.info("Reconstruction PSNR of image #{:03d}/{:03d}: {:.2f}db".format(i, num_images, psnr)) psnr_avg += psnr logger.info("Average reconstruction PSNR: {:.2f}db".format(psnr_avg / num_images)) main()