File size: 5,512 Bytes
2720487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import cv2
import supervision as sv # pip install supervision
from ultralytics import YOLO
import numpy as np
import matplotlib.pyplot as plt
yolo_model = YOLO('yolov10x_best.pt')
from surya.model.detection.segformer import load_processor , load_model
import torch
import os
from surya.model.detection.segformer import load_processor , load_model
import torch
import os
# os.environ['HF_HOME'] = '/share/data/drive_3/ketan/orc/HF_Cache'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model("vikp/surya_layout2").to(device)
from PIL import Image
from surya.input.processing import prepare_image_detection
def predicted_mask_function(image_path) :
img = Image.open(image_path)
img = [prepare_image_detection(img=img, processor=load_processor())]
img = torch.stack(img, dim=0).to(model.dtype).to(model.device)
logits = model(img).logits
predicted_mask = torch.argmax(logits[0], dim=0).cpu().numpy()
return predicted_mask
def predict_boxes_labels(image_path):
results = yolo_model(source=image_path, conf=0.2, iou=0.8)[0]
detections = sv.Detections.from_ultralytics(results)
labels = detections.data["class_name"].tolist()
bboxes = detections.xyxy.tolist()
return bboxes,labels
def resize_segment(mask, class_id, target_size, method=cv2.INTER_AREA):
# Create a binary mask for the current class
class_mask = np.where(mask == class_id, 1, 0).astype(np.uint8)
# Resize the class mask to the target size
resized_class_mask = cv2.resize(class_mask, (target_size[1], target_size[0]), interpolation=method)
return resized_class_mask
def resize_and_combine_classes(mask, target_size, method=cv2.INTER_AREA):
unique_classes = np.unique(mask)
# Initialize a zero-filled mask for the combined result with the correct target size
resized_masks = np.zeros((target_size[0], target_size[1]), dtype=np.uint8)
# Process each class found in the mask
for class_id in unique_classes:
resized_class_mask = resize_segment(mask, class_id, target_size, method)
# Assign the class ID to the resized output mask where the resized class mask is 1
resized_masks[resized_class_mask == 1] = class_id
return resized_masks
class_labels = {
0: 'Blank',
1: 'Caption',
2: 'Footnote',
3: 'Formula',
4: 'List-item',
5: 'Page-footer',
6: 'Page-header',
7: 'Picture',
8: 'Section-header',
9: 'Table',
10: 'Text',
11: 'Title'
}
colors = plt.cm.get_cmap('tab20', len(class_labels))
def colormap_to_rgb(cmap, index):
color = cmap(index)[:3] # Extract RGB, ignore alpha
return tuple(int(c * 255) for c in color)
def mask_to_bboxes(colored_mask, class_labels):
bboxes = []
# Loop through each class in the class_labels
for label, class_name in class_labels.items():
# Get the RGB color for the current label
color = colormap_to_rgb(colors, label)
# Create a binary mask for the current label by checking where the colored mask matches the class color
class_mask = np.all(colored_mask == color, axis=-1).astype(np.uint8)
# Find contours of the class region in the binary mask
contours, _ = cv2.findContours(class_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Loop through all contours and extract bounding boxes
for contour in contours:
# Get the bounding box for the contour (in xywh format)
x, y, w, h = cv2.boundingRect(contour)
# Convert to xyxy format: (xmin, ymin, xmax, ymax)
xmin, ymin, xmax, ymax = x, y, x + w, y + h
# Append the bounding box with the corresponding class label
bboxes.append((xmin, ymin, xmax, ymax))
# bboxes.append((xmin, ymin, xmax, ymax, class_name))
return bboxes
import matplotlib.pyplot as plt
# from matplotlib import colors
def suryolo(image_path) :
image = Image.open(image_path)
L, W = image.size
predicted_mask = predicted_mask_function(image_path)
colored_mask = np.zeros((W, L, 3), dtype=np.uint8) # 3 channels for RGB
label_name_to_int = {v: k for k, v in class_labels.items()}
colors = plt.cm.get_cmap('tab20', len(class_labels))
bboxes,labels = predict_boxes_labels(image_path)
for box, label in zip(bboxes, labels): # Assuming labels list corresponds to bboxes
xmin, ymin, xmax, ymax = box
xmin, ymin, xmax, ymax = int(xmin), int(ymin), int(xmax), int(ymax)
# Resize predicted mask to match the image dimensions (W = width, L = height)
predicted_mask = resize_and_combine_classes(predicted_mask, (W, L))
# Extract the mask region within the bounding box
mask_region = predicted_mask[ymin:ymax, xmin:xmax]
# Get the corresponding integer index for the label
label_index = label_name_to_int[label]
# Get the corresponding color for the label using the colormap
color = colormap_to_rgb(colors, label_index)
# Apply the color to the regions where mask_region > 0.5
colored_mask[ymin:ymax, xmin:xmax][mask_region > 0.5] = color
blank_color = colormap_to_rgb(colors, 0)
colored_mask[(colored_mask == 0).all(axis=-1)] = blank_color
return mask_to_bboxes(colored_mask,class_labels)
|