Spaces:
Build error
Build error
| # -*- encoding: utf-8 -*- | |
| import argparse | |
| import copy | |
| import time | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from networks.paperedge_cpu import GlobalWarper, LocalWarper, WarperUtil | |
| cv2.setNumThreads(0) | |
| cv2.ocl.setUseOpenCL(False) | |
| class PaperEdge(object): | |
| def __init__(self, enet_path, tnet_path, device) -> None: | |
| self.device = device | |
| self.netG = GlobalWarper().to(device) | |
| netG_state = torch.load(enet_path, map_location=device)['G'] | |
| self.netG.load_state_dict(netG_state) | |
| self.netG.eval() | |
| self.netL = LocalWarper().to(device) | |
| netL_state = torch.load(tnet_path, map_location=device)['L'] | |
| self.netL.load_state_dict(netL_state) | |
| self.netL.eval() | |
| self.warpUtil = WarperUtil(64).to(device) | |
| def load_img(img_path): | |
| im = cv2.imread(img_path).astype(np.float32) / 255.0 | |
| im = im[:, :, (2, 1, 0)] | |
| im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA) | |
| im = torch.from_numpy(np.transpose(im, (2, 0, 1))) | |
| return im | |
| def infer(self, img_path): | |
| gs_d, ls_d = None, None | |
| with torch.no_grad(): | |
| x = self.load_img(img_path) | |
| x = x.unsqueeze(0).to(self.device) | |
| d = self.netG(x) | |
| d = self.warpUtil.global_post_warp(d, 64) | |
| gs_d = copy.deepcopy(d) | |
| d = F.interpolate(d, size=256, mode='bilinear', align_corners=True) | |
| y0 = F.grid_sample(x, d.permute(0, 2, 3, 1), align_corners=True) | |
| ls_d = self.netL(y0) | |
| ls_d = F.interpolate(ls_d, size=256, mode='bilinear', align_corners=True) | |
| ls_d = ls_d.clamp(-1.0, 1.0) | |
| im = cv2.imread(img_path).astype(np.float32) / 255.0 | |
| im = torch.from_numpy(np.transpose(im, (2, 0, 1))) | |
| im = im.to(self.device).unsqueeze(0) | |
| gs_d = F.interpolate(gs_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True) | |
| gs_y = F.grid_sample(im, gs_d.permute(0, 2, 3, 1), align_corners=True).detach() | |
| ls_d = F.interpolate(ls_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True) | |
| ls_y = F.grid_sample(gs_y, ls_d.permute(0, 2, 3, 1), align_corners=True).detach() | |
| ls_y = ls_y.squeeze().permute(1, 2, 0).cpu().numpy() | |
| save_path = f'{dst_dir}/result_ls.png' | |
| cv2.imwrite(save_path, ls_y * 255.) | |
| return save_path | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--Enet_ckpt', type=str, | |
| default='models/G_w_checkpoint_13820.pt') | |
| parser.add_argument('--Tnet_ckpt', type=str, | |
| default='models/L_w_checkpoint_27640.pt') | |
| parser.add_argument('--img_path', type=str, default='images/3.jpg') | |
| parser.add_argument('--out_dir', type=str, default='output') | |
| parser.add_argument('--device', type=str, default='cpu') | |
| args = parser.parse_args() | |
| if args.device == 'cuda' and torch.cuda.is_available(): | |
| device = torch.device('cuda:0') | |
| else: | |
| device = torch.device('cpu') | |
| dst_dir = args.out_dir | |
| Path(dst_dir).mkdir(parents=True, exist_ok=True) | |
| paper_edge = PaperEdge(args.Enet_ckpt, args.Tnet_ckpt, args.device) | |
| paper_edge.inder(args.img_path) | |
| print('ok') |