File size: 3,998 Bytes
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d12a9
6e70c4a
 
 
40d12a9
6e70c4a
 
 
 
40d12a9
 
 
 
 
 
 
6e70c4a
 
40d12a9
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import cv2
import os, argparse, json
from os.path import join
from glob import glob

import torch
import torch.nn.functional as F

from model.model import ResHalf
from model.model import Quantize
from model.loss import l1_loss
from utils import util
from utils.dct import DCT_Lowfrequency
from utils.filters_tensor import bgr2gray
from collections import OrderedDict

class Inferencer:
    def __init__(self, checkpoint_path, model, use_cuda=True, multi_gpu=True):
        self.checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
        self.use_cuda = use_cuda
        self.model = model.eval()
        if multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
            state_dict = self.checkpoint['state_dict']
        else:
            ## remove keyword "module" in the state_dict
            state_dict = OrderedDict()
            for k, v in self.checkpoint['state_dict'].items():
                name = k[7:]
                state_dict[name] = v 
        if self.use_cuda:
            self.model = self.model.cuda()
        self.model.load_state_dict(state_dict)

    def __call__(self, input_img, decoding_only=False):
        with torch.no_grad():
            scale = 8
            _, _, H, W = input_img.shape
            if H % scale != 0 or W % scale != 0:
                input_img = F.pad(input_img, [0, scale - W % scale, 0, scale - H % scale], mode='reflect')
            if self.use_cuda:
                input_img = input_img.cuda()
            if decoding_only:
                resColor = self.model(input_img, decoding_only)
                if H % scale != 0 or W % scale != 0:
                    resColor = resColor[:, :, :H, :W]
                return resColor
            else:
                resHalftone, resColor = self.model(input_img, decoding_only)
                resHalftone = Quantize.apply((resHalftone + 1.0) * 0.5) * 2.0 - 1.
                if H % scale != 0 or W % scale != 0:
                    resHalftone = resHalftone[:, :, :H, :W]
                    resColor = resColor[:, :, :H, :W]
                return resHalftone, resColor


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='invHalf')
    parser.add_argument('--model', default=None, type=str,
                        help='model weight file path')
    parser.add_argument('--decoding', action='store_true', default=False, help='restoration from halftone input')
    parser.add_argument('--data_dir', default=None, type=str, 
                        help='where to load input data (RGB images)')
    parser.add_argument('--save_dir', default=None, type=str, 
                        help='where to save the result')
    args = parser.parse_args()

    invhalfer = Inferencer(
        checkpoint_path=args.model,
        model=ResHalf(train=False)        
    )
    save_dir = os.path.join(args.save_dir)
    util.ensure_dir(save_dir)
    test_imgs = glob(join(args.data_dir, '*.*g'))
    print('------loaded %d images.' % len(test_imgs) )
    for img in test_imgs:
        print('[*] processing %s ...' % img)
        if args.decoding:
            input_img = cv2.imread(img, flags=cv2.IMREAD_GRAYSCALE) / 127.5 - 1.
            c = invhalfer(util.img2tensor(input_img), decoding_only=True)
            c = util.tensor2img(c / 2. + 0.5) * 255.
            cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)
        else:
            input_img = cv2.imread(img, flags=cv2.IMREAD_COLOR) / 127.5 - 1.
            h, c = invhalfer(util.img2tensor(input_img), decoding_only=False)
            h = util.tensor2img(h / 2. + 0.5) * 255.
            c = util.tensor2img(c / 2. + 0.5) * 255.
            cv2.imwrite(join(save_dir, 'halftone_' + img.split('/')[-1].split('.')[0] + '.png'), h)
            cv2.imwrite(join(save_dir, 'restored_' + img.split('/')[-1].split('.')[0] + '.png'), c)