ArabicDoc-layout-Detection / surya_yolo_pipeline.py
ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
raw
history blame
5.51 kB
import cv2
import supervision as sv # pip install supervision
from ultralytics import YOLO
import numpy as np
import matplotlib.pyplot as plt
yolo_model = YOLO('yolov10x_best.pt')
from surya.model.detection.segformer import load_processor , load_model
import torch
import os
from surya.model.detection.segformer import load_processor , load_model
import torch
import os
# os.environ['HF_HOME'] = '/share/data/drive_3/ketan/orc/HF_Cache'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("vikp/surya_layout2").to(device)
from PIL import Image
from surya.input.processing import prepare_image_detection
def predicted_mask_function(image_path) :
img = Image.open(image_path)
img = [prepare_image_detection(img=img, processor=load_processor())]
img = torch.stack(img, dim=0).to(model.dtype).to(model.device)
logits = model(img).logits
predicted_mask = torch.argmax(logits[0], dim=0).cpu().numpy()
return predicted_mask
def predict_boxes_labels(image_path):
results = yolo_model(source=image_path, conf=0.2, iou=0.8)[0]
detections = sv.Detections.from_ultralytics(results)
labels = detections.data["class_name"].tolist()
bboxes = detections.xyxy.tolist()
return bboxes,labels
def resize_segment(mask, class_id, target_size, method=cv2.INTER_AREA):
# Create a binary mask for the current class
class_mask = np.where(mask == class_id, 1, 0).astype(np.uint8)
# Resize the class mask to the target size
resized_class_mask = cv2.resize(class_mask, (target_size[1], target_size[0]), interpolation=method)
return resized_class_mask
def resize_and_combine_classes(mask, target_size, method=cv2.INTER_AREA):
unique_classes = np.unique(mask)
# Initialize a zero-filled mask for the combined result with the correct target size
resized_masks = np.zeros((target_size[0], target_size[1]), dtype=np.uint8)
# Process each class found in the mask
for class_id in unique_classes:
resized_class_mask = resize_segment(mask, class_id, target_size, method)
# Assign the class ID to the resized output mask where the resized class mask is 1
resized_masks[resized_class_mask == 1] = class_id
return resized_masks
class_labels = {
0: 'Blank',
1: 'Caption',
2: 'Footnote',
3: 'Formula',
4: 'List-item',
5: 'Page-footer',
6: 'Page-header',
7: 'Picture',
8: 'Section-header',
9: 'Table',
10: 'Text',
11: 'Title'
}
colors = plt.cm.get_cmap('tab20', len(class_labels))
def colormap_to_rgb(cmap, index):
color = cmap(index)[:3] # Extract RGB, ignore alpha
return tuple(int(c * 255) for c in color)
def mask_to_bboxes(colored_mask, class_labels):
bboxes = []
# Loop through each class in the class_labels
for label, class_name in class_labels.items():
# Get the RGB color for the current label
color = colormap_to_rgb(colors, label)
# Create a binary mask for the current label by checking where the colored mask matches the class color
class_mask = np.all(colored_mask == color, axis=-1).astype(np.uint8)
# Find contours of the class region in the binary mask
contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Loop through all contours and extract bounding boxes
for contour in contours:
# Get the bounding box for the contour (in xywh format)
x, y, w, h = cv2.boundingRect(contour)
# Convert to xyxy format: (xmin, ymin, xmax, ymax)
xmin, ymin, xmax, ymax = x, y, x + w, y + h
# Append the bounding box with the corresponding class label
bboxes.append((xmin, ymin, xmax, ymax))
# bboxes.append((xmin, ymin, xmax, ymax, class_name))
return bboxes
import matplotlib.pyplot as plt
# from matplotlib import colors
def suryolo(image_path) :
image = Image.open(image_path)
L, W = image.size
predicted_mask = predicted_mask_function(image_path)
colored_mask = np.zeros((W, L, 3), dtype=np.uint8) # 3 channels for RGB
label_name_to_int = {v: k for k, v in class_labels.items()}
colors = plt.cm.get_cmap('tab20', len(class_labels))
bboxes,labels = predict_boxes_labels(image_path)
for box, label in zip(bboxes, labels): # Assuming labels list corresponds to bboxes
xmin, ymin, xmax, ymax = box
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
# Resize predicted mask to match the image dimensions (W = width, L = height)
predicted_mask = resize_and_combine_classes(predicted_mask, (W, L))
# Extract the mask region within the bounding box
mask_region = predicted_mask[ymin:ymax, xmin:xmax]
# Get the corresponding integer index for the label
label_index = label_name_to_int[label]
# Get the corresponding color for the label using the colormap
color = colormap_to_rgb(colors, label_index)
# Apply the color to the regions where mask_region > 0.5
colored_mask[ymin:ymax, xmin:xmax][mask_region > 0.5] = color
blank_color = colormap_to_rgb(colors, 0)
colored_mask[(colored_mask == 0).all(axis=-1)] = blank_color
return mask_to_bboxes(colored_mask,class_labels)