Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import torch | |
| import numpy as np | |
| import os, time, random | |
| import argparse | |
| from torch.utils.data import Dataset, DataLoader | |
| from PIL import Image as PILImage | |
| from glob import glob | |
| from tqdm import tqdm | |
| from model.model import InvISPNet | |
| from dataset.FiveK_dataset import FiveKDatasetTest | |
| from config.config import get_arguments | |
| from utils.JPEG import DiffJPEG | |
| from utils.commons import denorm, preprocess_test_patch | |
| os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp") | |
| os.environ["CUDA_VISIBLE_DEVICES"] = str( | |
| np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()]) | |
| ) | |
| # os.environ['CUDA_VISIBLE_DEVICES'] = '7' | |
| os.system("rm tmp") | |
| DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() | |
| parser = get_arguments() | |
| parser.add_argument("--ckpt", type=str, help="Checkpoint path.") | |
| parser.add_argument( | |
| "--out_path", type=str, default="./exps/", help="Path to save checkpoint. " | |
| ) | |
| parser.add_argument( | |
| "--split_to_patch", | |
| dest="split_to_patch", | |
| action="store_true", | |
| help="Test on patch. ", | |
| ) | |
| args = parser.parse_args() | |
| print("Parsed arguments: {}".format(args)) | |
| ckpt_name = args.ckpt.split("/")[-1].split(".")[0] | |
| if args.split_to_patch: | |
| os.makedirs( | |
| args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name), exist_ok=True | |
| ) | |
| out_path = args.out_path + "%s/results_metric_%s/" % (args.task, ckpt_name) | |
| else: | |
| os.makedirs( | |
| args.out_path + "%s/results_%s/" % (args.task, ckpt_name), exist_ok=True | |
| ) | |
| out_path = args.out_path + "%s/results_%s/" % (args.task, ckpt_name) | |
| def main(args): | |
| # ======================================define the model============================================ | |
| net = InvISPNet(channel_in=3, channel_out=3, block_num=8) | |
| device = torch.device("cuda:0") | |
| net.to(device) | |
| net.eval() | |
| # load the pretrained weight if there exists one | |
| if os.path.isfile(args.ckpt): | |
| net.load_state_dict(torch.load(args.ckpt), strict=False) | |
| print("[INFO] Loaded checkpoint: {}".format(args.ckpt)) | |
| print("[INFO] Start data load and preprocessing") | |
| RAWDataset = FiveKDatasetTest(opt=args) | |
| dataloader = DataLoader( | |
| RAWDataset, batch_size=1, shuffle=False, num_workers=0, drop_last=True | |
| ) | |
| input_RGBs = sorted(glob(out_path + "pred*jpg")) | |
| input_RGBs_names = [path.split("/")[-1].split(".")[0][5:] for path in input_RGBs] | |
| print("[INFO] Start test...") | |
| for i_batch, sample_batched in enumerate(tqdm(dataloader)): | |
| step_time = time.time() | |
| input, target_rgb, target_raw = ( | |
| sample_batched["input_raw"].to(device), | |
| sample_batched["target_rgb"].to(device), | |
| sample_batched["target_raw"].to(device), | |
| ) | |
| file_name = sample_batched["file_name"][0] | |
| if args.split_to_patch: | |
| input_list, target_rgb_list, target_raw_list = preprocess_test_patch( | |
| input, target_rgb, target_raw | |
| ) | |
| else: | |
| # remove [:,:,::2,::2] if you have enough GPU memory to test the full resolution | |
| input_list, target_rgb_list, target_raw_list = ( | |
| [input[:, :, ::2, ::2]], | |
| [target_rgb[:, :, ::2, ::2]], | |
| [target_raw[:, :, ::2, ::2]], | |
| ) | |
| for i_patch in range(len(input_list)): | |
| file_name_patch = file_name + "_%05d" % i_patch | |
| idx = input_RGBs_names.index(file_name_patch) | |
| input_RGB_path = input_RGBs[idx] | |
| input_RGB = ( | |
| torch.from_numpy(np.array(PILImage.open(input_RGB_path)) / 255.0) | |
| .unsqueeze(0) | |
| .permute(0, 3, 1, 2) | |
| .float() | |
| .to(device) | |
| ) | |
| target_raw_patch = target_raw_list[i_patch] | |
| with torch.no_grad(): | |
| reconstruct_raw = net(input_RGB, rev=True) | |
| pred_raw = reconstruct_raw.detach().permute(0, 2, 3, 1) | |
| pred_raw = torch.clamp(pred_raw, 0, 1) | |
| target_raw_patch = target_raw_patch.permute(0, 2, 3, 1) | |
| pred_raw = denorm(pred_raw, 255) | |
| target_raw_patch = denorm(target_raw_patch, 255) | |
| pred_raw = pred_raw.cpu().numpy() | |
| target_raw_patch = target_raw_patch.cpu().numpy().astype(np.float32) | |
| raw_pred = PILImage.fromarray(np.uint8(pred_raw[0, :, :, 0])) | |
| raw_tar_pred = PILImage.fromarray( | |
| np.hstack( | |
| ( | |
| np.uint8(target_raw_patch[0, :, :, 0]), | |
| np.uint8(pred_raw[0, :, :, 0]), | |
| ) | |
| ) | |
| ) | |
| raw_tar = PILImage.fromarray(np.uint8(target_raw_patch[0, :, :, 0])) | |
| raw_pred.save(out_path + "raw_pred_%s_%05d.jpg" % (file_name, i_patch)) | |
| raw_tar.save(out_path + "raw_tar_%s_%05d.jpg" % (file_name, i_patch)) | |
| raw_tar_pred.save( | |
| out_path + "raw_gt_pred_%s_%05d.jpg" % (file_name, i_patch) | |
| ) | |
| np.save( | |
| out_path + "raw_pred_%s_%05d.npy" % (file_name, i_patch), | |
| pred_raw[0, :, :, :] / 255.0, | |
| ) | |
| np.save( | |
| out_path + "raw_tar_%s_%05d.npy" % (file_name, i_patch), | |
| target_raw_patch[0, :, :, :] / 255.0, | |
| ) | |
| del reconstruct_raw | |
| if __name__ == "__main__": | |
| torch.set_num_threads(4) | |
| main(args) | |