File size: 3,739 Bytes
1828176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f0d74c
1828176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8cf023c
1828176
 
 
0f0d74c
1828176
0f0d74c
1828176
 
 
8cf023c
1828176
 
8cf023c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# -*- encoding: utf-8 -*-
import copy
import os

os.system('pip install -r requirements.txt')

import time
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from networks.paperedge_cpu import GlobalWarper, LocalWarper, WarperUtil
import gradio as gr


cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)


class PaperEdge(object):
    def __init__(self, enet_path, tnet_path, device, dst_dir) -> None:
        self.device = device
        self.dst_dir = dst_dir

        self.netG = GlobalWarper().to(device)
        netG_state = torch.load(enet_path, map_location=device)['G']
        self.netG.load_state_dict(netG_state)
        self.netG.eval()

        self.netL = LocalWarper().to(device)
        netL_state = torch.load(tnet_path, map_location=device)['L']
        self.netL.load_state_dict(netL_state)
        self.netL.eval()

        self.warpUtil = WarperUtil(64).to(device)

    @staticmethod
    def load_img(img_path):
        im = cv2.imread(img_path).astype(np.float32) / 255.0
        im = im[:, :, (2, 1, 0)]
        im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA)
        im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
        return im

    def __call__(self, img_path):
        time_stamp = time.strftime('%Y-%m-%d-%H-%M-%S',
                                   time.localtime(time.time()))

        gs_d, ls_d = None, None
        with torch.no_grad():
            x = self.load_img(img_path)
            x = x.unsqueeze(0).to(self.device)

            d = self.netG(x)

            d = self.warpUtil.global_post_warp(d, 64)

            gs_d = copy.deepcopy(d)

            d = F.interpolate(d, size=256, mode='bilinear', align_corners=True)
            y0 = F.grid_sample(x, d.permute(0, 2, 3, 1), align_corners=True)
            ls_d = self.netL(y0)

            ls_d = F.interpolate(ls_d, size=256, mode='bilinear', align_corners=True)
            ls_d = ls_d.clamp(-1.0, 1.0)

        im = cv2.imread(img_path).astype(np.float32) / 255.0
        im = torch.from_numpy(np.transpose(im, (2, 0, 1)))
        im = im.to(self.device).unsqueeze(0)

        gs_d = F.interpolate(gs_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
        gs_y = F.grid_sample(im, gs_d.permute(0, 2, 3, 1), align_corners=True).detach()

        ls_d = F.interpolate(ls_d, (im.size(2), im.size(3)), mode='bilinear', align_corners=True)
        ls_y = F.grid_sample(gs_y, ls_d.permute(0, 2, 3, 1), align_corners=True).detach()
        ls_y = ls_y.squeeze().permute(1, 2, 0).cpu().numpy()

        save_path = f'{self.dst_dir}/{time_stamp}.png'
        cv2.imwrite(save_path, ls_y * 255.)
        return save_path


def inference(img_path):
    save_img_path = paper_edge(img_path)
    return save_img_path


enet_path = 'models/G_w_checkpoint_13820.pt'
tnet_path = 'models/L_w_checkpoint_27640.pt'
device = torch.device('cpu')

dst_dir = Path('inference/')
if not dst_dir.exists():
    dst_dir.mkdir(parents=True, exist_ok=True)

paper_edge = PaperEdge(enet_path, tnet_path, device, dst_dir)

title = 'PaperEdge Demo'
description = 'This is the demo for the paper "Learning From Documents in the Wild to Improve Document Unwarping" (SIGGRAPH 2022). Github repo: https://github.com/cvlab-stonybrook/PaperEdge'
css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;}"
examples = [['images/1.jpg']]

gr.Interface(
    inference,
    inputs=gr.inputs.Image(type='filepath', label='Input'),
    outputs=[
        gr.outputs.Image(type='filepath', label='Output_image'),
    ],
    title=title,
    description=description,
    examples=examples,
    css=css,
    allow_flagging='never',
    ).launch(debug=True, enable_queue=True)