Spaces:
Build error
Build error
# -*- encoding: utf-8 -*- | |
import copy | |
import os | |
os.system('pip install -r requirements.txt') | |
import time | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from networks.paperedge_cpu import GlobalWarper, LocalWarper, WarperUtil | |
import gradio as gr | |
cv2.setNumThreads(0) | |
cv2.ocl.setUseOpenCL(False) | |
class PaperEdge(object): | |
def __init__(self, enet_path, tnet_path, device, dst_dir) -> None: | |
self.device = device | |
self.dst_dir = dst_dir | |
self.netG = GlobalWarper().to(device) | |
netG_state = torch.load(enet_path, map_location=device)['G'] | |
self.netG.load_state_dict(netG_state) | |
self.netG.eval() | |
self.netL = LocalWarper().to(device) | |
netL_state = torch.load(tnet_path, map_location=device)['L'] | |
self.netL.load_state_dict(netL_state) | |
self.netL.eval() | |
self.warpUtil = WarperUtil(64).to(device) | |
def load_img(img_path): | |
im = cv2.imread(img_path).astype(np.float32) / 255.0 | |
im = im[:, :, (2, 1, 0)] | |
im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA) | |
im = torch.from_numpy(np.transpose(im, (2, 0, 1))) | |
return im | |
def __call__(self, img_path): | |
time_stamp = time.strftime('%Y-%m-%d-%H-%M-%S', | |
time.localtime(time.time())) | |
gs_d, ls_d = None, None | |
with torch.no_grad(): | |
x = self.load_img(img_path) | |
x = x.unsqueeze(0).to(self.device) | |
d = self.netG(x) | |
d = self.warpUtil.global_post_warp(d, 64) | |
gs_d = copy.deepcopy(d) | |
d = F.interpolate(d, size=256, mode='bilinear', align_corners=True) | |
y0 = F.grid_sample(x, d.permute(0, 2, 3, 1), align_corners=True) | |
ls_d = self.netL(y0) | |
ls_d = F.interpolate(ls_d, size=256, mode='bilinear', align_corners=True) | |
ls_d = ls_d.clamp(-1.0, 1.0) | |
im = cv2.imread(img_path).astype(np.float32) / 255.0 | |
im = torch.from_numpy(np.transpose(im, (2, 0, 1))) | |
im = im.to(self.device).unsqueeze(0) | |
gs_d = F.interpolate(gs_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True) | |
gs_y = F.grid_sample(im, gs_d.permute(0, 2, 3, 1), align_corners=True).detach() | |
ls_d = F.interpolate(ls_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True) | |
ls_y = F.grid_sample(gs_y, ls_d.permute(0, 2, 3, 1), align_corners=True).detach() | |
ls_y = ls_y.squeeze().permute(1, 2, 0).cpu().numpy() | |
save_path = f'{self.dst_dir}/{time_stamp}.png' | |
cv2.imwrite(save_path, ls_y * 255.) | |
return save_path | |
def inference(img_path): | |
save_img_path = paper_edge(img_path) | |
return save_img_path | |
enet_path = 'models/G_w_checkpoint_13820.pt' | |
tnet_path = 'models/L_w_checkpoint_27640.pt' | |
device = torch.device('cpu') | |
dst_dir = Path('inference/') | |
if not dst_dir.exists(): | |
dst_dir.mkdir(parents=True, exist_ok=True) | |
paper_edge = PaperEdge(enet_path, tnet_path, device, dst_dir) | |
title = 'PaperEdge Demo' | |
description = 'This is the demo for the paper "Learning From Documents in the Wild to Improve Document Unwarping" (SIGGRAPH 2022). Github repo: https://github.com/cvlab-stonybrook/PaperEdge' | |
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}" | |
examples = [['images/1.jpg']] | |
gr.Interface( | |
inference, | |
inputs=gr.inputs.Image(type='filepath', label='Input'), | |
outputs=[ | |
gr.outputs.Image(type='filepath', label='Output_image'), | |
], | |
title=title, | |
description=description, | |
examples=examples, | |
css=css, | |
allow_flagging='never', | |
).launch(debug=True, enable_queue=True) | |