PaperEdgeDemo / demo_cpu.py
SWHL's picture
First commit
1828176
# -*- encoding: utf-8 -*-
import argparse
import copy
import time
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from networks.paperedge_cpu import GlobalWarper, LocalWarper, WarperUtil
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
class PaperEdge(object):
def __init__(self, enet_path, tnet_path, device) -> None:
self.device = device
self.netG = GlobalWarper().to(device)
netG_state = torch.load(enet_path, map_location=device)['G']
self.netG.load_state_dict(netG_state)
self.netG.eval()
self.netL = LocalWarper().to(device)
netL_state = torch.load(tnet_path, map_location=device)['L']
self.netL.load_state_dict(netL_state)
self.netL.eval()
self.warpUtil = WarperUtil(64).to(device)
@staticmethod
def load_img(img_path):
im = cv2.imread(img_path).astype(np.float32) / 255.0
im = im[:, :, (2, 1, 0)]
im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA)
im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
return im
def infer(self, img_path):
gs_d, ls_d = None, None
with torch.no_grad():
x = self.load_img(img_path)
x = x.unsqueeze(0).to(self.device)
d = self.netG(x)
d = self.warpUtil.global_post_warp(d, 64)
gs_d = copy.deepcopy(d)
d = F.interpolate(d, size=256, mode='bilinear', align_corners=True)
y0 = F.grid_sample(x, d.permute(0, 2, 3, 1), align_corners=True)
ls_d = self.netL(y0)
ls_d = F.interpolate(ls_d, size=256, mode='bilinear', align_corners=True)
ls_d = ls_d.clamp(-1.0, 1.0)
im = cv2.imread(img_path).astype(np.float32) / 255.0
im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
im = im.to(self.device).unsqueeze(0)
gs_d = F.interpolate(gs_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
gs_y = F.grid_sample(im, gs_d.permute(0, 2, 3, 1), align_corners=True).detach()
ls_d = F.interpolate(ls_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
ls_y = F.grid_sample(gs_y, ls_d.permute(0, 2, 3, 1), align_corners=True).detach()
ls_y = ls_y.squeeze().permute(1, 2, 0).cpu().numpy()
save_path = f'{dst_dir}/result_ls.png'
cv2.imwrite(save_path, ls_y * 255.)
return save_path
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--Enet_ckpt', type=str,
default='models/G_w_checkpoint_13820.pt')
parser.add_argument('--Tnet_ckpt', type=str,
default='models/L_w_checkpoint_27640.pt')
parser.add_argument('--img_path', type=str, default='images/3.jpg')
parser.add_argument('--out_dir', type=str, default='output')
parser.add_argument('--device', type=str, default='cpu')
args = parser.parse_args()
if args.device == 'cuda' and torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
dst_dir = args.out_dir
Path(dst_dir).mkdir(parents=True, exist_ok=True)
paper_edge = PaperEdge(args.Enet_ckpt, args.Tnet_ckpt, args.device)
paper_edge.inder(args.img_path)
print('ok')