ReversibleHalftoning / inference.py
menghanxia's picture
fixed checkpoint loading requires GPU issue
40d12a9
raw
history blame contribute delete
No virus
4 kB
import numpy as np
import cv2
import os, argparse, json
from os.path import join
from glob import glob
import torch
import torch.nn.functional as F
from model.model import ResHalf
from model.model import Quantize
from model.loss import l1_loss
from utils import util
from utils.dct import DCT_Lowfrequency
from utils.filters_tensor import bgr2gray
from collections import OrderedDict
class Inferencer:
def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True):
self.checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.use_cuda = use_cuda
self.model = model.eval()
if multi_gpu:
self.model = torch.nn.DataParallel(self.model)
state_dict = self.checkpoint['state_dict']
else:
## remove keyword "module" in the state_dict
state_dict = OrderedDict()
for k, v in self.checkpoint['state_dict'].items():
name = k[7:]
state_dict[name] = v
if self.use_cuda:
self.model = self.model.cuda()
self.model.load_state_dict(state_dict)
def __call__(self, input_img, decoding_only=False):
with torch.no_grad():
scale = 8
_, _, H, W = input_img.shape
if H % scale != 0 or W % scale != 0:
input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect')
if self.use_cuda:
input_img = input_img.cuda()
if decoding_only:
resColor = self.model(input_img, decoding_only)
if H % scale != 0 or W % scale != 0:
resColor = resColor[:, :, :H, :W]
return resColor
else:
resHalftone, resColor = self.model(input_img, decoding_only)
resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1.
if H % scale != 0 or W % scale != 0:
resHalftone = resHalftone[:, :, :H, :W]
resColor = resColor[:, :, :H, :W]
return resHalftone, resColor
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='invHalf')
parser.add_argument('--model', default=None, type=str,
help='model weight file path')
parser.add_argument('--decoding', action='store_true', default=False, help='restoration from halftone input')
parser.add_argument('--data_dir', default=None, type=str,
help='where to load input data (RGB images)')
parser.add_argument('--save_dir', default=None, type=str,
help='where to save the result')
args = parser.parse_args()
invhalfer = Inferencer(
checkpoint_path=args.model,
model=ResHalf(train=False)
)
save_dir = os.path.join(args.save_dir)
util.ensure_dir(save_dir)
test_imgs = glob(join(args.data_dir, '*.*g'))
print('------loaded %d images.' % len(test_imgs) )
for img in test_imgs:
print('[*] processing %s ...' % img)
if args.decoding:
input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1.
c = invhalfer(util.img2tensor(input_img), decoding_only=True)
c = util.tensor2img(c / 2. + 0.5) * 255.
cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)
else:
input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1.
h, c = invhalfer(util.img2tensor(input_img), decoding_only=False)
h = util.tensor2img(h / 2. + 0.5) * 255.
c = util.tensor2img(c / 2. + 0.5) * 255.
cv2.imwrite(join(save_dir, 'halftone_' + img.split('/')[-1].split('.')[0] + '.png'), h)
cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)