USR-DA / inference.py
DS
dump shiet
e5b70eb
raw
history blame
2.8 kB
import os
import numpy as np
import cv2
import torch.nn as nn
from tqdm import tqdm
import torch
import torchvision
from model import encoder, decoder
from opt.option import args
# device setting
if args.gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
print('using GPU 0')
else:
print('use --gpu_id to specify GPU ID to use')
exit()
# make directory for saving weights
if not os.path.exists(args.results):
os.mkdir(args.results)
# numpy array -> torch tensor
class ToTensor(object):
def __call__(self, sample):
sample = np.transpose(sample, (2, 0, 1))
sample = torch.from_numpy(sample)
return sample
# create model
# model_Enc = encoder.Encoder().cuda()
# model_Dec_SR = decoder.Decoder_SR().cuda()
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).cuda()
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).cuda()
model_Enc = nn.DataParallel(model_Enc)
#model_Dec_Id = nn.DataParallel(model_Dec_Id)
model_Dec_SR = nn.DataParallel(model_Dec_SR)
# load weights
checkpoint = torch.load(args.weights)
model_Enc.load_state_dict(checkpoint['model_Enc'])
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
model_Enc.eval()
model_Dec_SR.eval()
# input transform
transforms = torchvision.transforms.Compose([ToTensor()])
filenames = os.listdir(args.dir_test)
filenames.sort()
with torch.no_grad():
for filename in tqdm(filenames):
img_name = os.path.join(args.dir_test, filename)
ext = os.path.splitext(img_name)[-1]
if ext in ['.png', '.jpg']:
img = cv2.imread(img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#img = cv2.resize(img, ((img.shape[1] // 4),(img.shape[0] // 4)))
img = np.array(img).astype('float32') / 255
# img = img[0:256, 0:256, :]
img = transforms(img)
img = torch.tensor(img.cuda()).unsqueeze(0)
# inference output
feat = model_Enc(img)
out = model_Dec_SR(feat)
min_max = (0, 1)
out = out.detach()[0].float().cpu()
out = out.squeeze().float().cpu().clamp_(*min_max)
out = (out - min_max[0]) / (min_max[1] - min_max[0])
out = out.numpy()
out = np.transpose(out[[2, 1, 0], :, :], (1, 2, 0))
out = (out*255.0).round()
out = out.astype(np.uint8)
# result image save (b x c x h x w (torch tensor) -> h x w x c (numpy array))
# out = out.data.cpu().squeeze().numpy()
# out = np.clip(out, 0, 1)
# out = np.transpose(out, (1, 2, 0))
print(args.results, filename)
cv2.imwrite('%s_out.png' %(os.path.join(args.results, filename)[:-4]), out)