|
import cv2 |
|
import supervision as sv |
|
from ultralytics import YOLO |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
yolo_model = YOLO('yolov10x_best.pt') |
|
|
|
|
|
from surya.model.detection.segformer import load_processor , load_model |
|
import torch |
|
import os |
|
|
|
|
|
from surya.model.detection.segformer import load_processor , load_model |
|
import torch |
|
import os |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = load_model("vikp/surya_layout2").to(device) |
|
|
|
|
|
from PIL import Image |
|
from surya.input.processing import prepare_image_detection |
|
|
|
|
|
def predicted_mask_function(image_path) : |
|
|
|
img = Image.open(image_path) |
|
img = [prepare_image_detection(img=img, processor=load_processor())] |
|
img = torch.stack(img, dim=0).to(model.dtype).to(model.device) |
|
logits = model(img).logits |
|
|
|
predicted_mask = torch.argmax(logits[0], dim=0).cpu().numpy() |
|
|
|
return predicted_mask |
|
|
|
|
|
|
|
def predict_boxes_labels(image_path): |
|
results = yolo_model(source=image_path, conf=0.2, iou=0.8)[0] |
|
detections = sv.Detections.from_ultralytics(results) |
|
labels = detections.data["class_name"].tolist() |
|
bboxes = detections.xyxy.tolist() |
|
return bboxes,labels |
|
|
|
|
|
|
|
def resize_segment(mask, class_id, target_size, method=cv2.INTER_AREA): |
|
|
|
class_mask = np.where(mask == class_id, 1, 0).astype(np.uint8) |
|
|
|
|
|
resized_class_mask = cv2.resize(class_mask, (target_size[1], target_size[0]), interpolation=method) |
|
|
|
return resized_class_mask |
|
|
|
def resize_and_combine_classes(mask, target_size, method=cv2.INTER_AREA): |
|
unique_classes = np.unique(mask) |
|
|
|
|
|
resized_masks = np.zeros((target_size[0], target_size[1]), dtype=np.uint8) |
|
|
|
|
|
for class_id in unique_classes: |
|
resized_class_mask = resize_segment(mask, class_id, target_size, method) |
|
|
|
|
|
resized_masks[resized_class_mask == 1] = class_id |
|
|
|
return resized_masks |
|
|
|
|
|
class_labels = { |
|
0: 'Blank', |
|
1: 'Caption', |
|
2: 'Footnote', |
|
3: 'Formula', |
|
4: 'List-item', |
|
5: 'Page-footer', |
|
6: 'Page-header', |
|
7: 'Picture', |
|
8: 'Section-header', |
|
9: 'Table', |
|
10: 'Text', |
|
11: 'Title' |
|
} |
|
|
|
colors = plt.cm.get_cmap('tab20', len(class_labels)) |
|
|
|
def colormap_to_rgb(cmap, index): |
|
color = cmap(index)[:3] |
|
return tuple(int(c * 255) for c in color) |
|
|
|
def mask_to_bboxes(colored_mask, class_labels): |
|
bboxes = [] |
|
|
|
|
|
for label, class_name in class_labels.items(): |
|
|
|
color = colormap_to_rgb(colors, label) |
|
|
|
|
|
class_mask = np.all(colored_mask == color, axis=-1).astype(np.uint8) |
|
|
|
|
|
contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
for contour in contours: |
|
|
|
x, y, w, h = cv2.boundingRect(contour) |
|
|
|
|
|
xmin, ymin, xmax, ymax = x, y, x + w, y + h |
|
|
|
|
|
bboxes.append((xmin, ymin, xmax, ymax)) |
|
|
|
|
|
return bboxes |
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def suryolo(image_path) : |
|
|
|
image = Image.open(image_path) |
|
L, W = image.size |
|
|
|
|
|
predicted_mask = predicted_mask_function(image_path) |
|
|
|
colored_mask = np.zeros((W, L, 3), dtype=np.uint8) |
|
|
|
label_name_to_int = {v: k for k, v in class_labels.items()} |
|
|
|
colors = plt.cm.get_cmap('tab20', len(class_labels)) |
|
|
|
bboxes,labels = predict_boxes_labels(image_path) |
|
|
|
for box, label in zip(bboxes, labels): |
|
xmin, ymin, xmax, ymax = box |
|
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax) |
|
|
|
|
|
predicted_mask = resize_and_combine_classes(predicted_mask, (W, L)) |
|
|
|
|
|
mask_region = predicted_mask[ymin:ymax, xmin:xmax] |
|
|
|
|
|
label_index = label_name_to_int[label] |
|
|
|
|
|
color = colormap_to_rgb(colors, label_index) |
|
|
|
|
|
colored_mask[ymin:ymax, xmin:xmax][mask_region > 0.5] = color |
|
|
|
blank_color = colormap_to_rgb(colors, 0) |
|
colored_mask[(colored_mask == 0).all(axis=-1)] = blank_color |
|
|
|
return mask_to_bboxes(colored_mask,class_labels) |
|
|
|
|