import numpy as np import cv2 import os, argparse, json from os.path import join from glob import glob import torch import torch.nn.functional as F from model.model import ResHalf from model.model import Quantize from model.loss import l1_loss from utils import util from utils.dct import DCT_Lowfrequency from utils.filters_tensor import bgr2gray from collections import OrderedDict class Inferencer: def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True): self.checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.use_cuda = use_cuda self.model = model.eval() if multi_gpu: self.model = torch.nn.DataParallel(self.model) state_dict = self.checkpoint['state_dict'] else: ## remove keyword "module" in the state_dict state_dict = OrderedDict() for k, v in self.checkpoint['state_dict'].items(): name = k[7:] state_dict[name] = v if self.use_cuda: self.model = self.model.cuda() self.model.load_state_dict(state_dict) def __call__(self, input_img, decoding_only=False): with torch.no_grad(): scale = 8 _, _, H, W = input_img.shape if H % scale != 0 or W % scale != 0: input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect') if self.use_cuda: input_img = input_img.cuda() if decoding_only: resColor = self.model(input_img, decoding_only) if H % scale != 0 or W % scale != 0: resColor = resColor[:, :, :H, :W] return resColor else: resHalftone, resColor = self.model(input_img, decoding_only) resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1. if H % scale != 0 or W % scale != 0: resHalftone = resHalftone[:, :, :H, :W] resColor = resColor[:, :, :H, :W] return resHalftone, resColor if __name__ == '__main__': parser = argparse.ArgumentParser(description='invHalf') parser.add_argument('--model', default=None, type=str, help='model weight file path') parser.add_argument('--decoding', action='store_true', default=False, help='restoration from halftone input') parser.add_argument('--data_dir', default=None, type=str, help='where to load input data (RGB images)') parser.add_argument('--save_dir', default=None, type=str, help='where to save the result') args = parser.parse_args() invhalfer = Inferencer( checkpoint_path=args.model, model=ResHalf(train=False) ) save_dir = os.path.join(args.save_dir) util.ensure_dir(save_dir) test_imgs = glob(join(args.data_dir, '*.*g')) print('------loaded %d images.' % len(test_imgs) ) for img in test_imgs: print('[*] processing %s ...' % img) if args.decoding: input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1. c = invhalfer(util.img2tensor(input_img), decoding_only=True) c = util.tensor2img(c / 2. + 0.5) * 255. cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c) else: input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1. h, c = invhalfer(util.img2tensor(input_img), decoding_only=False) h = util.tensor2img(h / 2. + 0.5) * 255. c = util.tensor2img(c / 2. + 0.5) * 255. cv2.imwrite(join(save_dir, 'halftone_' + img.split('/')[-1].split('.')[0] + '.png'), h) cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)