qoobeeshy's picture
examples updated
6bd511c
raw
history blame
7.42 kB
import gradio as gr
import requests
import torch
import os
from tqdm import tqdm
# import wandb
from ultralytics import YOLO
import cv2
import numpy as np
import pandas as pd
from skimage.transform import resize
from skimage import img_as_bool
from skimage.morphology import convex_hull_image
import json
# wandb.init(mode='disabled')
def tableConvexHull(img, masks):
mask=np.zeros(masks[0].shape,dtype="bool")
for msk in masks:
temp=msk.cpu().detach().numpy();
chull = convex_hull_image(temp);
mask=np.bitwise_or(mask,chull)
return mask
def cls_exists(clss, cls):
indices = torch.where(clss==cls)
return len(indices[0])>0
def empty_mask(img):
mask = np.zeros(img.shape[:2], dtype="uint8")
return np.array(mask, dtype=bool)
def extract_img_mask(img_model, img, config):
res_dict = {
'status' : 1
}
res = get_predictions(img_model, img, config)
if res['status']==-1:
res_dict['status'] = -1
elif res['status']==0:
res_dict['mask']=empty_mask(img)
else:
masks = res['masks']
boxes = res['boxes']
clss = boxes[:, 5]
mask = extract_mask(img, masks, boxes, clss, 0)
res_dict['mask'] = mask
return res_dict
def get_predictions(model, img2, config):
res_dict = {
'status': 1
}
try:
for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\
imgsz=config['sz'], conf=config['conf'], stream=True,\
classes=config['classes']):
try:
res_dict['masks'] = result.masks.data
res_dict['boxes'] = result.boxes.data
del result
return res_dict
except Exception as e:
res_dict['status'] = 0
return res_dict
except:
res_dict['status'] = -1
return res_dict
def extract_mask(img, masks, boxes, clss, cls):
if not cls_exists(clss, cls):
return empty_mask(img)
indices = torch.where(clss==cls)
c_masks = masks[indices]
mask_arr = torch.any(c_masks, dim=0).bool()
mask_arr = mask_arr.cpu().detach().numpy()
mask = mask_arr
return mask
def get_masks(img, model, img_model, flags, configs):
response = {
'status': 1
}
ans_masks = []
img2 = img
# ***** Getting paragraph and text masks
res = get_predictions(model, img2, configs['paratext'])
if res['status']==-1:
response['status'] = -1
return response
elif res['status']==0:
for i in range(2): ans_masks.append(empty_mask(img))
else:
masks, boxes = res['masks'], res['boxes']
clss = boxes[:, 5]
for cls in range(2):
mask = extract_mask(img, masks, boxes, clss, cls)
ans_masks.append(mask)
# ***** Getting image and table masks
res2 = get_predictions(model, img2, configs['imgtab'])
if res2['status']==-1:
response['status'] = -1
return response
elif res2['status']==0:
for i in range(2): ans_masks.append(empty_mask(img))
else:
masks, boxes = res2['masks'], res2['boxes']
clss = boxes[:, 5]
if cls_exists(clss, 2):
img_res = extract_img_mask(img_model, img, configs['image'])
if img_res['status'] == 1:
img_mask = img_res['mask']
else:
response['status'] = -1
return response
else:
img_mask = empty_mask(img)
ans_masks.append(img_mask)
if cls_exists(clss, 3):
indices = torch.where(clss==3)
tbl_mask = tableConvexHull(img, masks[indices])
else:
tbl_mask = empty_mask(img)
ans_masks.append(tbl_mask)
if not configs['paratext']['rm']:
h, w, c = img.shape
for i in range(4):
ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w)))
response['masks'] = ans_masks
return response
def overlay(image, mask, color, alpha, resize=None):
"""Combines image and its segmentation mask into a single image.
https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay
Params:
image: Training image. np.ndarray,
mask: Segmentation mask. np.ndarray,
color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0)
alpha: Segmentation mask's transparency. float = 0.5,
resize: If provided, both image and its mask are resized before blending them together.
tuple[int, int] = (1024, 1024))
Returns:
image_combined: The combined image. np.ndarray
"""
color = color[::-1]
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0)
colored_mask = np.moveaxis(colored_mask, 0, -1)
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color)
image_overlay = masked.filled()
if resize is not None:
image = cv2.resize(image.transpose(1, 2, 0), resize)
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize)
image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0)
return image_combined
model_path = 'models'
general_model_name = 'e50_aug.pt'
image_model_name = 'e100_img.pt'
general_model = YOLO(os.path.join(model_path, general_model_name))
image_model = YOLO(os.path.join(model_path, image_model_name))
image_path = 'examples'
sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png',
'0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png']
sample_path = [os.path.join(image_path, sample) for sample in sample_name]
flags = {
'hist': False,
'bz': False
}
configs = {}
configs['paratext'] = {
'sz' : 640,
'conf': 0.25,
'rm': True,
'classes': [0, 1]
}
configs['imgtab'] = {
'sz' : 640,
'conf': 0.35,
'rm': True,
'classes': [2, 3]
}
configs['image'] = {
'sz' : 640,
'conf': 0.35,
'rm': True,
'classes': [0]
}
def evaluate(img_path, model=general_model, img_model=image_model,\
configs=configs, flags=flags):
# print('starting')
img = cv2.imread(img_path)
res = get_masks(img, general_model, image_model, flags, configs)
if res['status']==-1:
for idx in configs.keys():
configs[idx]['rm'] = False
return evaluate(img, model, img_model, flags, configs)
else:
masks = res['masks']
color_map = {
0 : (255, 0, 0),
1 : (0, 255, 0),
2 : (0, 0, 255),
3 : (255, 255, 0),
}
for i, mask in enumerate(masks):
img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4)
# print('finishing')
return img
# output = evaluate(img_path=sample_path, model=general_model, img_model=image_model,\
# configs=configs, flags=flags)
inputs_image = [
gr.components.Image(type="filepath", label="Input Image"),
]
outputs_image = [
gr.components.Image(type="numpy", label="Output Image"),
]
interface_image = gr.Interface(
fn=evaluate,
inputs=inputs_image,
outputs=outputs_image,
title="Document Layout Segmentor",
examples=sample_path,
cache_examples=True,
).launch()