DocTr / app.py
HaoFeng2019's picture
Update app.py
4cbf82a
from seg import U2NETP
from GeoTr import GeoTr
from IllTr import IllTr
from inference_ill import rec_ill
import torch
import torch.nn as nn
import torch.nn.functional as F
import skimage.io as io
import numpy as np
import cv2
import glob
import os
from PIL import Image
import argparse
import warnings
warnings.filterwarnings('ignore')
import gradio as gr
example_img_list = ['51_1 copy.png', '48_2 copy.png', '25.jpg']
def reload_model(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
# print(len(pretrained_dict.keys()))
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
# print(len(pretrained_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
def reload_segmodel(model, path=""):
if not bool(path):
return model
else:
model_dict = model.state_dict()
pretrained_dict = torch.load(path, map_location='cpu')
# print(len(pretrained_dict.keys()))
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict}
# print(len(pretrained_dict.keys()))
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
return model
class GeoTr_Seg(nn.Module):
def __init__(self):
super(GeoTr_Seg, self).__init__()
self.msk = U2NETP(3, 1)
self.GeoTr = GeoTr(num_attn_layers=6)
def forward(self, x):
msk, _1, _2, _3, _4, _5, _6 = self.msk(x)
msk = (msk > 0.5).float()
x = msk * x
bm = self.GeoTr(x)
bm = (2 * (bm / 286.8) - 1) * 0.99
return bm
# Initialize models
GeoTr_Seg_model = GeoTr_Seg()
# IllTr_model = IllTr()
# Load models only once
reload_segmodel(GeoTr_Seg_model.msk, './model_pretrained/seg.pth')
reload_model(GeoTr_Seg_model.GeoTr, './model_pretrained/geotr.pth')
# reload_model(IllTr_model, './model_pretrained/illtr.pth')
# Compile models (assuming PyTorch 2.0)
GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
# IllTr_model = torch.compile(IllTr_model)
def process_image(input_image):
GeoTr_Seg_model.eval()
# IllTr_model.eval()
im_ori = np.array(input_image)[:, :, :3] / 255.
h, w, _ = im_ori.shape
im = cv2.resize(im_ori, (288, 288))
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
with torch.no_grad():
bm = GeoTr_Seg_model(im)
bm = bm.cpu()
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
ill_rec = False
if ill_rec:
img_ill = rec_ill(IllTr_model, img_geo)
return Image.fromarray(img_ill)
else:
return Image.fromarray(img_geo)
# Define Gradio interface
input_image = gr.inputs.Image()
output_image = gr.outputs.Image(type='pil')
iface = gr.Interface(fn=process_image, inputs=input_image, outputs=output_image, title="DocTr",
examples=example_img_list)
iface.launch()