File size: 2,798 Bytes
e5b70eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import numpy as np
import cv2
import torch.nn as nn
from tqdm import tqdm
import torch
import torchvision

from model import encoder, decoder
from opt.option import args


# device setting
if args.gpu_id is not None:
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"
    print('using GPU 0')
else:
    print('use --gpu_id to specify GPU ID to use')
    exit()


# make directory for saving weights
if not os.path.exists(args.results):
    os.mkdir(args.results)


# numpy array -> torch tensor
class ToTensor(object):
    def __call__(self, sample):
        sample = np.transpose(sample, (2, 0, 1))
        sample = torch.from_numpy(sample)
        return sample


# create model
# model_Enc = encoder.Encoder().cuda()
# model_Dec_SR = decoder.Decoder_SR().cuda()
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).cuda()
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).cuda()

model_Enc = nn.DataParallel(model_Enc)
#model_Dec_Id = nn.DataParallel(model_Dec_Id)
model_Dec_SR = nn.DataParallel(model_Dec_SR)

# load weights
checkpoint = torch.load(args.weights)
model_Enc.load_state_dict(checkpoint['model_Enc'])
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
model_Enc.eval()
model_Dec_SR.eval()

# input transform
transforms = torchvision.transforms.Compose([ToTensor()])


filenames = os.listdir(args.dir_test)
filenames.sort()
with torch.no_grad():
    for filename in tqdm(filenames):
        img_name = os.path.join(args.dir_test, filename)
        ext = os.path.splitext(img_name)[-1]
        if ext in ['.png', '.jpg']:
            img = cv2.imread(img_name)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            #img = cv2.resize(img, ((img.shape[1] // 4),(img.shape[0] // 4)))
            img = np.array(img).astype('float32') / 255
            # img = img[0:256, 0:256, :]
            
            img = transforms(img)
            img = torch.tensor(img.cuda()).unsqueeze(0)

            # inference output
            feat = model_Enc(img)
            out = model_Dec_SR(feat)

            min_max = (0, 1)
            out = out.detach()[0].float().cpu()

            out = out.squeeze().float().cpu().clamp_(*min_max)
            out = (out - min_max[0]) / (min_max[1] - min_max[0])
            out = out.numpy()
            out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))

            out = (out*255.0).round()
            out = out.astype(np.uint8)

            # result image save (b x c x h x w (torch tensor) -> h x w x c (numpy array))
            # out = out.data.cpu().squeeze().numpy()
            # out = np.clip(out, 0, 1)
            # out = np.transpose(out, (1, 2, 0))
            print(args.results, filename)
            cv2.imwrite('%s_out.png' %(os.path.join(args.results, filename)[:-4]), out)