Spaces:
Running
Running
File size: 4,647 Bytes
6a07cb2 1081f7c 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 2bb6556 6a07cb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import importlib
import warnings
from collections import defaultdict
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from config import Config
from data_utils.image_utils import _to_2d
warnings.filterwarnings("ignore")
DocTr_Plus = importlib.import_module("models.DocTr-Plus.inference")
DocScanner = importlib.import_module("models.DocScanner.inference")
cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mask_dict = defaultdict(int)
def load_geotrp_model(cuda, path=""):
_GeoTrP = DocTr_Plus.GeoTrP()
_GeoTrP = _GeoTrP.to(cuda)
DocTr_Plus.reload_model(_GeoTrP.GeoTr, path)
_GeoTrP.eval()
return _GeoTrP
def load_docscanner_model(cuda, path_l="", path_m=""):
net = DocScanner.Net().to(cuda)
DocScanner.reload_seg_model(cuda, net.msk, path_m)
DocScanner.reload_rec_model(cuda, net.bm, path_l)
net.eval()
return net
def preprocess_image(img, target_size=[288, 288]):
im_ori = img[:, :, :3] / 255.0
h_, w_, _ = im_ori.shape
im_ori_resized = cv2.resize(im_ori, (288, 288))
im = cv2.resize(im_ori_resized, target_size)
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
return im_ori, im, h_, w_
def geotrp_rec(img, model, cuda):
im_ori, im, h_, w_ = preprocess_image(img)
with torch.no_grad():
bm = model(im.to(cuda))
bm = bm.cpu().numpy()[0]
bm0 = bm[0, :, :]
bm1 = bm[1, :, :]
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
img_geo = cv2.remap(im_ori, bm0, bm1, cv2.INTER_LINEAR) * 255
img_geo = cv2.resize(img_geo, (w_, h_))
return img_geo
def docscanner_get_mask(img, model, cuda):
_, im, h, w = preprocess_image(img)
with torch.no_grad():
_, msk = model(im.to(cuda))
msk = msk.cpu()
mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
mask_resized = cv2.resize(mask_np, (w, h))
return mask_resized
def docscanner_rec_img(img, model, cuda):
im_ori, im, h, w = preprocess_image(img)
with torch.no_grad():
bm = model(im.to(cuda))
bm = bm.cpu()
# save rectified image
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
out = F.grid_sample(
torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(),
lbl,
align_corners=True,
)
img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8)
return img
def docscanner_rec(img, model, cuda):
im_ori = img[:, :, :3] / 255.0
h, w, _ = im_ori.shape
im = cv2.resize(im_ori, (288, 288))
im = im.transpose(2, 0, 1)
im = torch.from_numpy(im).float().unsqueeze(0)
with torch.no_grad():
bm, msk = model(im.to(cuda))
bm = bm.cpu()
msk = msk.cpu()
mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8)
mask_resized = cv2.resize(mask_np, (w, h))
mask_img = mask_resized
# save rectified image
bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow
bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow
bm0 = cv2.blur(bm0, (3, 3))
bm1 = cv2.blur(bm1, (3, 3))
lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2
out = F.grid_sample(
torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(),
lbl,
align_corners=True,
)
img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8)
return img, mask_img
# ์ถํ data_utils์ ๋ฃ์ ์์
def get_mask_white_area(mask):
"""
Get the white area (non-zero pixels) of a mask.
Args:
mask (np.ndarray): Input mask image (2D or 3D array)
Returns:
np.ndarray: Array of (y, x) coordinates of white pixels
"""
mask = _to_2d(mask)
white_pixels = np.argwhere(mask > 0)
return white_pixels
def main():
config = Config()
img = cv2.imread("input/test.jpg") # ์ฝ๋ ์คํ์ ์์ ํ์
docscanner = load_docscanner_model(
cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path
)
doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path)
mask = docscanner_get_mask(img, docscanner, cuda)
mask_dict.add(get_mask_white_area(mask))
if __name__ == "__main__":
main()
|