|  | import gradio as gr | 
					
						
						|  | import requests | 
					
						
						|  | import torch | 
					
						
						|  | import os | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  | from ultralytics import YOLO | 
					
						
						|  | import cv2 | 
					
						
						|  | import numpy as np | 
					
						
						|  | import pandas as pd | 
					
						
						|  | from skimage.transform import resize | 
					
						
						|  | from skimage import img_as_bool | 
					
						
						|  | from skimage.morphology import convex_hull_image | 
					
						
						|  | import json | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def tableConvexHull(img, masks): | 
					
						
						|  | mask=np.zeros(masks[0].shape,dtype="bool") | 
					
						
						|  | for msk in masks: | 
					
						
						|  | temp=msk.cpu().detach().numpy(); | 
					
						
						|  | chull = convex_hull_image(temp); | 
					
						
						|  | mask=np.bitwise_or(mask,chull) | 
					
						
						|  | return mask | 
					
						
						|  |  | 
					
						
						|  | def cls_exists(clss, cls): | 
					
						
						|  | indices = torch.where(clss==cls) | 
					
						
						|  | return len(indices[0])>0 | 
					
						
						|  |  | 
					
						
						|  | def empty_mask(img): | 
					
						
						|  | mask = np.zeros(img.shape[:2], dtype="uint8") | 
					
						
						|  | return np.array(mask, dtype=bool) | 
					
						
						|  |  | 
					
						
						|  | def extract_img_mask(img_model, img, config): | 
					
						
						|  | res_dict = { | 
					
						
						|  | 'status' : 1 | 
					
						
						|  | } | 
					
						
						|  | res = get_predictions(img_model, img, config) | 
					
						
						|  |  | 
					
						
						|  | if res['status']==-1: | 
					
						
						|  | res_dict['status'] = -1 | 
					
						
						|  |  | 
					
						
						|  | elif res['status']==0: | 
					
						
						|  | res_dict['mask']=empty_mask(img) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | masks = res['masks'] | 
					
						
						|  | boxes = res['boxes'] | 
					
						
						|  | clss = boxes[:, 5] | 
					
						
						|  | mask = extract_mask(img, masks, boxes, clss, 0) | 
					
						
						|  | res_dict['mask'] = mask | 
					
						
						|  | return res_dict | 
					
						
						|  |  | 
					
						
						|  | def get_predictions(model, img2, config): | 
					
						
						|  | res_dict = { | 
					
						
						|  | 'status': 1 | 
					
						
						|  | } | 
					
						
						|  | try: | 
					
						
						|  | for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\ | 
					
						
						|  | imgsz=config['sz'], conf=config['conf'], stream=True,\ | 
					
						
						|  | classes=config['classes']): | 
					
						
						|  | try: | 
					
						
						|  | res_dict['masks'] = result.masks.data | 
					
						
						|  | res_dict['boxes'] = result.boxes.data | 
					
						
						|  | del result | 
					
						
						|  | return res_dict | 
					
						
						|  | except Exception as e: | 
					
						
						|  | res_dict['status'] = 0 | 
					
						
						|  | return res_dict | 
					
						
						|  | except: | 
					
						
						|  | res_dict['status'] = -1 | 
					
						
						|  | return res_dict | 
					
						
						|  |  | 
					
						
						|  | def extract_mask(img, masks, boxes, clss, cls): | 
					
						
						|  | if not cls_exists(clss, cls): | 
					
						
						|  | return empty_mask(img) | 
					
						
						|  | indices = torch.where(clss==cls) | 
					
						
						|  | c_masks = masks[indices] | 
					
						
						|  | mask_arr = torch.any(c_masks, dim=0).bool() | 
					
						
						|  | mask_arr = mask_arr.cpu().detach().numpy() | 
					
						
						|  | mask = mask_arr | 
					
						
						|  | return mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_masks(img, model, img_model, flags, configs): | 
					
						
						|  | response = { | 
					
						
						|  | 'status': 1 | 
					
						
						|  | } | 
					
						
						|  | ans_masks = [] | 
					
						
						|  | img2 = img | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | res = get_predictions(model, img2, configs['paratext']) | 
					
						
						|  | if res['status']==-1: | 
					
						
						|  | response['status'] = -1 | 
					
						
						|  | return response | 
					
						
						|  | elif res['status']==0: | 
					
						
						|  | for i in range(2): ans_masks.append(empty_mask(img)) | 
					
						
						|  | else: | 
					
						
						|  | masks, boxes = res['masks'], res['boxes'] | 
					
						
						|  | clss = boxes[:, 5] | 
					
						
						|  | for cls in range(2): | 
					
						
						|  | mask = extract_mask(img, masks, boxes, clss, cls) | 
					
						
						|  | ans_masks.append(mask) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | res2 = get_predictions(model, img2, configs['imgtab']) | 
					
						
						|  | if res2['status']==-1: | 
					
						
						|  | response['status'] = -1 | 
					
						
						|  | return response | 
					
						
						|  | elif res2['status']==0: | 
					
						
						|  | for i in range(2): ans_masks.append(empty_mask(img)) | 
					
						
						|  | else: | 
					
						
						|  | masks, boxes = res2['masks'], res2['boxes'] | 
					
						
						|  | clss = boxes[:, 5] | 
					
						
						|  |  | 
					
						
						|  | if cls_exists(clss, 2): | 
					
						
						|  | img_res = extract_img_mask(img_model, img, configs['image']) | 
					
						
						|  | if img_res['status'] == 1: | 
					
						
						|  | img_mask = img_res['mask'] | 
					
						
						|  | else: | 
					
						
						|  | response['status'] = -1 | 
					
						
						|  | return response | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | img_mask = empty_mask(img) | 
					
						
						|  | ans_masks.append(img_mask) | 
					
						
						|  |  | 
					
						
						|  | if cls_exists(clss, 3): | 
					
						
						|  | indices = torch.where(clss==3) | 
					
						
						|  | tbl_mask = tableConvexHull(img, masks[indices]) | 
					
						
						|  | else: | 
					
						
						|  | tbl_mask = empty_mask(img) | 
					
						
						|  | ans_masks.append(tbl_mask) | 
					
						
						|  |  | 
					
						
						|  | if not configs['paratext']['rm']: | 
					
						
						|  | h, w, c = img.shape | 
					
						
						|  | for i in range(4): | 
					
						
						|  | ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w))) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | response['masks'] = ans_masks | 
					
						
						|  | return response | 
					
						
						|  |  | 
					
						
						|  | def overlay(image, mask, color, alpha, resize=None): | 
					
						
						|  | """Combines image and its segmentation mask into a single image. | 
					
						
						|  | https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay | 
					
						
						|  |  | 
					
						
						|  | Params: | 
					
						
						|  | image: Training image. np.ndarray, | 
					
						
						|  | mask: Segmentation mask. np.ndarray, | 
					
						
						|  | color: Color for segmentation mask rendering.  tuple[int, int, int] = (255, 0, 0) | 
					
						
						|  | alpha: Segmentation mask's transparency. float = 0.5, | 
					
						
						|  | resize: If provided, both image and its mask are resized before blending them together. | 
					
						
						|  | tuple[int, int] = (1024, 1024)) | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | image_combined: The combined image. np.ndarray | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | color = color[::-1] | 
					
						
						|  | colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) | 
					
						
						|  | colored_mask = np.moveaxis(colored_mask, 0, -1) | 
					
						
						|  | masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) | 
					
						
						|  | image_overlay = masked.filled() | 
					
						
						|  |  | 
					
						
						|  | if resize is not None: | 
					
						
						|  | image = cv2.resize(image.transpose(1, 2, 0), resize) | 
					
						
						|  | image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) | 
					
						
						|  |  | 
					
						
						|  | image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) | 
					
						
						|  |  | 
					
						
						|  | return image_combined | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_path = 'models' | 
					
						
						|  | general_model_name = 'e50_aug.pt' | 
					
						
						|  | image_model_name = 'e100_img.pt' | 
					
						
						|  |  | 
					
						
						|  | general_model = YOLO(os.path.join(model_path, general_model_name)) | 
					
						
						|  | image_model = YOLO(os.path.join(model_path, image_model_name)) | 
					
						
						|  |  | 
					
						
						|  | image_path = 'examples' | 
					
						
						|  | sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png', | 
					
						
						|  | '0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png'] | 
					
						
						|  |  | 
					
						
						|  | sample_path = [os.path.join(image_path, sample) for sample in sample_name] | 
					
						
						|  |  | 
					
						
						|  | flags = { | 
					
						
						|  | 'hist': False, | 
					
						
						|  | 'bz': False | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | configs = {} | 
					
						
						|  | configs['paratext'] = { | 
					
						
						|  | 'sz' : 640, | 
					
						
						|  | 'conf': 0.25, | 
					
						
						|  | 'rm': True, | 
					
						
						|  | 'classes': [0, 1] | 
					
						
						|  | } | 
					
						
						|  | configs['imgtab'] = { | 
					
						
						|  | 'sz' : 640, | 
					
						
						|  | 'conf': 0.35, | 
					
						
						|  | 'rm': True, | 
					
						
						|  | 'classes': [2, 3] | 
					
						
						|  | } | 
					
						
						|  | configs['image'] = { | 
					
						
						|  | 'sz' : 640, | 
					
						
						|  | 'conf': 0.35, | 
					
						
						|  | 'rm': True, | 
					
						
						|  | 'classes': [0] | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def evaluate(img_path, model=general_model, img_model=image_model,\ | 
					
						
						|  | configs=configs, flags=flags): | 
					
						
						|  |  | 
					
						
						|  | img = cv2.imread(img_path) | 
					
						
						|  | res = get_masks(img, general_model, image_model, flags, configs) | 
					
						
						|  | if res['status']==-1: | 
					
						
						|  | for idx in configs.keys(): | 
					
						
						|  | configs[idx]['rm'] = False | 
					
						
						|  | return evaluate(img, model, img_model, flags, configs) | 
					
						
						|  | else: | 
					
						
						|  | masks = res['masks'] | 
					
						
						|  |  | 
					
						
						|  | color_map = { | 
					
						
						|  | 0 : (255, 0, 0), | 
					
						
						|  | 1 : (0, 255, 0), | 
					
						
						|  | 2 : (0, 0, 255), | 
					
						
						|  | 3 : (255, 255, 0), | 
					
						
						|  | } | 
					
						
						|  | for i, mask in enumerate(masks): | 
					
						
						|  | img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4) | 
					
						
						|  |  | 
					
						
						|  | return img | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | inputs_image = [ | 
					
						
						|  | gr.components.Image(type="filepath", label="Input Image"), | 
					
						
						|  | ] | 
					
						
						|  | outputs_image = [ | 
					
						
						|  | gr.components.Image(type="numpy", label="Output Image"), | 
					
						
						|  | ] | 
					
						
						|  | interface_image = gr.Interface( | 
					
						
						|  | fn=evaluate, | 
					
						
						|  | inputs=inputs_image, | 
					
						
						|  | outputs=outputs_image, | 
					
						
						|  | title="Document Layout Segmentor", | 
					
						
						|  | examples=sample_path, | 
					
						
						|  | cache_examples=True, | 
					
						
						|  | ).launch() |