import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable import torch import numpy as np import os, time, random import argparse from torch.utils.data import Dataset, DataLoader from PIL import Image as PILImage from glob import glob from tqdm import tqdm from model.model import InvISPNet from dataset.FiveK_dataset import FiveKDatasetTest from config.config import get_arguments from utils.JPEG import DiffJPEG from utils.commons import denorm, preprocess_test_patch os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) # os.environ['CUDA_VISIBLE_DEVICES'] = '7' os.system('rm tmp') DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() parser = get_arguments() parser.add_argument("--ckpt", type=str, help="Checkpoint path.") parser.add_argument("--out_path", type=str, default="./exps/", help="Path to save checkpoint. ") parser.add_argument("--split_to_patch", dest='split_to_patch', action='store_true', help="Test on patch. ") args = parser.parse_args() print("Parsed arguments: {}".format(args)) ckpt_name = args.ckpt.split("/")[-1].split(".")[0] if args.split_to_patch: os.makedirs(args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name), exist_ok=True) out_path = args.out_path+"%s/results_metric_%s/"%(args.task, ckpt_name) else: os.makedirs(args.out_path+"%s/results_%s/"%(args.task, ckpt_name), exist_ok=True) out_path = args.out_path+"%s/results_%s/"%(args.task, ckpt_name) def main(args): # ======================================define the model============================================ net = InvISPNet(channel_in=3, channel_out=3, block_num=8) device = torch.device("cuda:0") net.to(device) net.eval() # load the pretrained weight if there exists one if os.path.isfile(args.ckpt): net.load_state_dict(torch.load(args.ckpt), strict=False) print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) print("[INFO] Start data load and preprocessing") RAWDataset = FiveKDatasetTest(opt=args) dataloader = DataLoader(RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True) input_RGBs = sorted(glob(out_path+"pred*jpg")) input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs] print("[INFO] Start test...") for i_batch, sample_batched in enumerate(tqdm(dataloader)): step_time = time.time() input, target_rgb, target_raw = sample_batched['input_raw'].to(device), sample_batched['target_rgb'].to(device), \ sample_batched['target_raw'].to(device) file_name = sample_batched['file_name'][0] if args.split_to_patch: input_list, target_rgb_list, target_raw_list = preprocess_test_patch(input, target_rgb, target_raw) else: # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution input_list, target_rgb_list, target_raw_list = [input[:,:,::2,::2]], [target_rgb[:,:,::2,::2]], [target_raw[:,:,::2,::2]] for i_patch in range(len(input_list)): file_name_patch = file_name + "_%05d"%i_patch idx = input_RGBs_names.index(file_name_patch) input_RGB_path = input_RGBs[idx] input_RGB = torch.from_numpy(np.array(PILImage.open(input_RGB_path))/255.0).unsqueeze(0).permute(0,3,1,2).float().to(device) target_raw_patch = target_raw_list[i_patch] with torch.no_grad(): reconstruct_raw = net(input_RGB, rev=True) pred_raw = reconstruct_raw.detach().permute(0,2,3,1) pred_raw = torch.clamp(pred_raw, 0, 1) target_raw_patch = target_raw_patch.permute(0,2,3,1) pred_raw = denorm(pred_raw, 255) target_raw_patch = denorm(target_raw_patch, 255) pred_raw = pred_raw.cpu().numpy() target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32) raw_pred = PILImage.fromarray(np.uint8(pred_raw[0,:,:,0])) raw_tar_pred = PILImage.fromarray(np.hstack((np.uint8(target_raw_patch[0,:,:,0]), np.uint8(pred_raw[0,:,:,0])))) raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0,:,:,0])) raw_pred.save(out_path+"raw_pred_%s_%05d.jpg"%(file_name, i_patch)) raw_tar.save(out_path+"raw_tar_%s_%05d.jpg"%(file_name, i_patch)) raw_tar_pred.save(out_path+"raw_gt_pred_%s_%05d.jpg"%(file_name, i_patch)) np.save(out_path+"raw_pred_%s_%05d.npy"%(file_name, i_patch), pred_raw[0,:,:,:]/255.0) np.save(out_path+"raw_tar_%s_%05d.npy"%(file_name, i_patch), target_raw_patch[0,:,:,:]/255.0) del reconstruct_raw if __name__ == '__main__': torch.set_num_threads(4) main(args)