File size: 2,798 Bytes
e5b70eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import os
import numpy as np
import cv2
import torch.nn as nn
from tqdm import tqdm
import torch
import torchvision
from model import encoder, decoder
from opt.option import args
# device setting
if args.gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
print('using GPU 0')
else:
print('use --gpu_id to specify GPU ID to use')
exit()
# make directory for saving weights
if not os.path.exists(args.results):
os.mkdir(args.results)
# numpy array -> torch tensor
class ToTensor(object):
def __call__(self, sample):
sample = np.transpose(sample, (2, 0, 1))
sample = torch.from_numpy(sample)
return sample
# create model
# model_Enc = encoder.Encoder().cuda()
# model_Dec_SR = decoder.Decoder_SR().cuda()
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).cuda()
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).cuda()
model_Enc = nn.DataParallel(model_Enc)
#model_Dec_Id = nn.DataParallel(model_Dec_Id)
model_Dec_SR = nn.DataParallel(model_Dec_SR)
# load weights
checkpoint = torch.load(args.weights)
model_Enc.load_state_dict(checkpoint['model_Enc'])
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
model_Enc.eval()
model_Dec_SR.eval()
# input transform
transforms = torchvision.transforms.Compose([ToTensor()])
filenames = os.listdir(args.dir_test)
filenames.sort()
with torch.no_grad():
for filename in tqdm(filenames):
img_name = os.path.join(args.dir_test, filename)
ext = os.path.splitext(img_name)[-1]
if ext in ['.png', '.jpg']:
img = cv2.imread(img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#img = cv2.resize(img, ((img.shape[1] // 4),(img.shape[0] // 4)))
img = np.array(img).astype('float32') / 255
# img = img[0:256, 0:256, :]
img = transforms(img)
img = torch.tensor(img.cuda()).unsqueeze(0)
# inference output
feat = model_Enc(img)
out = model_Dec_SR(feat)
min_max = (0, 1)
out = out.detach()[0].float().cpu()
out = out.squeeze().float().cpu().clamp_(*min_max)
out = (out - min_max[0]) / (min_max[1] - min_max[0])
out = out.numpy()
out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
out = (out*255.0).round()
out = out.astype(np.uint8)
# result image save (b x c x h x w (torch tensor) -> h x w x c (numpy array))
# out = out.data.cpu().squeeze().numpy()
# out = np.clip(out, 0, 1)
# out = np.transpose(out, (1, 2, 0))
print(args.results, filename)
cv2.imwrite('%s_out.png' %(os.path.join(args.results, filename)[:-4]), out)
|