import importlib import warnings from collections import defaultdict import cv2 import numpy as np import torch import torch.nn.functional as F from config import Config from data_utils.image_utils import _to_2d warnings.filterwarnings("ignore") DocTr_Plus = importlib.import_module("models.DocTr-Plus.inference") DocScanner = importlib.import_module("models.DocScanner.inference") cuda = torch.device("cuda" if torch.cuda.is_available() else "cpu") mask_dict = defaultdict(int) def load_geotrp_model(cuda, path=""): _GeoTrP = DocTr_Plus.GeoTrP() _GeoTrP = _GeoTrP.to(cuda) DocTr_Plus.reload_model(_GeoTrP.GeoTr, path) _GeoTrP.eval() return _GeoTrP def load_docscanner_model(cuda, path_l="", path_m=""): net = DocScanner.Net().to(cuda) DocScanner.reload_seg_model(net.msk, path_m) DocScanner.reload_rec_model(net.bm, path_l) net.eval() return net def preprocess_image(img, target_size=[288, 288]): im_ori = img[:, :, :3] / 255.0 h_, w_, _ = im_ori.shape im_ori_resized = cv2.resize(im_ori, (288, 288)) im = cv2.resize(im_ori_resized, target_size) im = im.transpose(2, 0, 1) im = torch.from_numpy(im).float().unsqueeze(0) return im_ori, im, h_, w_ def geotrp_rec(img, model): im_ori, im, h_, w_ = preprocess_image(img) with torch.no_grad(): bm = model(im.cuda()) bm = bm.cpu().numpy()[0] bm0 = bm[0, :, :] bm1 = bm[1, :, :] bm0 = cv2.blur(bm0, (3, 3)) bm1 = cv2.blur(bm1, (3, 3)) img_geo = cv2.remap(im_ori, bm0, bm1, cv2.INTER_LINEAR) * 255 img_geo = cv2.resize(img_geo, (w_, h_)) return img_geo def docscanner_get_mask(img, model): _, im, h, w = preprocess_image(img) with torch.no_grad(): _, msk = model(im.cuda()) msk = msk.cpu() mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8) mask_resized = cv2.resize(mask_np, (w, h)) return mask_resized def docscanner_rec_img(img, model): im_ori, im, h, w = preprocess_image(img) with torch.no_grad(): bm = model(im.cuda()) bm = bm.cpu() # save rectified image bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow bm0 = cv2.blur(bm0, (3, 3)) bm1 = cv2.blur(bm1, (3, 3)) lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 out = F.grid_sample( torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True, ) img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8) return img def docscanner_rec(img, model): im_ori = img[:, :, :3] / 255.0 h, w, _ = im_ori.shape im = cv2.resize(im_ori, (288, 288)) im = im.transpose(2, 0, 1) im = torch.from_numpy(im).float().unsqueeze(0) with torch.no_grad(): bm, msk = model(im.cuda()) bm = bm.cpu() msk = msk.cpu() mask_np = (msk[0, 0].numpy() * 255).astype(np.uint8) mask_resized = cv2.resize(mask_np, (w, h)) mask_img = mask_resized # save rectified image bm0 = cv2.resize(bm[0, 0].numpy(), (w, h)) # x flow bm1 = cv2.resize(bm[0, 1].numpy(), (w, h)) # y flow bm0 = cv2.blur(bm0, (3, 3)) bm1 = cv2.blur(bm1, (3, 3)) lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0) # h * w * 2 out = F.grid_sample( torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True, ) img = (((out[0] * 255).permute(1, 2, 0).numpy())[:, :, ::-1]).astype(np.uint8) return img, mask_img # 추후 data_utils에 넣을 예정 def get_mask_white_area(mask): """ Get the white area (non-zero pixels) of a mask. Args: mask (np.ndarray): Input mask image (2D or 3D array) Returns: np.ndarray: Array of (y, x) coordinates of white pixels """ mask = _to_2d(mask) white_pixels = np.argwhere(mask > 0) return white_pixels def main(): config = Config() img = cv2.imread("input/test.jpg") # 코드 실행시 수정 필요 docscanner = load_docscanner_model( cuda, path_l=config.get_rec_model_path, path_m=config.get_seg_model_path ) doctr = load_geotrp_model(cuda, path=config.get_geotr_model_path) mask = docscanner_get_mask(img, docscanner) mask_dict.add(get_mask_white_area(mask)) if __name__ == "__main__": main()